├── roadmap
├── __init__.py
├── utils
│ ├── logger.py
│ ├── freeze_pos_embedding.py
│ ├── count_parameters.py
│ ├── set_lr.py
│ ├── expand_path.py
│ ├── set_initial_lr.py
│ ├── get_lr.py
│ ├── list_or_tuple.py
│ ├── format_time.py
│ ├── freeze_batch_norm.py
│ ├── str_to_bool.py
│ ├── rgb_to_bgr.py
│ ├── get_gradient_norm.py
│ ├── override_config.py
│ ├── average_meter.py
│ ├── extract_progress.py
│ ├── moving_average.py
│ ├── create_label_matrix.py
│ ├── dict_average.py
│ ├── get_set_random_state.py
│ └── __init__.py
├── models
│ ├── __init__.py
│ ├── create_projection_head.py
│ └── net.py
├── config
│ ├── loss
│ │ ├── fastap.yaml
│ │ ├── smoothap.yaml
│ │ ├── pair_loss.yaml
│ │ ├── blackboxap.yaml
│ │ ├── softbinap.yaml
│ │ ├── affineap.yaml
│ │ ├── supap.yaml
│ │ ├── supap_inat.yaml
│ │ ├── roadmap.yaml
│ │ └── roadmap_inat.yaml
│ ├── memory
│ │ ├── cub.yaml
│ │ ├── default.yaml
│ │ ├── sop.yaml
│ │ └── inaturalist.yaml
│ ├── optimizer
│ │ ├── cub.yaml
│ │ ├── cub_deit.yaml
│ │ ├── inshop_deit.yaml
│ │ ├── sfm120k_deit.yaml
│ │ ├── sop_deit.yaml
│ │ ├── inaturalist_deit.yaml
│ │ ├── inshop.yaml
│ │ ├── sop.yaml
│ │ └── inaturalist.yaml
│ ├── dataset
│ │ ├── cub.yaml
│ │ ├── inaturalist.yaml
│ │ ├── sop.yaml
│ │ ├── inshop.yaml
│ │ └── sfm120k.yaml
│ ├── model
│ │ ├── resnet.yaml
│ │ ├── deit.yaml
│ │ └── resnet_max_ln.yaml
│ ├── hydra
│ │ └── launcher
│ │ │ └── ray_launcher.yaml
│ ├── transform
│ │ ├── sfm120k.yaml
│ │ ├── cub.yaml
│ │ ├── cub_big.yaml
│ │ ├── inaturalist.yaml
│ │ ├── inshop_big.yaml
│ │ ├── sop.yaml
│ │ └── sop_big.yaml
│ ├── default.yaml
│ └── experience
│ │ ├── landmarks.yaml
│ │ └── default.yaml
├── losses
│ ├── fast_ap.py
│ ├── __init__.py
│ ├── calibration_loss.py
│ ├── pair_loss.py
│ ├── blackbox_ap.py
│ ├── softbin_ap.py
│ └── smooth_rank_ap.py
├── samplers
│ ├── __init__.py
│ ├── random_sampler.py
│ ├── m_per_class_sampler.py
│ └── hierarchical_sampler.py
├── datasets
│ ├── __init__.py
│ ├── sop.py
│ ├── sfm120k.py
│ ├── inshop.py
│ ├── cub200.py
│ ├── inaturalist.py
│ ├── revisited_dataset.py
│ └── base_dataset.py
├── engine
│ ├── make_subset.py
│ ├── __init__.py
│ ├── chepoint.py
│ ├── get_knn.py
│ ├── memory.py
│ ├── base_update.py
│ ├── evaluate.py
│ ├── train.py
│ ├── cross_validation_splits.py
│ ├── accuracy_calculator.py
│ └── landmark_evaluation.py
├── single_experiment_runner.py
├── evaluate.py
├── getter.py
└── run.py
├── picture
└── outline.png
├── dev_requirements.txt
├── .gitignore
├── requirements.txt
├── setup.py
├── LICENSE
└── README.md
/roadmap/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/picture/outline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/elias-ramzi/ROADMAP/HEAD/picture/outline.png
--------------------------------------------------------------------------------
/roadmap/utils/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | LOGGER = logging.getLogger("ROADMAP")
4 |
--------------------------------------------------------------------------------
/roadmap/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .net import RetrievalNet
2 |
3 |
4 | __all__ = [
5 | 'RetrievalNet',
6 | ]
7 |
--------------------------------------------------------------------------------
/roadmap/config/loss/fastap.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | - name: FastAP
3 | weight: 1.0
4 | kwargs:
5 | num_bins: 10
6 |
--------------------------------------------------------------------------------
/roadmap/config/loss/smoothap.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | - name: SmoothAP
3 | weight: 1.0
4 | kwargs:
5 | tau: 0.01
6 |
--------------------------------------------------------------------------------
/roadmap/config/loss/pair_loss.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | - name: PairLoss
3 | weight: 1.0
4 | kwargs:
5 | margin: 0.5
6 |
--------------------------------------------------------------------------------
/roadmap/utils/freeze_pos_embedding.py:
--------------------------------------------------------------------------------
1 | def freeze_pos_embedding(net):
2 | net.pos_embed.requires_grad_(False)
3 | return net
4 |
--------------------------------------------------------------------------------
/roadmap/utils/count_parameters.py:
--------------------------------------------------------------------------------
1 | def count_parameters(model):
2 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
3 |
--------------------------------------------------------------------------------
/roadmap/utils/set_lr.py:
--------------------------------------------------------------------------------
1 | def set_lr(optimizer, lr):
2 | for param_group in optimizer.param_groups:
3 | param_group['lr'] = lr
4 |
--------------------------------------------------------------------------------
/roadmap/losses/fast_ap.py:
--------------------------------------------------------------------------------
1 | from pytorch_metric_learning import losses
2 |
3 |
4 | class FastAP(losses.FastAPLoss):
5 | takes_embeddings = True
6 |
--------------------------------------------------------------------------------
/roadmap/config/loss/blackboxap.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | - name: BlackBoxAP
3 | weight: 1.0
4 | kwargs:
5 | lambda_val: 4.0
6 | margin: 0.02
7 |
--------------------------------------------------------------------------------
/roadmap/config/loss/softbinap.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | - name: SoftBinAP
3 | weight: 1.0
4 | kwargs:
5 | nq: 20
6 | min: -1
7 | max: 1
8 |
--------------------------------------------------------------------------------
/roadmap/config/memory/cub.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | name: XBM
3 |
4 | activate_after: -1
5 | weight: 1.0
6 | kwargs:
7 | size: 5824
8 | unique: True
9 |
--------------------------------------------------------------------------------
/roadmap/config/memory/default.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | name:
3 |
4 | activate_after: -1
5 | weight: 1.0
6 | kwargs:
7 | size: null
8 | unique: null
9 |
--------------------------------------------------------------------------------
/roadmap/config/memory/sop.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | name: XBM
3 |
4 | activate_after: -1
5 | weight: 1.0
6 | kwargs:
7 | size: 59551
8 | unique: True
9 |
--------------------------------------------------------------------------------
/roadmap/config/loss/affineap.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | - name: AffineAP
3 | weight: 1.0
4 | kwargs:
5 | theta: 0.5
6 | mu_n: 0.025
7 | mu_p: 0.025
8 |
--------------------------------------------------------------------------------
/roadmap/config/memory/inaturalist.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | name: XBM
3 |
4 | activate_after: -1
5 | weight: 1.0
6 | kwargs:
7 | size: 60000
8 | unique: False
9 |
--------------------------------------------------------------------------------
/roadmap/utils/expand_path.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | def expand_path(pth):
5 | pth = os.path.expandvars(pth)
6 | pth = os.path.expanduser(pth)
7 | return pth
8 |
--------------------------------------------------------------------------------
/dev_requirements.txt:
--------------------------------------------------------------------------------
1 | wheel
2 | six
3 | appdirs
4 | ordered_set
5 | ipython
6 | jupyter
7 | ipdb
8 | autopep8
9 | flake8
10 | pylint
11 | isort
12 | jedi==0.17.1
13 |
--------------------------------------------------------------------------------
/roadmap/utils/set_initial_lr.py:
--------------------------------------------------------------------------------
1 | def set_initial_lr(optimizer):
2 | for param_group in optimizer.param_groups:
3 | param_group['initial_lr'] = param_group['lr']
4 |
--------------------------------------------------------------------------------
/roadmap/config/loss/supap.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | - name: SupAP
3 | weight: 1.0
4 | kwargs:
5 | tau: 0.01
6 | rho: 100.0
7 | offset: 1.44
8 | delta: 0.05
9 |
--------------------------------------------------------------------------------
/roadmap/utils/get_lr.py:
--------------------------------------------------------------------------------
1 | def get_lr(optimizer):
2 | all_lr = {}
3 | for i, param_group in enumerate(optimizer.param_groups):
4 | all_lr[i] = param_group['lr']
5 | return all_lr
6 |
--------------------------------------------------------------------------------
/roadmap/config/loss/supap_inat.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | - name: SupAP
3 | weight: 1.0
4 | kwargs:
5 | tau: 0.01
6 | rho: 100.0
7 | offset: 1.0
8 | delta: 0.05
9 | start: 0.0
10 |
--------------------------------------------------------------------------------
/roadmap/utils/list_or_tuple.py:
--------------------------------------------------------------------------------
1 | from omegaconf.listconfig import ListConfig
2 |
3 |
4 | def list_or_tuple(lst):
5 | if isinstance(lst, (ListConfig, tuple, list)):
6 | return True
7 |
8 | return False
9 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .venv
2 | __pycache__
3 | .ipynb_checkpoints
4 | .pytest_cache
5 | .DS_Store
6 |
7 | build/
8 | dist/
9 | *.egg-info
10 | *.so
11 | *.c
12 |
13 | tmp
14 | personal_experiment
15 | .idea
16 | outputs
17 | scripts
18 |
--------------------------------------------------------------------------------
/roadmap/config/optimizer/cub.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | - name: Adam
3 | params:
4 | kwargs:
5 | lr: 0.000001
6 | weight_decay: 0.0004
7 | scheduler_on_epoch:
8 | scheduler_on_step:
9 | scheduler_on_val:
10 |
--------------------------------------------------------------------------------
/roadmap/config/optimizer/cub_deit.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | - name: AdamW
3 | params:
4 | kwargs:
5 | lr: 0.000001
6 | weight_decay: 0.0005
7 | scheduler_on_epoch:
8 | scheduler_on_step:
9 | scheduler_on_val:
10 |
--------------------------------------------------------------------------------
/roadmap/config/optimizer/inshop_deit.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | - name: AdamW
3 | params:
4 | kwargs:
5 | lr: 0.00001
6 | weight_decay: 0.0005
7 | scheduler_on_epoch:
8 | scheduler_on_step:
9 | scheduler_on_val:
10 |
--------------------------------------------------------------------------------
/roadmap/config/optimizer/sfm120k_deit.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | - name: AdamW
3 | params:
4 | kwargs:
5 | lr: 0.00001
6 | weight_decay: 0.0005
7 | scheduler_on_epoch:
8 | scheduler_on_step:
9 | scheduler_on_val:
10 |
--------------------------------------------------------------------------------
/roadmap/utils/format_time.py:
--------------------------------------------------------------------------------
1 | def format_time(seconds):
2 | seconds = int(seconds)
3 | minutes = seconds // 60
4 | hours = minutes // 60
5 | minutes = minutes % 60
6 | rseconds = seconds % 60
7 | return f"{hours}h{minutes}m{rseconds}s"
8 |
--------------------------------------------------------------------------------
/roadmap/utils/freeze_batch_norm.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | def freeze_batch_norm(model):
5 | for module in filter(lambda m: type(m) == nn.BatchNorm2d, model.modules()):
6 | module.eval()
7 | module.train = lambda _: None
8 | return model
9 |
--------------------------------------------------------------------------------
/roadmap/config/dataset/cub.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | name: Cub200Dataset
3 | kwargs:
4 | data_dir: /local/DEEPLEARNING/image_retrieval/CUB_200_2011
5 |
6 | sampler:
7 | name: MPerClassSampler
8 | kwargs:
9 | batch_size: 128
10 | samples_per_class: 4
11 |
--------------------------------------------------------------------------------
/roadmap/utils/str_to_bool.py:
--------------------------------------------------------------------------------
1 | def str_to_bool(condition):
2 | if isinstance(condition, str):
3 | if condition.lower() == 'true':
4 | condition = True
5 | if condition.lower() == 'false':
6 | condition = False
7 |
8 | return condition
9 |
--------------------------------------------------------------------------------
/roadmap/config/dataset/inaturalist.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | name: INaturalistDataset
3 | kwargs:
4 | data_dir: /local/DEEPLEARNING/image_retrieval/Inaturalist
5 |
6 | sampler:
7 | name: MPerClassSampler
8 | kwargs:
9 | batch_size: 128
10 | samples_per_class: 4
11 |
--------------------------------------------------------------------------------
/roadmap/config/model/resnet.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | name: RetrievalNet
3 | freeze_batch_norm: True
4 | freeze_pos_embedding: False
5 | kwargs:
6 | backbone_name: resnet50
7 | embed_dim: 512
8 | norm_features: False
9 | without_fc: False
10 | with_autocast: True
11 |
--------------------------------------------------------------------------------
/roadmap/config/model/deit.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | name: RetrievalNet
3 | freeze_batch_norm: False
4 | freeze_pos_embedding: False
5 | kwargs:
6 | backbone_name: vit_deit_distilled
7 | embed_dim: 384
8 | norm_features: False
9 | without_fc: True
10 | with_autocast: True
11 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.20.3
2 | torch==1.8.1
3 | torchvision==0.9.1
4 | faiss-gpu==1.6.5
5 | pillow==8.2.0
6 | pandas==1.2.4
7 | tqdm==4.61.0
8 | tensorboard==2.5.0
9 | matplotlib==3.4.2
10 | pytorch-metric-learning==0.9.99
11 | timm==0.4.12
12 | hydra-core==1.0.6
13 | hydra_colorlog==1.0.1
14 |
--------------------------------------------------------------------------------
/roadmap/config/hydra/launcher/ray_launcher.yaml:
--------------------------------------------------------------------------------
1 | # @package hydra.launcher
2 | _target_: hydra_plugins.hydra_ray_launcher.ray_launcher.RayLauncher
3 | ray:
4 | init:
5 | address: null
6 | _temp_dir: ${env:HOME}/experiments/tmp
7 |
8 | remote:
9 | num_gpus: 1
10 | num_cpus: 16
11 |
--------------------------------------------------------------------------------
/roadmap/config/loss/roadmap.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | - name: CalibrationLoss
3 | weight: 1.0
4 | kwargs:
5 | pos_margin: 0.9
6 | neg_margin: 0.6
7 |
8 | - name: SupAP
9 | weight: 1.0
10 | kwargs:
11 | tau: 0.01
12 | rho: 100.0
13 | offset: 1.44
14 | delta: 0.05
15 |
--------------------------------------------------------------------------------
/roadmap/samplers/__init__.py:
--------------------------------------------------------------------------------
1 | from .hierarchical_sampler import HierarchicalSampler
2 | from .m_per_class_sampler import MPerClassSampler
3 | from .random_sampler import RandomSampler
4 |
5 |
6 | __all__ = [
7 | 'HierarchicalSampler',
8 | 'MPerClassSampler',
9 | 'RandomSampler',
10 | ]
11 |
--------------------------------------------------------------------------------
/roadmap/config/model/resnet_max_ln.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | name: RetrievalNet
3 | freeze_batch_norm: True
4 | freeze_pos_embedding: False
5 | kwargs:
6 | backbone_name: resnet50
7 | embed_dim: 512
8 | norm_features: True
9 | without_fc: False
10 | with_autocast: True
11 | pooling: max
12 |
--------------------------------------------------------------------------------
/roadmap/config/dataset/sop.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | name: SOPDataset
3 | kwargs:
4 | data_dir: /local/DEEPLEARNING/image_retrieval/Stanford_Online_Products
5 |
6 | sampler:
7 | name: HierarchicalSampler
8 | kwargs:
9 | batch_size: 128
10 | samples_per_class: 4
11 | batches_per_super_pair: 10
12 |
--------------------------------------------------------------------------------
/roadmap/config/loss/roadmap_inat.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | - name: CalibrationLoss
3 | weight: 1.0
4 | kwargs:
5 | pos_margin: 0.9
6 | neg_margin: 0.6
7 |
8 | - name: SupAP
9 | weight: 1.0
10 | kwargs:
11 | tau: 0.01
12 | rho: 100.0
13 | offset: 1.0
14 | delta: 0.05
15 | start: 0.0
16 |
--------------------------------------------------------------------------------
/roadmap/config/optimizer/sop_deit.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | - name: AdamW
3 | params:
4 | kwargs:
5 | lr: 0.00001
6 | weight_decay: 0.0005
7 | scheduler_on_epoch:
8 | name: MultiStepLR
9 | kwargs:
10 | milestones: [25, 50]
11 | gamma: 0.3
12 | scheduler_on_step:
13 | scheduler_on_val:
14 |
--------------------------------------------------------------------------------
/roadmap/config/dataset/inshop.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | name: InShopDataset
3 | kwargs:
4 | data_dir: /local/DEEPLEARNING/image_retrieval/inshop
5 | hierarchy_mode: 'all'
6 |
7 | sampler:
8 | name: HierarchicalSampler
9 | kwargs:
10 | batch_size: 128
11 | samples_per_class: 8
12 | batches_per_super_pair: 4
13 | nb_categories: 2
14 |
--------------------------------------------------------------------------------
/roadmap/config/optimizer/inaturalist_deit.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | - name: AdamW
3 | params:
4 | kwargs:
5 | lr: 0.00001
6 | weight_decay: 0.0005
7 | scheduler_on_epoch:
8 | name: MultiStepLR
9 | kwargs:
10 | milestones: [30, 70]
11 | gamma: 0.3
12 | last_epoch: -1
13 | scheduler_on_step:
14 | scheduler_on_val:
15 |
--------------------------------------------------------------------------------
/roadmap/utils/rgb_to_bgr.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 |
3 |
4 | class RGBToBGR():
5 | def __call__(self, im):
6 | assert im.mode == 'RGB'
7 | r, g, b = [im.getchannel(i) for i in range(3)]
8 | # RGB mode also for BGR, `3x8-bit pixels, true color`, see PIL doc
9 | im = Image.merge('RGB', [b, g, r])
10 | return im
11 |
--------------------------------------------------------------------------------
/roadmap/utils/get_gradient_norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def get_gradient_norm(net):
5 | if torch.cuda.device_count() > 1:
6 | net = net.module
7 |
8 | if hasattr(net, 'fc'):
9 | final_layer = net.fc
10 |
11 | elif hasattr(net, 'blocks'):
12 | final_layer = net.blocks[-1].mlp.fc2
13 |
14 | return torch.norm(list(final_layer.parameters())[0].grad, 2)
15 |
--------------------------------------------------------------------------------
/roadmap/config/optimizer/inshop.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | - name: Adam
3 | params: backbone
4 | kwargs:
5 | lr: 0.00001
6 | weight_decay: 0.0001
7 | scheduler_on_epoch:
8 | scheduler_on_step:
9 | scheduler_on_val:
10 |
11 | - name: Adam
12 | params: fc
13 | kwargs:
14 | lr: 0.00002
15 | weight_decay: 0.0001
16 | scheduler_on_epoch:
17 | scheduler_on_step:
18 | scheduler_on_val:
19 |
--------------------------------------------------------------------------------
/roadmap/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .cub200 import Cub200Dataset
2 | from .inaturalist import INaturalistDataset
3 | from .inshop import InShopDataset
4 | from .revisited_dataset import RevisitedDataset
5 | from .sfm120k import SfM120kDataset
6 | from .sop import SOPDataset
7 |
8 |
9 | __all__ = [
10 | 'Cub200Dataset',
11 | 'INaturalistDataset',
12 | 'InShopDataset',
13 | 'RevisitedDataset',
14 | 'SfM120kDataset',
15 | 'SOPDataset',
16 | ]
17 |
--------------------------------------------------------------------------------
/roadmap/config/transform/sfm120k.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | train:
3 | RandomResizedCrop:
4 | scale: [0.16, 1]
5 | ratio: [0.75, 1.33]
6 | size: 224
7 | RandomHorizontalFlip:
8 | p: 0.5
9 | ToTensor: {}
10 | Normalize:
11 | mean: [0.485, 0.456, 0.406]
12 | std: [0.229, 0.224, 0.225]
13 |
14 | test:
15 | Resize:
16 | size: [224,224]
17 | ToTensor: {}
18 | Normalize:
19 | mean: [0.485, 0.456, 0.406]
20 | std: [0.229, 0.224, 0.225]
21 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | with open('requirements.txt') as open_file:
4 | install_requires = open_file.read()
5 |
6 | setuptools.setup(
7 | name='roadmap',
8 | version='0.0.0',
9 | packages=[''],
10 | url='https://github.com/elias-ramzi/ROADMAP',
11 | license='',
12 | author='Elias Ramzi',
13 | author_email='elias.ramzi@lecnam.net',
14 | description='',
15 | python_requires='>=3.6',
16 | install_requires=install_requires
17 | )
18 |
--------------------------------------------------------------------------------
/roadmap/config/transform/cub.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | train:
3 | RandomResizedCrop:
4 | scale: [0.16, 1]
5 | ratio: [0.75, 1.33]
6 | size: 224
7 | RandomHorizontalFlip:
8 | p: 0.5
9 | ToTensor: {}
10 | Normalize:
11 | mean: [0.485, 0.456, 0.406]
12 | std: [0.229, 0.224, 0.225]
13 |
14 | test:
15 | Resize:
16 | size: 256
17 | CenterCrop:
18 | size: 224
19 | ToTensor: {}
20 | Normalize:
21 | mean: [0.485, 0.456, 0.406]
22 | std: [0.229, 0.224, 0.225]
23 |
--------------------------------------------------------------------------------
/roadmap/config/transform/cub_big.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | train:
3 | RandomResizedCrop:
4 | scale: [0.16, 1]
5 | ratio: [0.75, 1.33]
6 | size: 256
7 | RandomHorizontalFlip:
8 | p: 0.5
9 | ToTensor: {}
10 | Normalize:
11 | mean: [0.485, 0.456, 0.406]
12 | std: [0.229, 0.224, 0.225]
13 |
14 | test:
15 | Resize:
16 | size: 288
17 | CenterCrop:
18 | size: 256
19 | ToTensor: {}
20 | Normalize:
21 | mean: [0.485, 0.456, 0.406]
22 | std: [0.229, 0.224, 0.225]
23 |
--------------------------------------------------------------------------------
/roadmap/utils/override_config.py:
--------------------------------------------------------------------------------
1 | def set_attribute(cfg, key, value):
2 | all_key = key.split('.')
3 | obj = cfg
4 | for k in all_key[:-1]:
5 | try:
6 | obj = obj[int(k)]
7 | except ValueError:
8 | obj = getattr(obj, k)
9 | setattr(obj, all_key[-1], value)
10 | return cfg
11 |
12 |
13 | def override_config(hyperparameters, config):
14 | for k, v in hyperparameters.items():
15 | config = set_attribute(config, k, v)
16 |
17 | return config
18 |
--------------------------------------------------------------------------------
/roadmap/config/transform/inaturalist.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | train:
3 | RandomResizedCrop:
4 | scale: [0.16, 1]
5 | ratio: [0.75, 1.33]
6 | size: 224
7 | RandomHorizontalFlip:
8 | p: 0.5
9 | ToTensor: {}
10 | Normalize:
11 | mean: [0.485, 0.456, 0.406]
12 | std: [0.229, 0.224, 0.225]
13 |
14 | test:
15 | Resize:
16 | size: 256
17 | CenterCrop:
18 | size: 224
19 | ToTensor: {}
20 | Normalize:
21 | mean: [0.485, 0.456, 0.406]
22 | std: [0.229, 0.224, 0.225]
23 |
--------------------------------------------------------------------------------
/roadmap/config/transform/inshop_big.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | train:
3 | RandomResizedCrop:
4 | scale: [0.16, 1]
5 | ratio: [0.75, 1.33]
6 | size: 256
7 | RandomHorizontalFlip:
8 | p: 0.5
9 | ToTensor: {}
10 | Normalize:
11 | mean: [0.485, 0.456, 0.406]
12 | std: [0.229, 0.224, 0.225]
13 |
14 | test:
15 | Resize:
16 | size: [288,288]
17 | CenterCrop:
18 | size: 256
19 | ToTensor: {}
20 | Normalize:
21 | mean: [0.485, 0.456, 0.406]
22 | std: [0.229, 0.224, 0.225]
23 |
--------------------------------------------------------------------------------
/roadmap/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from .blackbox_ap import BlackBoxAP
2 | from .calibration_loss import CalibrationLoss
3 | from .fast_ap import FastAP
4 | from .softbin_ap import SoftBinAP
5 | from .pair_loss import PairLoss
6 | from .smooth_rank_ap import (
7 | HeavisideAP,
8 | SmoothAP,
9 | SupAP,
10 | )
11 |
12 |
13 | __all__ = [
14 | 'BlackBoxAP',
15 | 'CalibrationLoss',
16 | 'FastAP',
17 | 'SoftBinAP',
18 | 'PairLoss',
19 | 'HeavisideAP',
20 | 'SmoothAP',
21 | 'SupAP',
22 | ]
23 |
--------------------------------------------------------------------------------
/roadmap/config/transform/sop.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | train:
3 | Resize:
4 | size: 256
5 | RandomResizedCrop:
6 | scale: [0.16, 1]
7 | ratio: [0.75, 1.33]
8 | size: 224
9 | RandomHorizontalFlip:
10 | p: 0.5
11 | ToTensor: {}
12 | Normalize:
13 | mean: [0.485, 0.456, 0.406]
14 | std: [0.229, 0.224, 0.225]
15 |
16 | test:
17 | Resize:
18 | size: [256, 256]
19 | CenterCrop:
20 | size: 224
21 | ToTensor: {}
22 | Normalize:
23 | mean: [0.485, 0.456, 0.406]
24 | std: [0.229, 0.224, 0.225]
25 |
--------------------------------------------------------------------------------
/roadmap/config/dataset/sfm120k.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | name: SfM120kDataset
3 | kwargs:
4 | data_dir: /local/DEEPLEARNING/image_retrieval/landmarks/sfm120k
5 |
6 | sampler:
7 | name: MPerClassSampler
8 | kwargs:
9 | batch_size: 128
10 | samples_per_class: 4
11 |
12 | evaluation:
13 | - name: RevisitedDataset
14 | kwargs:
15 | data_dir: /local/DEEPLEARNING/image_retrieval/landmarks/rparis6k
16 |
17 | - name: RevisitedDataset
18 | kwargs:
19 | data_dir: /local/DEEPLEARNING/image_retrieval/landmarks/roxford5k
20 |
--------------------------------------------------------------------------------
/roadmap/config/transform/sop_big.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | train:
3 | Resize:
4 | size: 288
5 | RandomResizedCrop:
6 | scale: [0.16, 1]
7 | ratio: [0.75, 1.33]
8 | size: 256
9 | RandomHorizontalFlip:
10 | p: 0.5
11 | ToTensor: {}
12 | Normalize:
13 | mean: [0.485, 0.456, 0.406]
14 | std: [0.229, 0.224, 0.225]
15 |
16 | test:
17 | Resize:
18 | size: [288, 288]
19 | CenterCrop:
20 | size: 256
21 | ToTensor: {}
22 | Normalize:
23 | mean: [0.485, 0.456, 0.406]
24 | std: [0.229, 0.224, 0.225]
25 |
--------------------------------------------------------------------------------
/roadmap/config/default.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 |
3 | - experience: default
4 |
5 | - dataset: sop
6 |
7 | - loss: roadmap
8 |
9 | - memory: default
10 |
11 | - model: resnet
12 |
13 | - optimizer: sop
14 |
15 | - transform: sop
16 |
17 | - hydra/job_logging: colorlog
18 |
19 | - hydra/hydra_logging: colorlog
20 |
21 | hydra:
22 | run:
23 | dir: ${experience.log_dir}/${experience.experiment_name}/outputs
24 |
25 | sweep:
26 | dir: ${experience.log_dir}
27 | subdir: ${experience.experiment_name}/outputs
28 |
--------------------------------------------------------------------------------
/roadmap/utils/average_meter.py:
--------------------------------------------------------------------------------
1 | def _handle_types(value):
2 | if hasattr(value, "detach"):
3 | return value.detach().item()
4 | else:
5 | return value
6 |
7 |
8 | class AverageMeter:
9 | def __init__(self) -> None:
10 | self.val = 0
11 | self.avg = 0
12 | self.sum = 0
13 | self.count = 0
14 |
15 | def update(self, val, n=1) -> None:
16 | val = _handle_types(val)
17 | self.val = val
18 | self.sum += val * n
19 | self.count += n
20 | self.avg = self.sum / self.count
21 |
--------------------------------------------------------------------------------
/roadmap/utils/extract_progress.py:
--------------------------------------------------------------------------------
1 | import zipfile
2 | import tarfile
3 |
4 | from tqdm import tqdm
5 |
6 | from .logger import LOGGER
7 |
8 |
9 | def extract_progress(compressed_obj):
10 | LOGGER.info("Extracting dataset")
11 | if isinstance(compressed_obj, tarfile.TarFile):
12 | iterable = compressed_obj
13 | length = len(compressed_obj.getmembers())
14 | elif isinstance(compressed_obj, zipfile.ZipFile):
15 | iterable = compressed_obj.namelist()
16 | length = len(iterable)
17 | for member in tqdm.tqdm(iterable, total=length):
18 | yield member
19 |
--------------------------------------------------------------------------------
/roadmap/config/optimizer/sop.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | - name: Adam
3 | params: backbone
4 | kwargs:
5 | lr: 0.00001
6 | weight_decay: 0.0001
7 | scheduler_on_epoch:
8 | name: MultiStepLR
9 | kwargs:
10 | milestones: [30,70]
11 | gamma: 0.3
12 | scheduler_on_step:
13 | scheduler_on_val:
14 |
15 | - name: Adam
16 | params: fc
17 | kwargs:
18 | lr: 0.00002
19 | weight_decay: 0.0001
20 | scheduler_on_epoch:
21 | name: MultiStepLR
22 | kwargs:
23 | milestones: [30,70]
24 | gamma: 0.3
25 | scheduler_on_step:
26 | scheduler_on_val:
27 |
--------------------------------------------------------------------------------
/roadmap/engine/make_subset.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 |
3 |
4 | def make_subset(dts, idxs, transform=None, mode=None):
5 | dts = deepcopy(dts)
6 |
7 | dts.paths = [dts.paths[x] for x in idxs]
8 | dts.labels = [dts.labels[x] for x in idxs]
9 |
10 | if hasattr(dts, 'super_labels') and dts.super_labels is not None:
11 | dts.super_labels = [dts.super_labels[x] for x in idxs]
12 |
13 | dts.get_instance_dict()
14 | dts.get_super_dict()
15 |
16 | if transform is not None:
17 | dts.transform = transform
18 |
19 | if mode is not None:
20 | dts.mode = mode
21 |
22 | return dts
23 |
--------------------------------------------------------------------------------
/roadmap/config/optimizer/inaturalist.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | - name: Adam
3 | params: backbone
4 | kwargs:
5 | lr: 0.00001
6 | weight_decay: 0.0004
7 | scheduler_on_epoch:
8 | name: MultiStepLR
9 | kwargs:
10 | milestones: [30, 70]
11 | gamma: 0.3
12 | last_epoch: -1
13 | scheduler_on_step:
14 | scheduler_on_val:
15 |
16 | - name: Adam
17 | params: fc
18 | kwargs:
19 | lr: 0.00002
20 | weight_decay: 0.0004
21 | scheduler_on_epoch:
22 | name: MultiStepLR
23 | kwargs:
24 | milestones: [30, 70]
25 | gamma: 0.3
26 | last_epoch: -1
27 | scheduler_on_step:
28 | scheduler_on_val:
29 |
--------------------------------------------------------------------------------
/roadmap/utils/moving_average.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class MovingAverage:
5 |
6 | def __init__(self, ws):
7 | self.data = []
8 |
9 | def __call__(self, value):
10 | try:
11 | self.data.pop(0)
12 | except IndexError:
13 | pass
14 |
15 | self.data.append(value)
16 | return np.mean(self.data)
17 |
18 | def mean_first(self, value):
19 | if self.data:
20 | mean = np.mean(self.data)
21 | self.data.pop(0)
22 | self.data.append(value)
23 | return mean
24 |
25 | else:
26 | self.data.append(value)
27 | return value
28 |
--------------------------------------------------------------------------------
/roadmap/config/experience/landmarks.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | experiment_name: ???
3 | log_dir: /share/DEEPLEARNING/datasets/image_retrieval/experiments/
4 | seed: ???
5 | resume: null
6 | maybe_resume: False
7 | force_lr: null
8 |
9 | warm_up: -1
10 | warm_up_key: fc
11 |
12 | max_iter: 50
13 |
14 | train_eval_freq: -1
15 | val_eval_freq: -1
16 | test_eval_freq: 5
17 | eval_bs: 96
18 | save_model: 50
19 | eval_split: rparis6k
20 | principal_metric: mapH
21 | log_grad: False
22 | with_AP: False
23 | landmarks: True
24 |
25 | split: null
26 | kfold: null
27 | split_random_state: null
28 | with_super_labels: False
29 |
30 | num_workers: 16
31 | pin_memory: True
32 |
33 | sub_batch: 128
34 | update_type: base_update
35 |
--------------------------------------------------------------------------------
/roadmap/config/experience/default.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | experiment_name: ???
3 | log_dir: /share/DEEPLEARNING/datasets/image_retrieval/experiments/
4 | seed: ???
5 | resume: null
6 | maybe_resume: False
7 | force_lr: null
8 |
9 | warm_up: -1
10 | warm_up_key: fc
11 |
12 | max_iter: 50
13 |
14 | train_eval_freq: -1
15 | val_eval_freq: -1
16 | test_eval_freq: 5
17 | eval_bs: 96
18 | save_model: 50
19 | eval_split: test
20 | principal_metric: mean_average_precision_at_r_level0
21 | log_grad: False
22 | with_AP: False
23 | landmarks: False
24 |
25 | split: null
26 | kfold: null
27 | split_random_state: null
28 | with_super_labels: False
29 |
30 | num_workers: 16
31 | pin_memory: True
32 |
33 | sub_batch: 128
34 | update_type: base_update
35 |
--------------------------------------------------------------------------------
/roadmap/utils/create_label_matrix.py:
--------------------------------------------------------------------------------
1 | def create_label_matrix(labels, other_labels=None):
2 | labels = labels.squeeze()
3 |
4 | if labels.ndim == 1:
5 | if other_labels is None:
6 | return (labels.view(-1, 1) == labels.t()).float()
7 |
8 | return (labels.view(-1, 1) == other_labels.t()).float()
9 |
10 | elif labels.ndim == 2:
11 | size = labels.size(0)
12 | if other_labels is None:
13 | return (labels.view(size, size, 1) == labels.view(size, 1, size)).float()
14 |
15 | raise NotImplementedError(f"Function for tensor dimension {labels.ndim} comparated to tensor of dimension {other_labels.ndim} not implemented")
16 |
17 | raise NotImplementedError(f"Function for tensor dimension {labels.ndim} not implemented")
18 |
--------------------------------------------------------------------------------
/roadmap/engine/__init__.py:
--------------------------------------------------------------------------------
1 | from .accuracy_calculator import CustomCalculator
2 | from .base_update import base_update
3 | from .chepoint import checkpoint
4 | from .cross_validation_splits import (
5 | get_class_disjoint_splits,
6 | get_hierarchical_class_disjoint_splits,
7 | get_closed_set_splits,
8 | get_splits,
9 | )
10 | from .evaluate import evaluate, get_tester
11 | from .landmark_evaluation import landmark_evaluation
12 | from .make_subset import make_subset
13 | from .memory import XBM
14 | from .train import train
15 |
16 |
17 | __all__ = [
18 | 'CustomCalculator',
19 | 'base_update',
20 | 'checkpoint',
21 | 'get_class_disjoint_splits',
22 | 'get_hierarchical_class_disjoint_splits',
23 | 'get_closed_set_splits',
24 | 'get_splits',
25 | 'evaluate',
26 | 'get_tester',
27 | 'landmark_evaluation',
28 | 'make_subset',
29 | 'XBM',
30 | 'train',
31 | ]
32 |
--------------------------------------------------------------------------------
/roadmap/utils/dict_average.py:
--------------------------------------------------------------------------------
1 | from .average_meter import AverageMeter
2 |
3 |
4 | class DictAverage:
5 |
6 | def __init__(self,) -> None:
7 | self.dict_avg = {}
8 |
9 | def update(self, dict_values: dict) -> None:
10 | for key, item in dict_values.items():
11 | try:
12 | self.dict_avg[key].update(item)
13 | except KeyError:
14 | self.dict_avg[key] = AverageMeter()
15 | self.dict_avg[key].update(item)
16 |
17 | def keys(self,):
18 | return self.dict_avg.keys()
19 |
20 | def __getitem__(self, name):
21 | self.dict_avg[name]
22 |
23 | def get(self, name, other=None):
24 | try:
25 | return self.dict_avg[name]
26 | except KeyError:
27 | return self.dict_avg[other]
28 |
29 | @property
30 | def avg(self,) -> dict:
31 | return {key: item.avg for key, item in self.dict_avg.items()}
32 |
--------------------------------------------------------------------------------
/roadmap/samplers/random_sampler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch.utils.data import BatchSampler
3 |
4 | import roadmap.utils as lib
5 |
6 |
7 | class RandomSampler(BatchSampler):
8 | def __init__(
9 | self,
10 | dataset,
11 | batch_size,
12 | ):
13 | self.batch_size = batch_size
14 |
15 | self.length = len(dataset)
16 | self.reshuffle()
17 |
18 | def __iter__(self,):
19 | self.reshuffle()
20 | for batch in self.batches:
21 | yield batch
22 |
23 | def __len__(self,):
24 | return len(self.batches)
25 |
26 | def __repr__(self,):
27 | return f"{self.__class__.__name__}(batch_size={self.batch_size})"
28 |
29 | def reshuffle(self):
30 | lib.LOGGER.info("Shuffling data")
31 | idxs = list(range(self.length))
32 | np.random.shuffle(idxs)
33 | self.batches = []
34 | for i in range(self.length // self.batch_size):
35 | self.batches.append(idxs[i*self.batch_size:(i+1)*self.batch_size])
36 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Elias Ramzi
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/roadmap/models/create_projection_head.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | import roadmap.utils as lib
4 |
5 |
6 | def create_projection_head(
7 | input_dimension=2048,
8 | layer_dim=512,
9 | normalization_layer='none',
10 | ):
11 | if isinstance(layer_dim, int):
12 | return nn.Linear(input_dimension, layer_dim)
13 |
14 | elif lib.list_or_tuple(layer_dim):
15 | layers = []
16 | prev_dim = input_dimension
17 | for i, dim in enumerate(layer_dim):
18 | layers.append(nn.Linear(prev_dim, dim))
19 | prev_dim = dim
20 | if i < len(layer_dim) - 1:
21 | if normalization_layer == 'bn':
22 | layers.append(nn.BatchNorm1d(dim))
23 | elif normalization_layer == 'ln':
24 | layers.append(nn.LayerNorm(dim))
25 | elif normalization_layer == 'none':
26 | pass
27 | else:
28 | raise ValueError(f"Unknown normalization layer : {normalization_layer}")
29 | layers.append(nn.ReLU(inplace=True))
30 |
31 | return nn.Sequential(*layers)
32 |
--------------------------------------------------------------------------------
/roadmap/utils/get_set_random_state.py:
--------------------------------------------------------------------------------
1 | import random
2 | from functools import wraps
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from .logger import LOGGER
8 |
9 |
10 | def get_random_state():
11 | LOGGER.debug("Getting random state")
12 | RANDOM_STATE = {}
13 | RANDOM_STATE["RANDOM_STATE"] = random.getstate()
14 | RANDOM_STATE["NP_STATE"] = np.random.get_state()
15 | RANDOM_STATE["TORCH_STATE"] = torch.random.get_rng_state()
16 | RANDOM_STATE["TORCH_CUDA_STATE"] = torch.cuda.get_rng_state_all()
17 | return RANDOM_STATE
18 |
19 |
20 | def set_random_state(RANDOM_STATE):
21 | LOGGER.debug("Setting random state")
22 | random.setstate(RANDOM_STATE["RANDOM_STATE"])
23 | np.random.set_state(RANDOM_STATE["NP_STATE"])
24 | torch.random.set_rng_state(RANDOM_STATE["TORCH_STATE"])
25 | torch.cuda.set_rng_state_all(RANDOM_STATE["TORCH_CUDA_STATE"])
26 |
27 |
28 | def get_set_random_state(func):
29 | @wraps(func)
30 | def wrapper(*args, **kwargs):
31 | RANDOM_STATE = get_random_state()
32 | output = func(*args, **kwargs)
33 | set_random_state(RANDOM_STATE)
34 | return output
35 | return wrapper
36 |
--------------------------------------------------------------------------------
/roadmap/datasets/sop.py:
--------------------------------------------------------------------------------
1 | from os.path import join
2 |
3 | import pandas as pd
4 |
5 | from .base_dataset import BaseDataset
6 |
7 |
8 | class SOPDataset(BaseDataset):
9 |
10 | def __init__(self, data_dir, mode, transform=None, **kwargs):
11 | super().__init__(**kwargs)
12 |
13 | self.data_dir = data_dir
14 | self.mode = mode
15 | self.transform = transform
16 |
17 | if mode == 'train':
18 | mode = ['train']
19 | elif mode == 'test':
20 | mode = ['test']
21 | elif mode == 'all':
22 | mode = ['train', 'test']
23 | else:
24 | raise ValueError(f"Mode unrecognized {mode}")
25 |
26 | self.paths = []
27 | self.labels = []
28 | self.super_labels = []
29 | for splt in mode:
30 | gt = pd.read_csv(join(self.data_dir, f'Ebay_{splt}.txt'), sep=' ')
31 | self.paths.extend(gt["path"].apply(lambda x: join(self.data_dir, x)).tolist())
32 | self.labels.extend((gt["class_id"] - 1).tolist())
33 | self.super_labels.extend((gt["super_class_id"] - 1).tolist())
34 |
35 | self.get_instance_dict()
36 | self.get_super_dict()
37 |
--------------------------------------------------------------------------------
/roadmap/datasets/sfm120k.py:
--------------------------------------------------------------------------------
1 | from os.path import join
2 |
3 | from scipy.io import loadmat
4 |
5 | from .base_dataset import BaseDataset
6 |
7 |
8 | def cid2filename(cid, prefix):
9 | """
10 | https://github.com/filipradenovic/cnnimageretrieval-pytorch
11 |
12 | Creates a training image path out of its CID name
13 |
14 | Arguments
15 | ---------
16 | cid : name of the image
17 | prefix : root directory where images are saved
18 |
19 | Returns
20 | -------
21 | filename : full image filename
22 | """
23 | return join(prefix, cid[-2:], cid[-4:-2], cid[-6:-4], cid)
24 |
25 |
26 | class SfM120kDataset(BaseDataset):
27 |
28 | def __init__(self, data_dir, mode, transform=None, **kwargs):
29 | super().__init__(**kwargs)
30 |
31 | self.data_dir = data_dir
32 | self.mode = mode
33 | self.transform = transform
34 |
35 | db = loadmat(join(self.data_dir, "retrieval-SfM-120k-imagenames-clusterids.mat"))
36 |
37 | cids = [x[0] for x in db['cids'][0]]
38 | self.paths = [cid2filename(x, join(self.data_dir, "ims")) for x in cids]
39 | self.labels = [int(x) for x in db['cluster'][0]]
40 |
41 | self.get_instance_dict()
42 |
--------------------------------------------------------------------------------
/roadmap/losses/calibration_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from pytorch_metric_learning import losses, distances
3 | from pytorch_metric_learning.utils import common_functions as c_f
4 | from pytorch_metric_learning.utils import loss_and_miner_utils as lmu
5 |
6 |
7 | class CalibrationLoss(losses.ContrastiveLoss):
8 | takes_embeddings = True
9 |
10 | def get_default_distance(self):
11 | return distances.DotProductSimilarity()
12 |
13 | def forward(self, embeddings, labels, ref_embeddings=None, ref_labels=None):
14 | if ref_embeddings is None:
15 | return super().forward(embeddings, labels)
16 |
17 | indices_tuple = self.create_indices_tuple(
18 | embeddings.size(0),
19 | embeddings,
20 | labels,
21 | ref_embeddings,
22 | ref_labels,
23 | )
24 |
25 | combined_embeddings = torch.cat([embeddings, ref_embeddings], dim=0)
26 | combined_labels = torch.cat([labels, ref_labels], dim=0)
27 | return super().forward(combined_embeddings, combined_labels, indices_tuple)
28 |
29 | def create_indices_tuple(
30 | self,
31 | batch_size,
32 | embeddings,
33 | labels,
34 | E_mem,
35 | L_mem,
36 | ):
37 | indices_tuple = lmu.get_all_pairs_indices(labels, L_mem)
38 | indices_tuple = c_f.shift_indices_tuple(indices_tuple, batch_size)
39 | return indices_tuple
40 |
--------------------------------------------------------------------------------
/roadmap/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .average_meter import AverageMeter
2 | from .count_parameters import count_parameters
3 | from .create_label_matrix import create_label_matrix
4 | from .dict_average import DictAverage
5 | from .expand_path import expand_path
6 | from .extract_progress import extract_progress
7 | from .format_time import format_time
8 | from .freeze_batch_norm import freeze_batch_norm
9 | from .freeze_pos_embedding import freeze_pos_embedding
10 | from .get_gradient_norm import get_gradient_norm
11 | from .get_lr import get_lr
12 | from .get_set_random_state import get_random_state, set_random_state, get_set_random_state
13 | from .list_or_tuple import list_or_tuple
14 | from .logger import LOGGER
15 | from .moving_average import MovingAverage
16 | from .override_config import override_config
17 | from .rgb_to_bgr import RGBToBGR
18 | from .set_initial_lr import set_initial_lr
19 | from .set_lr import set_lr
20 | from .str_to_bool import str_to_bool
21 |
22 |
23 | __all__ = [
24 | 'AverageMeter',
25 | 'count_parameters',
26 | 'create_label_matrix',
27 | 'DictAverage',
28 | 'expand_path',
29 | 'extract_progress',
30 | 'format_time',
31 | 'freeze_batch_norm',
32 | 'freeze_pos_embedding',
33 | 'get_gradient_norm',
34 | 'get_random_state',
35 | 'set_random_state',
36 | 'get_set_random_state',
37 | 'get_lr',
38 | 'list_or_tuple',
39 | 'LOGGER',
40 | 'MovingAverage',
41 | 'override_config',
42 | 'RGBToBGR',
43 | 'set_initial_lr',
44 | 'set_lr',
45 | 'str_to_bool',
46 | ]
47 |
--------------------------------------------------------------------------------
/roadmap/datasets/inshop.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from .base_dataset import BaseDataset
4 |
5 |
6 | class InShopDataset(BaseDataset):
7 |
8 | def __init__(self, data_dir, mode, transform=None, hierarchy_mode='all', **kwargs):
9 | super().__init__(**kwargs)
10 |
11 | assert mode in ["train", "query", "gallery"], f"Mode : {mode} unknown"
12 | assert hierarchy_mode in ['1', '2', 'all'], f"Hierarchy mode : {hierarchy_mode} unknown"
13 | self.data_dir = data_dir
14 | self.mode = mode
15 | self.transform = transform
16 |
17 | with open(os.path.join(data_dir, "list_eval_partition.txt")) as f:
18 | db = f.read().split("\n")[2:-1]
19 |
20 | paths = []
21 | labels = []
22 | super_labels_name = []
23 | for line in db:
24 | line = line.split(" ")
25 | line = list(filter(lambda x: x, line))
26 | if line[2] == mode:
27 | paths.append(os.path.join(data_dir, line[0]))
28 | labels.append(int(line[1].split("_")[-1]))
29 | if hierarchy_mode == '2':
30 | super_labels_name.append(line[0].split("/")[2])
31 | elif hierarchy_mode == '1':
32 | super_labels_name.append(line[0].split("/")[1])
33 | elif hierarchy_mode == 'all':
34 | super_labels_name.append('/'.join(line[0].split("/")[1:3]))
35 |
36 | self.paths = paths
37 | self.labels = labels
38 |
39 | slb_to_id = {slb: i for i, slb in enumerate(set(super_labels_name))}
40 | self.super_labels = [slb_to_id[slb] for slb in super_labels_name]
41 | self.super_labels_name = super_labels_name
42 |
43 | self.get_instance_dict()
44 | self.get_super_dict()
45 |
--------------------------------------------------------------------------------
/roadmap/losses/pair_loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Malong Technologies Co., Ltd.
2 | # All rights reserved.
3 | #
4 | # Contact: github@malong.com
5 | #
6 | # This source code is licensed under the LICENSE file in
7 | # https://github.com/msight-tech/research-xbm/blob/master/LICENSE
8 | import torch
9 | from torch import nn
10 |
11 |
12 | class PairLoss(nn.Module):
13 | takes_embeddings = True
14 |
15 | def __init__(self, margin=0.5):
16 | super().__init__()
17 | self.margin = margin
18 |
19 | def compute_loss(self, inputs_col, targets_col, inputs_row, target_row):
20 |
21 | n = inputs_col.size(0)
22 | # Compute similarity matrix
23 | sim_mat = torch.matmul(inputs_col, inputs_row.t())
24 | epsilon = 1e-5
25 | loss = list()
26 |
27 | neg_count = list()
28 | for i in range(n):
29 | pos_pair_ = torch.masked_select(sim_mat[i], targets_col[i] == target_row)
30 | pos_pair_ = torch.masked_select(pos_pair_, pos_pair_ < 1 - epsilon)
31 | neg_pair_ = torch.masked_select(sim_mat[i], targets_col[i] != target_row)
32 |
33 | neg_pair = torch.masked_select(neg_pair_, neg_pair_ > self.margin)
34 |
35 | pos_loss = torch.sum(-pos_pair_ + 1)
36 | if len(neg_pair) > 0:
37 | neg_loss = torch.sum(neg_pair)
38 | neg_count.append(len(neg_pair))
39 | else:
40 | neg_loss = 0
41 |
42 | loss.append(pos_loss + neg_loss)
43 |
44 | loss = sum(loss) / n
45 | return loss
46 |
47 | def forward(self, embeddings, labels, ref_embeddings=None, ref_labels=None):
48 | if ref_embeddings is None:
49 | return self.compute_loss(embeddings, labels, embeddings, labels)
50 |
51 | return self.compute_loss(embeddings, labels, ref_embeddings, ref_labels)
52 |
53 | def extra_repr(self,):
54 | return f"margin={self.margin}"
55 |
--------------------------------------------------------------------------------
/roadmap/datasets/cub200.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | from torchvision import datasets
5 |
6 | from .base_dataset import BaseDataset
7 |
8 |
9 | class Cub200Dataset(BaseDataset):
10 |
11 | def __init__(self, data_dir, mode, transform=None, load_super_labels=False, **kwargs):
12 | super().__init__(**kwargs)
13 | self.data_dir = data_dir
14 | self.mode = mode
15 | self.transform = transform
16 | self.load_super_labels = load_super_labels
17 |
18 | dataset = datasets.ImageFolder(os.path.join(self.data_dir, 'images'))
19 | paths = np.array([a for (a, b) in dataset.imgs])
20 | labels = np.array([b for (a, b) in dataset.imgs])
21 |
22 | sorted_lb = list(sorted(set(labels)))
23 | if mode == 'train':
24 | set_labels = set(sorted_lb[:len(sorted_lb) // 2])
25 | elif mode == 'test':
26 | set_labels = set(sorted_lb[len(sorted_lb) // 2:])
27 | elif mode == 'all':
28 | set_labels = sorted_lb
29 |
30 | self.paths = []
31 | self.labels = []
32 | for lb, pth in zip(labels, paths):
33 | if lb in set_labels:
34 | self.paths.append(pth)
35 | self.labels.append(lb)
36 |
37 | self.super_labels = None
38 | if self.load_super_labels:
39 | with open(os.path.join(self.data_dir, "classes.txt")) as f:
40 | lines = f.read().split("\n")
41 | lines.remove("")
42 | labels_id = list(map(lambda x: int(x.split(" ")[0])-1, lines))
43 | super_labels_name = list(map(lambda x: x.split(" ")[2], lines))
44 | slb_names_to_id = {x: i for i, x in enumerate(sorted(set(super_labels_name)))}
45 | super_labels = [slb_names_to_id[x] for x in super_labels_name]
46 | labels_to_super_labels = {lb: slb for lb, slb in zip(labels_id, super_labels)}
47 | self.super_labels = [labels_to_super_labels[x] for x in self.labels]
48 |
49 | self.get_instance_dict()
50 |
--------------------------------------------------------------------------------
/roadmap/engine/chepoint.py:
--------------------------------------------------------------------------------
1 | from os.path import join
2 |
3 | import torch
4 |
5 | import roadmap.utils as lib
6 |
7 |
8 | def checkpoint(
9 | log_dir,
10 | save_checkpoint,
11 | net,
12 | optimizer,
13 | scheduler,
14 | scaler,
15 | epoch,
16 | seed,
17 | args,
18 | score,
19 | best_model,
20 | best_score,
21 | ):
22 | state_dict = {}
23 | if torch.cuda.device_count() > 1:
24 | state_dict["net_state"] = net.module.state_dict()
25 | else:
26 | state_dict["net_state"] = net.state_dict()
27 |
28 | state_dict["optimizer_state"] = {key: opt.state_dict() for key, opt in optimizer.items()}
29 |
30 | state_dict["scheduler_on_epoch_state"] = [sch.state_dict() for sch in scheduler["on_epoch"]]
31 | state_dict["scheduler_on_step_state"] = [sch.state_dict() for sch in scheduler["on_step"]]
32 | state_dict["scheduler_on_val_state"] = [sch.state_dict() for sch, _ in scheduler["on_val"]]
33 |
34 | if scaler is not None:
35 | state_dict["scaler_state"] = scaler.state_dict()
36 |
37 | state_dict["epoch"] = epoch
38 | state_dict["seed"] = seed
39 | state_dict["config"] = args
40 | state_dict["score"] = score
41 | state_dict["best_score"] = best_score
42 | state_dict["best_model"] = f"{best_model}.ckpt"
43 |
44 | RANDOM_STATE = lib.get_random_state()
45 | state_dict.update(RANDOM_STATE)
46 |
47 | if log_dir is None:
48 | from ray import tune
49 | torch.save(state_dict, join(tune.get_trial_dir(), "rolling.ckpt"))
50 | if save_checkpoint:
51 | with tune.checkpoint_dir(step=epoch) as checkpoint_dir:
52 | lib.LOGGER.info(f"Checkpoint of epoch {epoch} created")
53 | torch.save(state_dict, join(checkpoint_dir, f"epoch_{epoch}.ckpt"))
54 |
55 | else:
56 | torch.save(state_dict, join(log_dir, 'weights', "rolling.ckpt"))
57 | if save_checkpoint:
58 | lib.LOGGER.info(f"Checkpoint of epoch {epoch} created")
59 | torch.save(state_dict, join(log_dir, 'weights', f"epoch_{epoch}.ckpt"))
60 |
--------------------------------------------------------------------------------
/roadmap/engine/get_knn.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import faiss
4 | import pytorch_metric_learning.utils.common_functions as c_f
5 |
6 | import roadmap.utils as lib
7 |
8 |
9 | def get_knn(references, queries, num_k, embeddings_come_from_same_source, with_faiss=True):
10 | num_k += embeddings_come_from_same_source
11 |
12 | lib.LOGGER.info("running k-nn with k=%d" % num_k)
13 | lib.LOGGER.info("embedding dimensionality is %d" % references.size(-1))
14 |
15 | if with_faiss:
16 | distances, indices = get_knn_faiss(references, queries, num_k)
17 | else:
18 | distances, indices = get_knn_torch(references, queries, num_k)
19 |
20 | if embeddings_come_from_same_source:
21 | return indices[:, 1:], distances[:, 1:]
22 |
23 | return indices, distances
24 |
25 |
26 | def get_knn_faiss(references, queries, num_k):
27 | lib.LOGGER.debug("Computing k-nn with faiss")
28 |
29 | d = references.size(-1)
30 | device = references.device
31 | references = c_f.to_numpy(references).astype(np.float32)
32 | queries = c_f.to_numpy(queries).astype(np.float32)
33 |
34 | index = faiss.IndexFlatL2(d)
35 | try:
36 | if torch.cuda.device_count() > 1:
37 | co = faiss.GpuMultipleClonerOptions()
38 | co.shards = True
39 | index = faiss.index_cpu_to_all_gpus(index, co)
40 | else:
41 | res = faiss.StandardGpuResources()
42 | index = faiss.index_cpu_to_gpu(res, 0, index)
43 | except AttributeError:
44 | # Only faiss CPU is installed
45 | pass
46 |
47 | index.add(references)
48 | distances, indices = index.search(queries, num_k)
49 | distances = c_f.to_device(torch.from_numpy(distances), device=device)
50 | indices = c_f.to_device(torch.from_numpy(indices), device=device)
51 | index.reset()
52 | return distances, indices
53 |
54 |
55 | def get_knn_torch(references, queries, num_k):
56 | lib.LOGGER.debug("Computing k-nn with torch")
57 |
58 | scores = queries @ references.t()
59 | distances, indices = torch.topk(scores, num_k)
60 | return distances, indices
61 |
--------------------------------------------------------------------------------
/roadmap/samplers/m_per_class_sampler.py:
--------------------------------------------------------------------------------
1 | """
2 | adapted from :
3 | https://github.com/Andrew-Brown1/Smooth_AP/blob/master/src/datasets.py
4 | """
5 | import copy
6 |
7 | import numpy as np
8 |
9 | import roadmap.utils as lib
10 |
11 |
12 | def flatten(list_):
13 | return [item for sublist in list_ for item in sublist]
14 |
15 |
16 | class MPerClassSampler:
17 | def __init__(
18 | self,
19 | dataset,
20 | batch_size,
21 | samples_per_class=4,
22 | ):
23 | """
24 | Args:
25 | image_dict: two-level dict, `super_dict[super_class_id][class_id]` gives the list of
26 | image idxs having the same super-label and class label
27 | """
28 | assert samples_per_class > 1
29 | assert batch_size % samples_per_class == 0
30 | self.image_dict = dataset.instance_dict.copy()
31 | self.batch_size = batch_size
32 | self.samples_per_class = samples_per_class
33 |
34 | self.reshuffle()
35 |
36 | def __iter__(self,):
37 | for batch in self.batches:
38 | yield batch
39 |
40 | def __len__(self,):
41 | return len(self.batches)
42 |
43 | def __repr__(self,):
44 | return (
45 | f"{self.__class__.__name__}(\n"
46 | f" batch_size={self.batch_size},\n"
47 | f" samples_per_class={self.samples_per_class}\n)"
48 | )
49 |
50 | def reshuffle(self):
51 | lib.LOGGER.info("Shuffling data")
52 | image_dict = copy.deepcopy(self.image_dict)
53 | for sub in image_dict:
54 | np.random.shuffle(image_dict[sub])
55 |
56 | classes = [*image_dict]
57 | np.random.shuffle(classes)
58 | total_batches = []
59 | batch = []
60 | finished = 0
61 | while finished == 0:
62 | for sub_class in classes:
63 | if (len(image_dict[sub_class]) >= self.samples_per_class) and (len(batch) < self.batch_size/self.samples_per_class):
64 | batch.append(image_dict[sub_class][:self.samples_per_class])
65 | image_dict[sub_class] = image_dict[sub_class][self.samples_per_class:]
66 |
67 | if len(batch) == self.batch_size/self.samples_per_class:
68 | batch = flatten(batch)
69 | np.random.shuffle(batch)
70 | total_batches.append(batch)
71 | batch = []
72 | else:
73 | finished = 1
74 |
75 | np.random.shuffle(total_batches)
76 | self.batches = total_batches
77 |
--------------------------------------------------------------------------------
/roadmap/datasets/inaturalist.py:
--------------------------------------------------------------------------------
1 | import json
2 | from os.path import join
3 |
4 | from .base_dataset import BaseDataset
5 |
6 |
7 | class INaturalistDataset(BaseDataset):
8 |
9 | def __init__(self, data_dir, mode, transform=None, **kwargs):
10 | super().__init__(**kwargs)
11 |
12 | self.data_dir = data_dir
13 | self.mode = mode
14 | self.transform = transform
15 |
16 | if mode == 'train':
17 | mode = ['train']
18 | elif mode == 'test':
19 | mode = ['test']
20 | elif mode == 'all':
21 | mode = ['train', 'test']
22 | else:
23 | raise ValueError(f"Mode unrecognized {mode}")
24 |
25 | self.paths = []
26 | for splt in mode:
27 | with open(join(self.data_dir, f'Inat_dataset_splits/Inaturalist_{splt}_set1.txt')) as f:
28 | paths = f.read().split("\n")
29 | paths.remove("")
30 | self.paths.extend([join(self.data_dir, pth) for pth in paths])
31 |
32 | with open(join(self.data_dir, 'train2018.json')) as f:
33 | db = json.load(f)['categories']
34 | self.db = {}
35 | for x in db:
36 | _ = x.pop("name")
37 | id_ = x.pop("id")
38 | x["species"] = id_
39 | self.db[id_] = x
40 |
41 | self.labels_name = [int(x.split("/")[-2]) for x in self.paths]
42 | self.labels_to_id = {cl: i for i, cl in enumerate(sorted(set(self.labels_name)))}
43 | self.labels = [self.labels_to_id[x] for x in self.labels_name]
44 |
45 | self.hierarchy_name = {}
46 | for x in self.labels_name:
47 | for key, val in self.db[x].items():
48 | try:
49 | self.hierarchy_name[key].append(val)
50 | except KeyError:
51 | self.hierarchy_name[key] = [val]
52 |
53 | self.hierarchy_name_to_id = {}
54 | self.hierarchy_labels = {}
55 | for key, lst in self.hierarchy_name.items():
56 | self.hierarchy_name_to_id[key] = {cl: i for i, cl in enumerate(sorted(set(lst)))}
57 | self.hierarchy_labels[key] = [self.hierarchy_name_to_id[key][x] for x in lst]
58 |
59 | self.super_labels_name = [x.split("/")[-3] for x in self.paths]
60 | self.super_labels_to_id = {scl: i for i, scl in enumerate(sorted(set(self.super_labels_name)))}
61 | self.super_labels = [self.super_labels_to_id[x] for x in self.super_labels_name]
62 |
63 | self.get_instance_dict()
64 | self.get_super_dict()
65 |
--------------------------------------------------------------------------------
/roadmap/engine/memory.py:
--------------------------------------------------------------------------------
1 | from collections import deque
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | def get_mask(lst):
8 | return [i == lst.index(x) for i, x in enumerate(lst)]
9 |
10 |
11 | class XBM(nn.Module):
12 |
13 | def __init__(self, size=None, weight=1.0, activate_after=-1, unique=True):
14 | super().__init__()
15 | self.size = size
16 | self.unique = unique
17 | self.weight = weight
18 | self.activate_after = activate_after
19 |
20 | if self.unique:
21 | self.features_memory = {}
22 | self.labels_memory = {}
23 | else:
24 | self.features_memory = deque()
25 | self.labels_memory = deque()
26 |
27 | def add_without_keys(self, features, labels):
28 | bs = features.size(0)
29 | while len(self.features_memory) + bs > self.size:
30 | self.features_memory.popleft()
31 | self.labels_memory.popleft()
32 |
33 | for feat, lb in zip(features, labels):
34 | self.features_memory.append(feat)
35 | self.labels_memory.append(lb)
36 |
37 | def add_with_keys(self, features, labels, keys):
38 | for k, feat, lb in zip(keys, features, labels):
39 | self.features_memory[k] = feat
40 | self.labels_memory[k] = lb
41 |
42 | def get_occupied_storage(self,):
43 | if not self.features_memory:
44 | return torch.tensor([]), torch.tensor([])
45 |
46 | if self.unique:
47 | return torch.stack(list(self.features_memory.values())), torch.stack(list(self.labels_memory.values()))
48 |
49 | return torch.stack(list(self.features_memory)), torch.stack(list(self.labels_memory))
50 |
51 | def forward(self, features, labels, keys=None):
52 |
53 | if self.unique:
54 | assert keys is not None
55 | self.add_with_keys(features, labels, keys)
56 | else:
57 | self.add_without_keys(features, labels)
58 |
59 | mem_features, mem_labels = self.get_occupied_storage()
60 | return mem_features, mem_labels
61 |
62 | def extra_repr(self,):
63 | return f"size={self.size}, unique={self.unique}"
64 |
65 |
66 | if __name__ == '__main__':
67 | mem = XBM((56, 128), unique=False)
68 |
69 | mem(torch.ones(32, 128), torch.ones(32,), torch.arange(32, 64))
70 | mem(torch.ones(32, 128), torch.ones(32,), torch.arange(32))
71 |
72 | # mem(torch.ones(32, 128), torch.ones(32,))
73 | # features, labels = mem(torch.ones(32, 128), torch.ones(32,))
74 | print(mem.index)
75 |
--------------------------------------------------------------------------------
/roadmap/datasets/revisited_dataset.py:
--------------------------------------------------------------------------------
1 | from os.path import join
2 |
3 | import torch
4 | import pickle
5 | from PIL import Image
6 |
7 | from .base_dataset import BaseDataset
8 |
9 |
10 | def imresize(img, imsize):
11 | img.thumbnail((imsize, imsize), Image.ANTIALIAS)
12 | return img
13 |
14 |
15 | def path_to_label(pth):
16 | return "_".join(pth.split("/")[-1].split(".")[0].split("_")[:-1])
17 |
18 |
19 | class RevisitedDataset(BaseDataset):
20 |
21 | def __init__(self, data_dir, mode, imsize=None, transform=None, **kwargs):
22 | super().__init__(**kwargs)
23 | assert mode in ["query", "gallery"]
24 |
25 | self.data_dir = data_dir
26 | self.mode = mode
27 | self.imsize = imsize
28 | self.transform = transform
29 | self.city = self.data_dir.split('/')
30 | self.city = self.city[-1] if self.city[-1] else self.city[-2]
31 |
32 | with open(join(self.data_dir, f"gnd_{self.city}.pkl"), "rb") as f:
33 | db = pickle.load(f)
34 |
35 | self.paths = [join(self.data_dir, "jpg", f"{x}.jpg") for x in db["qimlist" if self.mode == "query" else "imlist"]]
36 | self.labels_name = [path_to_label(x) for x in self.paths]
37 | labels_name_to_id = {lb: i for i, lb in enumerate(sorted(set(self.labels_name)))}
38 | self.labels = [labels_name_to_id[x] for x in self.labels_name]
39 |
40 | if self.mode == "query":
41 | self.bbx = [x["bbx"] for x in db["gnd"]]
42 | self.easy = [x["easy"] for x in db["gnd"]]
43 | self.hard = [x["hard"] for x in db["gnd"]]
44 | self.junk = [x["junk"] for x in db["gnd"]]
45 |
46 | self.get_instance_dict()
47 |
48 | def __getitem__(self, idx):
49 | img = Image.open(self.paths[idx])
50 | imfullsize = max(img.size)
51 |
52 | if self.mode == 'query':
53 | img = img.crop(self.bbx[idx])
54 |
55 | if self.imsize is not None:
56 | if self.mode == 'query':
57 | img = imresize(img, self.imsize * max(img.size) / imfullsize)
58 | else:
59 | img = imresize(img, self.imsize)
60 |
61 | if self.transform is not None:
62 | img = self.transform(img)
63 |
64 | out = {"image": img, "label": torch.tensor([self.labels[idx]])}
65 | if self.mode == 'query':
66 | out["easy"] = self.easy[idx]
67 | out["hard"] = self.hard[idx]
68 | out["junk"] = self.junk[idx]
69 |
70 | return out
71 |
72 | def __repr__(self,):
73 | return f"{self.city.title()}Dataset(mode={self.mode}, imsize={self.imsize}, len={len(self)})"
74 |
--------------------------------------------------------------------------------
/roadmap/single_experiment_runner.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import hydra
4 | import torch
5 |
6 | import roadmap.run as run
7 | import roadmap.utils as lib
8 |
9 |
10 | @hydra.main(config_path='config', config_name='default')
11 | def single_experiment_runner(cfg):
12 | """
13 | Parses hydra config, check for potential resuming of training
14 | and launches training
15 | """
16 |
17 | try:
18 | try:
19 | import ray
20 | ray.cluster_resources()
21 | lib.LOGGER.info("Experiment running with ray : deactivating TQDM")
22 | os.environ['TQDM_DISABLE'] = "1"
23 | except Exception:
24 | pass
25 | except ray.exceptions.RaySystemError:
26 | pass
27 |
28 | cfg.experience.log_dir = lib.expand_path(cfg.experience.log_dir)
29 |
30 | if cfg.experience.resume is not None:
31 | if os.path.isfile(lib.expand_path(cfg.experience.resume)):
32 | resume = lib.expand_path(cfg.experience.resume)
33 | else:
34 | resume = os.path.join(cfg.experience.log_dir, cfg.experience.experiment_name, 'weights', cfg.experience.resume)
35 |
36 | if not os.path.isfile(resume):
37 | lib.LOGGER.warning("Checkpoint does not exists")
38 | return
39 |
40 | at_epoch = torch.load(resume, map_location='cpu')["epoch"]
41 | if at_epoch >= cfg.experience.max_iter:
42 | lib.LOGGER.warning(f"Exiting trial, experiment {cfg.experience.experiment_name} already finished")
43 | return
44 |
45 | elif cfg.experience.maybe_resume:
46 | state_path = os.path.join(cfg.experience.log_dir, cfg.experience.experiment_name, 'weights', 'rolling.ckpt')
47 | if os.path.isfile(state_path):
48 | resume = state_path
49 | lib.LOGGER.warning(f"Resuming experience because weights were found @ {resume}")
50 | at_epoch = torch.load(resume, map_location='cpu')["epoch"]
51 | if at_epoch >= cfg.experience.max_iter:
52 | lib.LOGGER.warning(f"Exiting trial, experiment {cfg.experience.experiment_name} already finished")
53 | return
54 | else:
55 | resume = None
56 |
57 | else:
58 | resume = None
59 | if os.path.isdir(os.path.join(cfg.experience.log_dir, cfg.experience.experiment_name, 'weights')):
60 | lib.LOGGER.warning(f"Exiting trial, experiment {cfg.experience.experiment_name} already exists")
61 | return
62 |
63 | metrics = run.run(
64 | config=cfg,
65 | base_config=None,
66 | checkpoint_dir=resume,
67 | )
68 |
69 | if metrics is not None:
70 | return metrics[cfg.experience.eval_split][cfg.experience.principal_metric]
71 |
72 |
73 | if __name__ == '__main__':
74 | single_experiment_runner()
75 |
--------------------------------------------------------------------------------
/roadmap/samplers/hierarchical_sampler.py:
--------------------------------------------------------------------------------
1 | import itertools
2 |
3 | import numpy as np
4 | from torch.utils.data.sampler import BatchSampler
5 |
6 | import roadmap.utils as lib
7 |
8 |
9 | def safe_random_choice(input_data, size):
10 | replace = len(input_data) < size
11 | return np.random.choice(input_data, size=size, replace=replace)
12 |
13 |
14 | # Inspired by
15 | # https://github.com/kunhe/Deep-Metric-Learning-Baselines/blob/master/datasets.py
16 | class HierarchicalSampler(BatchSampler):
17 | def __init__(
18 | self,
19 | dataset,
20 | batch_size,
21 | samples_per_class,
22 | batches_per_super_pair,
23 | nb_categories=2,
24 | ):
25 | """
26 | labels: 2D array, where rows correspond to elements, and columns correspond to the hierarchical labels
27 | batch_size: because this is a BatchSampler the batch size must be specified
28 | samples_per_class: number of instances to sample for a specific class. set to 0 if all element in a class
29 | batches_per_super_pairs: number of batches to create for a pair of categories (or super labels)
30 | inner_label: columns index corresponding to classes
31 | outer_label: columns index corresponding to the level of hierarchy for the pairs
32 | """
33 | self.batch_size = int(batch_size)
34 | self.batches_per_super_pair = int(batches_per_super_pair)
35 | self.samples_per_class = int(samples_per_class)
36 | self.nb_categories = int(nb_categories)
37 |
38 | # checks
39 | assert self.batch_size % self.nb_categories == 0, f"batch_size should be a multiple of {self.nb_categories}"
40 | self.sub_batch_len = self.batch_size // self.nb_categories
41 | if self.samples_per_class > 0:
42 | assert self.sub_batch_len % self.samples_per_class == 0, "batch_size not a multiple of samples_per_class"
43 | else:
44 | self.samples_per_class = None
45 |
46 | self.super_image_lists = dataset.super_dict.copy()
47 | self.super_pairs = list(itertools.combinations(set(dataset.super_labels), self.nb_categories))
48 | self.reshuffle()
49 |
50 | def __iter__(self,):
51 | self.reshuffle()
52 | for batch in self.batches:
53 | yield batch
54 |
55 | def __len__(self,):
56 | return len(self.batches)
57 |
58 | def __repr__(self,):
59 | return (
60 | f"{self.__class__.__name__}(\n"
61 | f" batch_size={self.batch_size},\n"
62 | f" samples_per_class={self.samples_per_class},\n"
63 | f" batches_per_super_pair={self.batches_per_super_pair},\n"
64 | f" nb_categories={self.nb_categories}\n)"
65 | )
66 |
67 | def reshuffle(self):
68 | lib.LOGGER.info("Shuffling data")
69 | batches = []
70 | for combinations in self.super_pairs:
71 |
72 | for b in range(self.batches_per_super_pair):
73 |
74 | batch = []
75 | for slb in combinations:
76 |
77 | sub_batch = []
78 | all_classes = list(self.super_image_lists[slb].keys())
79 | np.random.shuffle(all_classes)
80 | for cl in all_classes:
81 | instances = self.super_image_lists[slb][cl]
82 | samples_per_class = self.samples_per_class if self.samples_per_class else len(instances)
83 | if len(sub_batch) + samples_per_class > self.sub_batch_len:
84 | continue
85 | sub_batch.extend(safe_random_choice(instances, size=samples_per_class))
86 |
87 | batch.extend(sub_batch)
88 |
89 | np.random.shuffle(batch)
90 | batches.append(batch)
91 |
92 | np.random.shuffle(batches)
93 | self.batches = batches
94 |
--------------------------------------------------------------------------------
/roadmap/evaluate.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import argparse
3 |
4 | import torch
5 | import numpy as np
6 |
7 | from roadmap.getter import Getter
8 | import roadmap.utils as lib
9 | import roadmap.engine as eng
10 |
11 |
12 | def load_and_evaluate(
13 | path,
14 | set,
15 | bs,
16 | nw,
17 | data_dir=None,
18 | ):
19 | lib.LOGGER.info(f"Evaluating : \033[92m{path}\033[0m")
20 | state = torch.load(lib.expand_path(path), map_location='cpu')
21 | cfg = state["config"]
22 |
23 | lib.LOGGER.info("Loading model...")
24 | cfg.model.kwargs.with_autocast = True
25 | net = Getter().get_model(cfg.model)
26 | net.load_state_dict(state["net_state"])
27 | if torch.cuda.device_count() > 1:
28 | net = torch.nn.DataParallel(net)
29 | net.cuda()
30 | net.eval()
31 |
32 | if data_dir is not None:
33 | cfg.dataset.kwargs.data_dir = lib.expand_path(data_dir)
34 |
35 | getter = Getter()
36 | transform = getter.get_transform(cfg.transform.test)
37 | if hasattr(cfg.experience, 'split') and (cfg.experience.split is not None):
38 | assert isinstance(cfg.experience.split, int)
39 | dts = getter.get_dataset(None, 'all', cfg.dataset)
40 | splits = eng.get_splits(dts.labels, dts.super_labels, cfg.experience.kfold, random_state=cfg.experience.split_random_state)
41 | dts = eng.make_subset(dts, splits[cfg.experience.split]['train' if set == 'train' else 'val'], transform, set)
42 | lib.LOGGER.info(dts)
43 | else:
44 | dts = getter.get_dataset(transform, set, cfg.dataset)
45 |
46 | lib.LOGGER.info("Dataset created...")
47 |
48 | metrics = eng.evaluate(
49 | net=net,
50 | test_dataset=dts,
51 | epoch=state["epoch"],
52 | batch_size=bs,
53 | num_workers=nw,
54 | exclude=['mean_average_precision'],
55 | )
56 |
57 | lib.LOGGER.info("Evaluation completed...")
58 | for split, mtrc in metrics.items():
59 | for k, v in mtrc.items():
60 | if k == 'epoch':
61 | continue
62 | lib.LOGGER.info(f"{split} --> {k} : {np.around(v*100, decimals=2)}")
63 |
64 | return metrics
65 |
66 |
67 | if __name__ == '__main__':
68 |
69 | parser = argparse.ArgumentParser()
70 | parser.add_argument("--config", type=str, required=True, nargs='+', help='Path.s to checkpoint')
71 | parser.add_argument("--parse-file", default=False, action='store_true', help='allows to pass a .txt file with several models to evaluate')
72 | parser.add_argument("--set", type=str, default='test', help='Set on which to evaluate')
73 | parser.add_argument("--bs", type=int, default=128, help='Batch size for DataLoader')
74 | parser.add_argument("--nw", type=int, default=10, help='Num workers for DataLoader')
75 | parser.add_argument("--data-dir", type=str, default=None, help='Possible override of the datadir in the dataset config')
76 | parser.add_argument("--metric-dir", type=str, default=None, help='Path in which to store the metrics')
77 | args = parser.parse_args()
78 |
79 | logging.basicConfig(
80 | format='%(asctime)s - %(levelname)s - %(message)s',
81 | datefmt='%m/%d/%Y %I:%M:%S %p',
82 | level=logging.INFO,
83 | )
84 |
85 | if args.parse_file:
86 | with open(args.config[0], 'r') as f:
87 | paths = f.read().split('\n')
88 | paths.remove("")
89 | args.config = paths
90 |
91 | for path in args.config:
92 | metrics = load_and_evaluate(
93 | path=path,
94 | set=args.set,
95 | bs=args.bs,
96 | nw=args.nw,
97 | data_dir=args.data_dir,
98 | )
99 | print()
100 | print()
101 |
102 | if args.metric_dir is not None:
103 | with open(args.metric_dir, 'a') as f:
104 | f.write(path)
105 | f.write("\n")
106 | for split, mtrc in metrics.items():
107 | for k, v in mtrc.items():
108 | if k == 'epoch':
109 | continue
110 | f.write(f"{split} --> {k} : {np.around(v*100, decimals=2)}\n")
111 | f.write("\n\n")
112 |
--------------------------------------------------------------------------------
/roadmap/losses/blackbox_ap.py:
--------------------------------------------------------------------------------
1 | # MIT License
2 | #
3 | # Copyright (c) 2019 Autonomous Learning Group
4 | #
5 | # Permission is hereby granted, free of charge, to any person obtaining a copy
6 | # of this software and associated documentation files (the "Software"), to deal
7 | # in the Software without restriction, including without limitation the rights
8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | # copies of the Software, and to permit persons to whom the Software is
10 | # furnished to do so, subject to the following conditions:
11 | #
12 | # The above copyright notice and this permission notice shall be included in all
13 | # copies or substantial portions of the Software.
14 | #
15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | # SOFTWARE.
22 | import torch
23 |
24 |
25 | def rank(seq):
26 | return torch.argsort(torch.argsort(seq).flip(1))
27 |
28 |
29 | def rank_normalised(seq):
30 | return (rank(seq) + 1).float() / seq.size()[1]
31 |
32 |
33 | class TrueRanker(torch.autograd.Function):
34 | @staticmethod
35 | @torch.cuda.amp.custom_fwd
36 | def forward(ctx, sequence, lambda_val):
37 | rank = rank_normalised(sequence)
38 | ctx.lambda_val = lambda_val
39 | ctx.save_for_backward(sequence, rank)
40 | return rank
41 |
42 | @staticmethod
43 | @torch.cuda.amp.custom_bwd
44 | def backward(ctx, grad_output):
45 | sequence, rank = ctx.saved_tensors
46 | assert grad_output.shape == rank.shape
47 | sequence_prime = sequence + ctx.lambda_val * grad_output
48 | rank_prime = rank_normalised(sequence_prime)
49 | gradient = -(rank - rank_prime) / (ctx.lambda_val + 1e-8)
50 | return gradient, None
51 |
52 |
53 | class BlackBoxAP(torch.nn.Module):
54 | """ Torch module for computing recall-based loss as in 'Blackbox differentiation of Ranking-based Metrics' """
55 | def __init__(self,
56 | lambda_val=4.,
57 | margin=0.02,
58 | return_type='1-mAP',
59 | ):
60 | """
61 | :param lambda_val: hyperparameter of black-box backprop
62 | :param margin: margin to be enforced between positives and negatives (alpha in the paper)
63 | :param interclass_coef: coefficient for interclass loss (beta in paper)
64 | :param batch_memory: how many batches should be in memory
65 | """
66 | super().__init__()
67 | assert return_type in ["AP", "mAP", "1-mAP", "1-AP"]
68 | self.lambda_val = lambda_val
69 | self.margin = margin
70 | self.return_type = return_type
71 |
72 | def raw_map_computation(self, scores, targets):
73 | """
74 | :param scores: NxM predicted similarity scores
75 | :param targets: NxM ground truth relevances
76 | """
77 | # Compute map
78 | HIGH_CONSTANT = 2.0
79 | epsilon = 1e-5
80 | deviations = torch.abs(torch.randn_like(targets, device=scores.device, dtype=scores.dtype)) * (targets - 0.5)
81 |
82 | scores = scores - self.margin * deviations
83 | ranks_of_positive = TrueRanker.apply(scores, self.lambda_val)
84 | scores_for_ranking_positives = -ranks_of_positive + HIGH_CONSTANT * targets
85 | ranks_within_positive = rank_normalised(scores_for_ranking_positives)
86 | ranks_within_positive.requires_grad = False
87 | assert torch.all(ranks_within_positive * targets < ranks_of_positive * targets + epsilon)
88 |
89 | sum_of_precisions_at_j_per_class = ((ranks_within_positive / ranks_of_positive) * targets).sum(dim=1)
90 | precisions_per_class = sum_of_precisions_at_j_per_class / (targets.sum(dim=1) + epsilon)
91 |
92 | if self.return_type == '1-mAP':
93 | return 1.0 - precisions_per_class.mean()
94 | elif self.return_type == '1-AP':
95 | return 1.0 - precisions_per_class
96 | elif self.return_type == 'mAP':
97 | return precisions_per_class.mean()
98 | elif self.return_type == 'AP':
99 | return precisions_per_class
100 |
101 | def forward(self, output, target):
102 | return self.raw_map_computation(output, target.type(output.dtype))
103 |
104 | def extra_repr(self,):
105 | return f"lambda_val={self.lambda_val}, margin={self.margin}, return_type={self.return_type}"
106 |
--------------------------------------------------------------------------------
/roadmap/losses/softbin_ap.py:
--------------------------------------------------------------------------------
1 | # BSD 3-Clause License
2 | #
3 | # Copyright (c) 2019, NAVER LABS
4 | # All rights reserved.
5 | #
6 | # Redistribution and use in source and binary forms, with or without
7 | # modification, are permitted provided that the following conditions are met:
8 | #
9 | # 1. Redistributions of source code must retain the above copyright notice, this
10 | # list of conditions and the following disclaimer.
11 | #
12 | # 2. Redistributions in binary form must reproduce the above copyright notice,
13 | # this list of conditions and the following disclaimer in the documentation
14 | # and/or other materials provided with the distribution.
15 | #
16 | # 3. Neither the name of the copyright holder nor the names of its
17 | # contributors may be used to endorse or promote products derived from
18 | # this software without specific prior written permission.
19 | #
20 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 | import numpy as np
31 | import torch
32 | import torch.nn as nn
33 |
34 |
35 | class SoftBinAP (nn.Module):
36 | """ Differentiable AP loss, through quantization. From the paper:
37 |
38 | Learning with Average Precision: Training Image Retrieval with a Listwise Loss
39 | Jerome Revaud, Jon Almazan, Rafael Sampaio de Rezende, Cesar de Souza
40 | https://arxiv.org/abs/1906.07589
41 |
42 | Input: (N, M) values in [min, max]
43 | label: (N, M) values in {0, 1}
44 |
45 | Returns: 1 - mAP (mean AP for each n in {1..N})
46 | Note: typically, this is what you wanna minimize
47 | """
48 | def __init__(
49 | self,
50 | nq=20,
51 | min=-1,
52 | max=1,
53 | return_type='1-mAP',
54 | ):
55 | super().__init__()
56 | assert isinstance(nq, int) and 2 <= nq <= 100
57 | assert return_type in ["1-mAP", "AP", "1-AP", "mAP", "debug"]
58 | self.nq = nq
59 | self.min = min
60 | self.max = max
61 | self.return_type = return_type
62 |
63 | gap = max - min
64 | assert gap > 0
65 | # Initialize quantizer as non-trainable convolution
66 | self.quantizer = q = nn.Conv1d(1, 2*nq, kernel_size=1, bias=True)
67 | q.weight = nn.Parameter(q.weight.detach(), requires_grad=False)
68 | q.bias = nn.Parameter(q.bias.detach(), requires_grad=False)
69 | a = (nq-1) / gap
70 | # First half equal to lines passing to (min+x,1) and (min+x+1/a,0)
71 | # with x = {nq-1..0}*gap/(nq-1)
72 | q.weight[:nq] = -a
73 | q.bias[:nq] = torch.from_numpy(a*min + np.arange(nq, 0, -1)) # b = 1 + a*(min+x)
74 | # First half equal to lines passing to (min+x,1) and (min+x-1/a,0)
75 | # with x = {nq-1..0}*gap/(nq-1)
76 | q.weight[nq:] = a
77 | q.bias[nq:] = torch.from_numpy(np.arange(2-nq, 2, 1) - a*min) # b = 1 - a*(min+x)
78 | # First and last one as a horizontal straight line
79 | q.weight[0] = q.weight[-1] = 0
80 | q.bias[0] = q.bias[-1] = 1
81 |
82 | def forward(self, x, label, qw=None):
83 | assert x.shape == label.shape # N x M
84 | N, M = x.shape
85 | # Quantize all predictions
86 | q = self.quantizer(x.unsqueeze(1))
87 | q = torch.min(q[:, :self.nq], q[:, self.nq:]).clamp(min=0) # N x Q x M
88 |
89 | nbs = q.sum(dim=-1) # number of samples N x Q = c
90 | rec = (q * label.view(N, 1, M).float()).sum(dim=-1) # number of correct samples = c+ N x Q
91 | prec = rec.cumsum(dim=-1) / (1e-16 + nbs.cumsum(dim=-1)) # precision
92 | rec /= rec.sum(dim=-1).unsqueeze(1) # norm in [0,1]
93 |
94 | ap = (prec * rec).sum(dim=-1) # per-image AP
95 |
96 | if self.return_type == '1-mAP':
97 | if qw is not None:
98 | ap *= qw # query weights
99 | loss = 1 - ap.mean()
100 | return loss
101 | elif self.return_type == 'AP':
102 | assert qw is None
103 | return ap
104 | elif self.return_type == 'mAP':
105 | assert qw is None
106 | return ap.mean()
107 | elif self.return_type == '1-AP':
108 | return 1 - ap
109 | elif self.return_type == 'debug':
110 | return prec, rec
111 |
112 | def extra_repr(self,):
113 | return f"nq={self.nq}, min={self.min}, max={self.max}, return_type={self.return_type}"
114 |
--------------------------------------------------------------------------------
/roadmap/engine/base_update.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import numpy as np
5 | from tqdm import tqdm
6 |
7 | import roadmap.utils as lib
8 |
9 |
10 | def _batch_optimization(
11 | config,
12 | net,
13 | batch,
14 | criterion,
15 | optimizer,
16 | scaler,
17 | epoch,
18 | memory
19 | ):
20 | with torch.cuda.amp.autocast(enabled=(scaler is not None)):
21 | di = net(batch["image"].cuda())
22 | labels = batch["label"].cuda()
23 | scores = torch.mm(di, di.t())
24 | label_matrix = lib.create_label_matrix(labels)
25 |
26 | if memory:
27 | memory_embeddings, memory_labels = memory(di.detach(), labels, batch["path"])
28 | if epoch >= config.memory.activate_after:
29 | memory_scores = torch.mm(di, memory_embeddings.t())
30 | memory_label_matrix = lib.create_label_matrix(labels, memory_labels)
31 |
32 | logs = {}
33 | losses = []
34 | for crit, weight in criterion:
35 | if hasattr(crit, 'takes_embeddings'):
36 | loss = crit(di, labels.view(-1))
37 | if memory:
38 | if epoch >= config.memory.activate_after:
39 | mem_loss = crit(di, labels.view(-1), memory_embeddings, memory_labels.view(-1))
40 |
41 | else:
42 | loss = crit(scores, label_matrix)
43 | if memory:
44 | if epoch >= config.memory.activate_after:
45 | mem_loss = crit(memory_scores, memory_label_matrix)
46 |
47 | loss = loss.mean()
48 | if weight == 'adaptative':
49 | losses.append(loss)
50 | else:
51 | losses.append(weight * loss)
52 |
53 | logs[crit.__class__.__name__] = loss.item()
54 | if memory:
55 | if epoch >= config.memory.activate_after:
56 | mem_loss = mem_loss.mean()
57 | if weight == 'adaptative':
58 | losses.append(config.memory.weight * mem_loss)
59 | else:
60 | losses.append(weight * config.memory.weight * mem_loss)
61 | logs[f"memory_{crit.__class__.__name__}"] = mem_loss.item()
62 |
63 | if weight == 'adaptative':
64 | grads = []
65 | for i, lss in enumerate(losses):
66 | g = torch.autograd.grad(lss, net.fc.parameters(), retain_graph=True)
67 | grads.append(torch.norm(g[0]).item())
68 | mean_grad = np.mean(grads)
69 | weights = [mean_grad / g for g in grads]
70 | losses = [w * lss for w, lss in zip(weights, losses)]
71 | logs.update({
72 | f"weight_{crit.__class__.__name__}": w for (crit, _), w in zip(criterion, weights)
73 | })
74 | logs.update({
75 | f"grad_{crit.__class__.__name__}": w for (crit, _), w in zip(criterion, grads)
76 | })
77 |
78 | total_loss = sum(losses)
79 | if scaler is None:
80 | total_loss.backward()
81 | else:
82 | scaler.scale(total_loss).backward()
83 |
84 | logs["total_loss"] = total_loss.item()
85 | _ = [loss.detach_() for loss in losses]
86 | total_loss.detach_()
87 | return logs
88 |
89 |
90 | def base_update(
91 | config,
92 | net,
93 | loader,
94 | criterion,
95 | optimizer,
96 | scheduler,
97 | scaler,
98 | epoch,
99 | memory=None,
100 | ):
101 | meter = lib.DictAverage()
102 | net.train()
103 | net.zero_grad()
104 |
105 | iterator = tqdm(loader, disable=os.getenv('TQDM_DISABLE'))
106 | for i, batch in enumerate(iterator):
107 | logs = _batch_optimization(
108 | config,
109 | net,
110 | batch,
111 | criterion,
112 | optimizer,
113 | scaler,
114 | epoch,
115 | memory,
116 | )
117 |
118 | if config.experience.log_grad:
119 | grad_norm = lib.get_gradient_norm(net)
120 | logs["grad_norm"] = grad_norm.item()
121 |
122 | for key, opt in optimizer.items():
123 | if epoch < config.experience.warm_up and key != config.experience.warm_up_key:
124 | lib.LOGGER.info(f"Warming up @epoch {epoch}")
125 | continue
126 | if scaler is None:
127 | opt.step()
128 | else:
129 | scaler.step(opt)
130 |
131 | net.zero_grad()
132 | _ = [crit.zero_grad() for crit, w in criterion]
133 |
134 | for sch in scheduler["on_step"]:
135 | sch.step()
136 |
137 | if scaler is not None:
138 | scaler.update()
139 |
140 | meter.update(logs)
141 | if not os.getenv('TQDM_DISABLE'):
142 | iterator.set_postfix(meter.avg)
143 | else:
144 | if (i + 1) % 50 == 0:
145 | lib.LOGGER.info(f'Iteration : {i}/{len(loader)}')
146 | for k, v in logs.items():
147 | lib.LOGGER.info(f'Loss: {k}: {v} ')
148 |
149 | for crit, _ in criterion:
150 | if hasattr(crit, 'step'):
151 | crit.step()
152 |
153 | return meter.avg
154 |
--------------------------------------------------------------------------------
/roadmap/getter.py:
--------------------------------------------------------------------------------
1 | from torch import optim
2 | import torchvision.transforms as transforms
3 |
4 | from roadmap import losses
5 | from roadmap import samplers
6 | from roadmap import datasets
7 | from roadmap import models
8 | from roadmap import engine
9 | from roadmap import utils as lib
10 |
11 |
12 | class Getter:
13 | """
14 | This class allows to create differents object (model, loss functions, optimizer...)
15 | based on the config
16 | """
17 |
18 | def get(self, obj, *args, **kwargs):
19 | return getattr(self, f"get_{obj}")(*args, **kwargs)
20 |
21 | def get_transform(self, config):
22 | t_list = []
23 | for k, v in config.items():
24 | t_list.append(getattr(transforms, k)(**v))
25 |
26 | transform = transforms.Compose(t_list)
27 | lib.LOGGER.info(transform)
28 | return transform
29 |
30 | def get_optimizer(self, net, config):
31 | optimizers = {}
32 | schedulers = {
33 | "on_epoch": [],
34 | "on_step": [],
35 | "on_val": [],
36 | }
37 | for opt in config:
38 | optimizer = getattr(optim, opt.name)
39 | if opt.params is not None:
40 | optimizer = optimizer(getattr(net, opt.params).parameters(), **opt.kwargs)
41 | optimizers[opt.params] = optimizer
42 | else:
43 | optimizer = optimizer(net.parameters(), **opt.kwargs)
44 | optimizers["net"] = optimizer
45 | lib.LOGGER.info(optimizer)
46 | if opt.scheduler_on_epoch is not None:
47 | schedulers["on_epoch"].append(self.get_scheduler(optimizer, opt.scheduler_on_epoch))
48 | if opt.scheduler_on_step is not None:
49 | schedulers["on_step"].append(self.get_scheduler(optimizer, opt.scheduler_on_step))
50 | if opt.scheduler_on_val is not None:
51 | schedulers["on_val"].append(
52 | (self.get_scheduler(optimizer, opt.scheduler_on_val), opt.scheduler_on_val.key)
53 | )
54 |
55 | return optimizers, schedulers
56 |
57 | def get_scheduler(self, opt, config):
58 | sch = getattr(optim.lr_scheduler, config.name)(opt, **config.kwargs)
59 | lib.LOGGER.info(sch)
60 | return sch
61 |
62 | def get_loss(self, config):
63 | criterion = []
64 | for crit in config:
65 | loss = getattr(losses, crit.name)(**crit.kwargs)
66 | weight = crit.weight
67 | lib.LOGGER.info(f"{loss} with weight {weight}")
68 | criterion.append((loss, weight))
69 | return criterion
70 |
71 | def get_sampler(self, dataset, config):
72 | sampler = getattr(samplers, config.name)(dataset, **config.kwargs)
73 | lib.LOGGER.info(sampler)
74 | return sampler
75 |
76 | def get_dataset(self, transform, mode, config):
77 | if (config.name == "InShopDataset") and (mode == "test"):
78 | dataset = {
79 | "test": getattr(datasets, config.name)(transform=transform, mode="query", **config.kwargs),
80 | "gallery": getattr(datasets, config.name)(transform=transform, mode="gallery", **config.kwargs),
81 | }
82 | lib.LOGGER.info(dataset)
83 | return dataset
84 | elif (config.name == "DyMLDataset") and mode.startswith("test"):
85 | dataset = {
86 | "test": getattr(datasets, config.name)(transform=transform, mode="test_query_fine", **config.kwargs),
87 | "distractor": getattr(datasets, config.name)(transform=transform, mode="test_gallery_fine", **config.kwargs),
88 | }
89 | lib.LOGGER.info(dataset)
90 | return dataset
91 | elif (config.name == "SfM120kDataset") and (mode == "test"):
92 | dataset = []
93 | for dts in config.evaluation:
94 | test_dts = getattr(datasets, dts.name)(transform=transform, mode="query", **dts.kwargs)
95 | dataset.append({
96 | f"query_{test_dts.city}": test_dts,
97 | f"gallery_{test_dts.city}": getattr(datasets, dts.name)(transform=transform, mode="gallery", **dts.kwargs),
98 | })
99 | lib.LOGGER.info(dataset)
100 | return dataset
101 | else:
102 | dataset = getattr(datasets, config.name)(
103 | transform=transform,
104 | mode=mode,
105 | **config.kwargs,
106 | )
107 | lib.LOGGER.info(dataset)
108 | return dataset
109 |
110 | def get_model(self, config):
111 | net = getattr(models, config.name)(**config.kwargs)
112 | if config.freeze_batch_norm:
113 | lib.LOGGER.info("Freezing batch norm")
114 | net = lib.freeze_batch_norm(net)
115 | if config.freeze_pos_embedding:
116 | lib.LOGGER.info("Freezing pos embeddings")
117 | net.backbone = lib.freeze_pos_embedding(net.backbone)
118 | return net
119 |
120 | def get_memory(self, config):
121 | memory = getattr(engine, config.name)(**config.kwargs)
122 | lib.LOGGER.info(memory)
123 | return memory
124 |
--------------------------------------------------------------------------------
/roadmap/datasets/base_dataset.py:
--------------------------------------------------------------------------------
1 | import random
2 | from collections import Counter
3 |
4 | import numpy as np
5 | import torch
6 | from torch.utils.data import Dataset
7 | import torchvision.transforms as transforms
8 | from PIL import Image
9 | from PIL import ImageFilter
10 |
11 |
12 | class BaseDataset(Dataset):
13 |
14 | def __init__(
15 | self,
16 | multi_crop=False,
17 | size_crops=[224, 96],
18 | nmb_crops=[2, 6],
19 | min_scale_crops=[0.14, 0.05],
20 | max_scale_crops=[1., 0.14],
21 | size_dataset=-1,
22 | return_label='none',
23 | ):
24 | super().__init__()
25 |
26 | if not multi_crop:
27 | self.get_fn = self.simple_get
28 | else:
29 | # adapted from
30 | # https://github.com/facebookresearch/swav/blob/master/src/multicropdataset.py
31 | self.get_fn = self.multiple_crop_get
32 |
33 | self.return_label = return_label
34 | assert self.return_label in ["none", "real", "hash"]
35 |
36 | color_transform = [get_color_distortion(), PILRandomGaussianBlur()]
37 | mean = [0.485, 0.456, 0.406]
38 | std = [0.228, 0.224, 0.225]
39 | trans = []
40 | for i in range(len(size_crops)):
41 | randomresizedcrop = transforms.RandomResizedCrop(
42 | size_crops[i],
43 | scale=(min_scale_crops[i], max_scale_crops[i]),
44 | )
45 | trans.extend([transforms.Compose([
46 | randomresizedcrop,
47 | transforms.RandomHorizontalFlip(p=0.5),
48 | transforms.Compose(color_transform),
49 | transforms.ToTensor(),
50 | transforms.Normalize(mean=mean, std=std)])
51 | ] * nmb_crops[i])
52 | self.trans = trans
53 |
54 | def __len__(self,):
55 | return len(self.paths)
56 |
57 | @property
58 | def my_at_R(self,):
59 | if not hasattr(self, '_at_R'):
60 | self._at_R = max(Counter(self.labels).values())
61 | return self._at_R
62 |
63 | def get_instance_dict(self,):
64 | self.instance_dict = {cl: [] for cl in set(self.labels)}
65 | for idx, cl in enumerate(self.labels):
66 | self.instance_dict[cl].append(idx)
67 |
68 | def get_super_dict(self,):
69 | if hasattr(self, 'super_labels') and self.super_labels is not None:
70 | self.super_dict = {ct: {} for ct in set(self.super_labels)}
71 | for idx, cl, ct in zip(range(len(self.labels)), self.labels, self.super_labels):
72 | try:
73 | self.super_dict[ct][cl].append(idx)
74 | except KeyError:
75 | self.super_dict[ct][cl] = [idx]
76 |
77 | def simple_get(self, idx):
78 | pth = self.paths[idx]
79 | img = Image.open(pth).convert('RGB')
80 | if self.transform:
81 | img = self.transform(img)
82 |
83 | label = self.labels[idx]
84 | label = torch.tensor([label])
85 | out = {"image": img, "label": label, "path": pth}
86 |
87 | if hasattr(self, 'super_labels') and self.super_labels is not None:
88 | super_label = self.super_labels[idx]
89 | super_label = torch.tensor([super_label])
90 | out['super_label'] = super_label
91 |
92 | return out
93 |
94 | def multiple_crop_get(self, idx):
95 | pth = self.paths[idx]
96 | image = Image.open(pth).convert('RGB')
97 | multi_crops = list(map(lambda trans: trans(image), self.trans))
98 |
99 | if self.return_label == 'real':
100 | label = self.labels[idx]
101 | labels = [label] * len(multi_crops)
102 | return {"image": multi_crops, "label": labels, "path": pth}
103 |
104 | if self.return_label == 'hash':
105 | label = abs(hash(pth))
106 | labels = [label] * len(multi_crops)
107 | return {"image": multi_crops, "label": labels, "path": pth}
108 |
109 | return {"image": multi_crops, "path": pth}
110 |
111 | def __getitem__(self, idx):
112 | return self.get_fn(idx)
113 |
114 | def __repr__(self,):
115 | return f"{self.__class__.__name__}(mode={self.mode}, len={len(self)})"
116 |
117 |
118 | class PILRandomGaussianBlur(object):
119 | """
120 | Apply Gaussian Blur to the PIL image. Take the radius and probability of
121 | application as the parameter.
122 | This transform was used in SimCLR - https://arxiv.org/abs/2002.05709
123 | """
124 |
125 | def __init__(self, p=0.5, radius_min=0.1, radius_max=2.):
126 | self.prob = p
127 | self.radius_min = radius_min
128 | self.radius_max = radius_max
129 |
130 | def __call__(self, img):
131 | do_it = np.random.rand() <= self.prob
132 | if not do_it:
133 | return img
134 |
135 | return img.filter(
136 | ImageFilter.GaussianBlur(
137 | radius=random.uniform(self.radius_min, self.radius_max)
138 | )
139 | )
140 |
141 |
142 | def get_color_distortion(s=1.0):
143 | # s is the strength of color distortion.
144 | color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)
145 | rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
146 | rnd_gray = transforms.RandomGrayscale(p=0.2)
147 | color_distort = transforms.Compose([rnd_color_jitter, rnd_gray])
148 | return color_distort
149 |
--------------------------------------------------------------------------------
/roadmap/engine/evaluate.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from pytorch_metric_learning import testers
5 | import pytorch_metric_learning.utils.common_functions as c_f
6 | from tqdm import tqdm
7 |
8 | import roadmap.utils as lib
9 | from .accuracy_calculator import get_accuracy_calculator
10 |
11 |
12 | class GlobalEmbeddingSpaceTester(testers.GlobalEmbeddingSpaceTester):
13 |
14 | def label_levels_to_evaluate(self, query_labels):
15 | num_levels_available = query_labels.shape[1]
16 | if self.label_hierarchy_level == "all":
17 | return range(num_levels_available)
18 | elif isinstance(self.label_hierarchy_level, int):
19 | assert self.label_hierarchy_level < num_levels_available
20 | return [self.label_hierarchy_level]
21 | elif c_f.is_list_or_tuple(self.label_hierarchy_level):
22 | # assert max(self.label_hierarchy_level) < num_levels_available
23 | return self.label_hierarchy_level
24 |
25 | def compute_all_embeddings(self, dataloader, trunk_model, embedder_model):
26 | s, e = 0, 0
27 | with torch.no_grad():
28 | lib.LOGGER.info("Computing embeddings")
29 | # added the option of disabling TQDM
30 | for i, data in enumerate(tqdm(dataloader, disable=os.getenv('TQDM_DISABLE'))):
31 | img, label = self.data_and_label_getter(data)
32 | label = c_f.process_label(label, "all", self.label_mapper)
33 | q = self.get_embeddings_for_eval(trunk_model, embedder_model, img)
34 | if label.dim() == 1:
35 | label = label.unsqueeze(1)
36 | if i == 0:
37 | labels = torch.zeros(
38 | len(dataloader.dataset),
39 | label.size(1),
40 | device=self.data_device,
41 | dtype=label.dtype,
42 | )
43 | all_q = torch.zeros(
44 | len(dataloader.dataset),
45 | q.size(1),
46 | device=self.data_device,
47 | dtype=q.dtype,
48 | )
49 | e = s + q.size(0)
50 | all_q[s:e] = q
51 | labels[s:e] = label
52 | s = e
53 | return all_q, labels
54 |
55 |
56 | def get_tester(
57 | normalize_embeddings=False,
58 | batch_size=64,
59 | num_workers=16,
60 | pca=None,
61 | exclude_ranks=None,
62 | k=2047,
63 | **kwargs,
64 | ):
65 | calculator = get_accuracy_calculator(
66 | exclude_ranks=exclude_ranks,
67 | k=k,
68 | **kwargs,
69 | )
70 |
71 | return GlobalEmbeddingSpaceTester(
72 | normalize_embeddings=normalize_embeddings,
73 | data_and_label_getter=lambda batch: (batch["image"], batch["label"]),
74 | batch_size=batch_size,
75 | dataloader_num_workers=num_workers,
76 | accuracy_calculator=calculator,
77 | data_device=None,
78 | pca=pca,
79 | )
80 |
81 |
82 | @lib.get_set_random_state
83 | def evaluate(
84 | net,
85 | train_dataset=None,
86 | val_dataset=None,
87 | test_dataset=None,
88 | epoch=None,
89 | tester=None,
90 | custom_eval=None,
91 | **kwargs
92 | ):
93 | at_R = 0
94 |
95 | dataset_dict = {}
96 | splits_to_eval = []
97 | if train_dataset is not None:
98 | dataset_dict["train"] = train_dataset
99 | splits_to_eval.append(('train', ['train']))
100 | at_R = max(at_R, train_dataset.my_at_R)
101 |
102 | if val_dataset is not None:
103 | dataset_dict["val"] = val_dataset
104 | splits_to_eval.append(('val', ['val']))
105 | at_R = max(at_R, val_dataset.my_at_R)
106 |
107 | if test_dataset is not None:
108 | if isinstance(test_dataset, dict):
109 | if 'gallery' in test_dataset:
110 | dataset_dict.update(test_dataset)
111 | splits_to_eval.append(('test', ['gallery']))
112 | at_R = max(at_R, test_dataset['test'].my_at_R, test_dataset['gallery'].my_at_R)
113 | elif 'distractor' in test_dataset:
114 | dataset_dict.update(test_dataset)
115 | splits_to_eval.append(('test', ['test', 'distractor']))
116 | at_R = max(at_R, test_dataset['test'].my_at_R, test_dataset['distractor'].my_at_R)
117 | elif isinstance(test_dataset, list):
118 | for dts in test_dataset:
119 | dataset_dict.update(dts)
120 | names = list(dts.keys())
121 | at_R = max(at_R, list(dts.values())[0].my_at_R, list(dts.values())[1].my_at_R)
122 | splits_to_eval.append((
123 | names[0] if names[0].startswith("query") else names[1],
124 | [names[0] if names[0].startswith("gallery") else names[1]]
125 | ))
126 | else:
127 | dataset_dict["test"] = test_dataset
128 | splits_to_eval.append(('test', ['test']))
129 | at_R = max(at_R, test_dataset.my_at_R)
130 |
131 | if custom_eval is not None:
132 | dataset_dict = custom_eval["dataset"]
133 | splits_to_eval = custom_eval["splits"]
134 |
135 | if tester is None:
136 | # next lines usefull when computing only the mAP@R and small recall values
137 | # if ('k' not in kwargs) and (at_R != 0):
138 | # kwargs["k"] = at_R + 1
139 | tester = get_tester(**kwargs)
140 |
141 | return tester.test(
142 | dataset_dict=dataset_dict,
143 | epoch=f"{epoch}",
144 | trunk_model=net,
145 | splits_to_eval=splits_to_eval,
146 | )
147 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Robust And Decomposable Average Precision for Image Retrieval (NeurIPS 2021)
2 |
3 | This repository contains the source code for our [ROADMAP paper (NeurIPS 2021)](https://arxiv.org/abs/2110.01445).
4 |
5 | 
6 |
7 | ## Use ROADMAP
8 |
9 | ```
10 | python3 -m venv .venv
11 | source .venv/bin/activate
12 | pip install -e .
13 | ```
14 |
15 | [](https://paperswithcode.com/sota/image-retrieval-on-inaturalist?p=robust-and-decomposable-average-precision-for)
16 | [](https://paperswithcode.com/sota/image-retrieval-on-sop?p=robust-and-decomposable-average-precision-for)
17 | [](https://paperswithcode.com/sota/image-retrieval-on-cub-200-2011?p=robust-and-decomposable-average-precision-for)
18 |
19 | ## Datasets
20 |
21 | We use the following datasets for our submission
22 |
23 | - CUB-200-2011 (download link available on this website : http://www.vision.caltech.edu/visipedia/CUB-200.html)
24 | - Stanford Online Products (you can download it here : https://cvgl.stanford.edu/projects/lifted_struct/)
25 | - INaturalist-2018 (obtained from here https://github.com/visipedia/inat_comp/tree/master/2018#Data)
26 |
27 |
28 | ## Run the code
29 |
30 |
31 | SOP
32 |
33 | The following command reproduce our results for Table 4.
34 |
35 | ```
36 | CUDA_VISIBLE_DEVICES=0 python roadmap/single_experiment_runner.py \
37 | 'experience.experiment_name=sop_ROADMAP_${dataset.sampler.kwargs.batch_size}_sota' \
38 | experience.seed=333 \
39 | experience.max_iter=100 \
40 | 'experience.log_dir=${env:HOME}experiments/ROADMAP' \
41 | optimizer=sop \
42 | model=resnet \
43 | transform=sop_big \
44 | dataset=sop \
45 | dataset.sampler.kwargs.batch_size=128 \
46 | dataset.sampler.kwargs.batches_per_super_pair=10 \
47 | loss=roadmap
48 | ```
49 |
50 | With the transformer backbone :
51 |
52 | ```
53 | CUDA_VISIBLE_DEVICES=0 python roadmap/single_experiment_runner.py \
54 | 'experience.experiment_name=sop_ROADMAP_${dataset.sampler.kwargs.batch_size}_DeiT' \
55 | experience.seed=333 \
56 | experience.max_iter=75 \
57 | 'experience.log_dir=${env:HOME}/experiments/ROADMAP' \
58 | optimizer=sop_deit \
59 | model=deit \
60 | transform=sop \
61 | dataset=sop \
62 | dataset.sampler.kwargs.batch_size=128 \
63 | dataset.sampler.kwargs.batches_per_super_pair=10 \
64 | loss=roadmap
65 | ```
66 |
67 |
68 |
69 |
70 | INaturalist
71 |
72 | For ROADMAP sota results:
73 |
74 | ```
75 | CUDA_VISIBLE_DEVICES='0,1,2' python roadmap/single_experiment_runner.py \
76 | 'experience.experiment_name=inat_ROADMAP_${dataset.sampler.kwargs.batch_size}_sota' \
77 | experience.seed=333 \
78 | experience.max_iter=90 \
79 | 'experience.log_dir=experiments/ROADMAP' \
80 | optimizer=inaturalist \
81 | model=resnet \
82 | transform=inaturalist \
83 | dataset=inaturalist \
84 | dataset.sampler.kwargs.batch_size=384 \
85 | loss=roadmap_inat
86 | ```
87 |
88 |
89 |
90 |
91 | CUB-200-2011
92 |
93 | For ROADMAP sota results:
94 |
95 | ```
96 | CUDA_VISIBLE_DEVICES=0 python roadmap/single_experiment_runner.py \
97 | 'experience.experiment_name=cub_ROADMAP_${dataset.sampler.kwargs.batch_size}_sota' \
98 | experience.seed=333 \
99 | experience.max_iter=200 \
100 | 'experience.log_dir=${env:HOME}/experiments/ROADMAP' \
101 | optimizer=cub \
102 | model=resnet_max_ln \
103 | transform=cub_big \
104 | dataset=cub \
105 | dataset.sampler.kwargs.batch_size=128 \
106 | loss=roadmap
107 | ```
108 |
109 | ```
110 | CUDA_VISIBLE_DEVICES=0 python roadmap/single_experiment_runner.py \
111 | 'experience.experiment_name=cub_ROADMAP_${dataset.sampler.kwargs.batch_size}_sota_DeiT' \
112 | experience.seed=333 \
113 | experience.max_iter=150 \
114 | 'experience.log_dir=${env:HOME}/experiments/ROADMAP' \
115 | optimizer=cub_deit \
116 | model=deit \
117 | transform=cub \
118 | dataset=cub \
119 | dataset.sampler.kwargs.batch_size=128 \
120 | loss=roadmap
121 | ```
122 |
123 |
124 |
125 |
126 | The results are not exactly the same as my code changed a bit (for instance the random seed are not the same).
127 |
128 |
129 | ## Contacts
130 |
131 | If you have any questions don't hesitate to create an issue on this repository. Or send me an email at elias.ramzi@lecnam.net.
132 |
133 | Don't hesitate to cite our work:
134 | ```
135 | @inproceedings{
136 | ramzi2021robust,
137 | title={Robust and Decomposable Average Precision for Image Retrieval},
138 | author={Elias Ramzi and Nicolas THOME and Cl{\'e}ment Rambour and Nicolas Audebert and Xavier Bitot},
139 | booktitle={Thirty-Fifth Conference on Neural Information Processing Systems},
140 | year={2021},
141 | url={https://openreview.net/forum?id=VjQw3v3FpJx}
142 | }
143 | ```
144 |
145 |
146 | ## Resources
147 | - Pytorch Metric Learning (PML): https://github.com/KevinMusgrave/pytorch-metric-learning
148 | - SmoothAP: https://github.com/Andrew-Brown1/Smooth_AP
149 | - Blackbox: https://github.com/martius-lab/blackbox-backprop
150 | - FastAP: https://github.com/kunhe/FastAP-metric-learning
151 | - SoftBinAP: https://github.com/naver/deep-image-retrieval
152 | - timm: https://github.com/rwightman/pytorch-image-models
153 | - PyTorch: https://github.com/pytorch/pytorch
154 | - Hydra: https://github.com/facebookresearch/hydra
155 | - Faiss: https://github.com/facebookresearch/faiss
156 | - Ray: https://github.com/ray-project/ray
157 |
--------------------------------------------------------------------------------
/roadmap/run.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os.path import join
3 | import logging
4 | import random
5 | import numpy as np
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch.utils.tensorboard import SummaryWriter
10 |
11 | import roadmap.utils as lib
12 | import roadmap.engine as eng
13 | from roadmap.getter import Getter
14 |
15 |
16 | def run(config, base_config=None, checkpoint_dir=None, splits=None):
17 | """
18 | creates all objects required to launch a training
19 | """
20 | # """""""""""""""""" Handle Config """"""""""""""""""""""""""
21 | if base_config is not None:
22 | log_dir = None
23 | config = lib.override_config(
24 | hyperparameters=config,
25 | config=base_config,
26 | )
27 |
28 | else:
29 | log_dir = lib.expand_path(config.experience.log_dir)
30 | log_dir = join(log_dir, config.experience.experiment_name)
31 | os.makedirs(join(log_dir, 'logs'), exist_ok=True)
32 | os.makedirs(join(log_dir, 'weights'), exist_ok=True)
33 |
34 | # """""""""""""""""" Handle Logging """"""""""""""""""""""""""
35 | logging.basicConfig(
36 | format='%(asctime)s - %(levelname)s - %(message)s',
37 | datefmt='%m/%d/%Y %I:%M:%S',
38 | level=logging.INFO,
39 | )
40 |
41 | if checkpoint_dir is None:
42 | state = None
43 | restore_epoch = 0
44 | else:
45 | lib.LOGGER.info(f"Resuming from state : {checkpoint_dir}")
46 | state = torch.load(checkpoint_dir, map_location='cpu')
47 | restore_epoch = state['epoch']
48 |
49 | if log_dir is None:
50 | from ray import tune
51 | writer = SummaryWriter(join(tune.get_trial_dir(), "logs"), purge_step=restore_epoch)
52 | else:
53 | writer = SummaryWriter(join(log_dir, "logs"), purge_step=restore_epoch)
54 |
55 | lib.LOGGER.info(f"Training with seed {config.experience.seed}")
56 | random.seed(config.experience.seed)
57 | np.random.seed(config.experience.seed)
58 | torch.manual_seed(config.experience.seed)
59 | torch.cuda.manual_seed_all(config.experience.seed)
60 | torch.backends.cudnn.deterministic = True
61 | torch.backends.cudnn.benchmark = False
62 |
63 | getter = Getter()
64 |
65 | # """""""""""""""""" Create Data """"""""""""""""""""""""""
66 | train_transform = getter.get_transform(config.transform.train)
67 | test_transform = getter.get_transform(config.transform.test)
68 | if config.experience.split is not None:
69 | assert isinstance(config.experience.split, int)
70 | dts = getter.get_dataset(None, 'all', config.dataset)
71 | splits = eng.get_splits(
72 | dts.labels, dts.super_labels,
73 | config.experience.kfold,
74 | random_state=config.experience.split_random_state,
75 | with_super_labels=config.experience.with_super_labels)
76 | train_dts = eng.make_subset(dts, splits[config.experience.split]['train'], train_transform, 'train')
77 | test_dts = eng.make_subset(dts, splits[config.experience.split]['val'], test_transform, 'test')
78 | val_dts = None
79 | lib.LOGGER.info(train_dts)
80 | lib.LOGGER.info(test_dts)
81 | else:
82 | train_dts = getter.get_dataset(train_transform, 'train', config.dataset)
83 | test_dts = getter.get_dataset(test_transform, 'test', config.dataset)
84 | val_dts = None
85 |
86 | sampler = getter.get_sampler(train_dts, config.dataset.sampler)
87 |
88 | # """""""""""""""""" Create Network """"""""""""""""""""""""""
89 | net = getter.get_model(config.model)
90 |
91 | scaler = None
92 | if config.model.kwargs.with_autocast:
93 | scaler = torch.cuda.amp.GradScaler()
94 | if checkpoint_dir:
95 | scaler.load_state_dict(state['scaler_state'])
96 |
97 | if checkpoint_dir:
98 | net.load_state_dict(state['net_state'])
99 | net.cuda()
100 |
101 | # """""""""""""""""" Create Optimizer & Scheduler """"""""""""""""""""""""""
102 | optimizer, scheduler = getter.get_optimizer(net, config.optimizer)
103 |
104 | if checkpoint_dir:
105 | for key, opt in optimizer.items():
106 | opt.load_state_dict(state['optimizer_state'][key])
107 |
108 | if config.experience.force_lr is not None:
109 | _ = [lib.set_lr(opt, config.experience.force_lr) for opt in optimizer.values()]
110 | lib.LOGGER.info(optimizer)
111 |
112 | if checkpoint_dir:
113 | for key, schs in scheduler.items():
114 | for sch, sch_state in zip(schs, state[f'scheduler_{key}_state']):
115 | sch.load_state_dict(sch_state)
116 |
117 | # """""""""""""""""" Create Criterion """"""""""""""""""""""""""
118 | criterion = getter.get_loss(config.loss)
119 |
120 | # """""""""""""""""" Create Memory """"""""""""""""""""""""""
121 | memory = None
122 | if config.memory.name is not None:
123 | lib.LOGGER.info("Using cross batch memory")
124 | memory = getter.get_memory(config.memory)
125 | memory.cuda()
126 |
127 | # """""""""""""""""" Handle Cuda """"""""""""""""""""""""""
128 | if torch.cuda.device_count() > 1:
129 | lib.LOGGER.info("Model is parallelized")
130 | net = nn.DataParallel(net)
131 |
132 | net.cuda()
133 | _ = [crit.cuda() for crit, _ in criterion]
134 |
135 | # """""""""""""""""" Handle RANDOM_STATE """"""""""""""""""""""""""
136 | if state is not None:
137 | # set random NumPy and Torch random states
138 | lib.set_random_state(state)
139 |
140 | return eng.train(
141 | config=config,
142 | log_dir=log_dir,
143 | net=net,
144 | criterion=criterion,
145 | optimizer=optimizer,
146 | scheduler=scheduler,
147 | scaler=scaler,
148 | memory=memory,
149 | train_dts=train_dts,
150 | val_dts=val_dts,
151 | test_dts=test_dts,
152 | sampler=sampler,
153 | writer=writer,
154 | restore_epoch=restore_epoch,
155 | )
156 |
157 |
158 | if __name__ == '__main__':
159 | run()
160 |
--------------------------------------------------------------------------------
/roadmap/engine/train.py:
--------------------------------------------------------------------------------
1 | import random
2 | from time import time
3 |
4 | import numpy as np
5 | import torch
6 | from torch.utils.data import DataLoader
7 |
8 | import roadmap.utils as lib
9 | from .base_update import base_update
10 | from .evaluate import evaluate
11 | from .landmark_evaluation import landmark_evaluation
12 | from . import checkpoint
13 |
14 |
15 | def train(
16 | config,
17 | log_dir,
18 | net,
19 | criterion,
20 | optimizer,
21 | scheduler,
22 | scaler,
23 | memory,
24 | train_dts,
25 | val_dts,
26 | test_dts,
27 | sampler,
28 | writer,
29 | restore_epoch,
30 | ):
31 | # """""""""""""""""" Iter over epochs """"""""""""""""""""""""""
32 | lib.LOGGER.info(f"Training of model {config.experience.experiment_name}")
33 | best_score = 0.
34 | best_model = None
35 |
36 | metrics = None
37 | for e in range(1 + restore_epoch, config.experience.max_iter + 1):
38 |
39 | lib.LOGGER.info(f"Training : @epoch #{e} for model {config.experience.experiment_name}")
40 | start_time = time()
41 |
42 | # """""""""""""""""" Training Loop """"""""""""""""""""""""""
43 | sampler.reshuffle()
44 | loader = DataLoader(
45 | train_dts,
46 | batch_sampler=sampler,
47 | num_workers=config.experience.num_workers,
48 | pin_memory=config.experience.pin_memory,
49 | )
50 | logs = base_update(
51 | config=config,
52 | net=net,
53 | loader=loader,
54 | criterion=criterion,
55 | optimizer=optimizer,
56 | scheduler=scheduler,
57 | scaler=scaler,
58 | epoch=e,
59 | memory=memory,
60 | )
61 |
62 | for sch in scheduler["on_epoch"]:
63 | sch.step()
64 |
65 | end_train_time = time()
66 |
67 | dataset_dict = {}
68 | if (config.experience.train_eval_freq > -1) and ((e % config.experience.train_eval_freq == 0) or (e == config.experience.max_iter)):
69 | dataset_dict["train_dataset"] = train_dts
70 |
71 | if (config.experience.val_eval_freq > -1) and ((e % config.experience.val_eval_freq == 0) or (e == config.experience.max_iter)):
72 | dataset_dict["val_dataset"] = val_dts
73 |
74 | if (config.experience.test_eval_freq > -1) and ((e % config.experience.test_eval_freq == 0) or (e == config.experience.max_iter)):
75 | dataset_dict["test_dataset"] = test_dts
76 |
77 | metrics = None
78 | if dataset_dict:
79 | RANDOM_STATE = random.getstate()
80 | NP_STATE = np.random.get_state()
81 | TORCH_STATE = torch.random.get_rng_state()
82 | TORCH_CUDA_STATE = torch.cuda.get_rng_state_all()
83 |
84 | lib.LOGGER.info(f"Evaluation : @epoch #{e} for model {config.experience.experiment_name}")
85 | torch.cuda.empty_cache()
86 | if config.experience.landmarks:
87 | metrics = landmark_evaluation(
88 | net=net,
89 | datasets=test_dts,
90 | batch_size=config.experience.eval_bs,
91 | num_workers=config.experience.num_workers,
92 | )
93 | else:
94 | metrics = evaluate(
95 | net,
96 | epoch=e,
97 | batch_size=config.experience.eval_bs,
98 | num_workers=config.experience.num_workers,
99 | with_AP=config.experience.with_AP,
100 | **dataset_dict,
101 | )
102 | torch.cuda.empty_cache()
103 |
104 | random.setstate(RANDOM_STATE)
105 | np.random.set_state(NP_STATE)
106 | torch.random.set_rng_state(TORCH_STATE)
107 | torch.cuda.set_rng_state_all(TORCH_CUDA_STATE)
108 |
109 | # """""""""""""""""" Evaluate Model """"""""""""""""""""""""""
110 | score = None
111 | if metrics is not None:
112 | score = metrics[config.experience.eval_split][config.experience.principal_metric]
113 | if score > best_score:
114 | best_model = f"epoch_{e}"
115 | best_score = score
116 |
117 | if log_dir is None:
118 | from ray import tune
119 | tune.report(**metrics[config.experience.eval_split])
120 |
121 | for sch, key in scheduler["on_val"]:
122 | sch.step(metrics[config.experience.eval_split][key])
123 |
124 | # """""""""""""""""" Logging Step """"""""""""""""""""""""""
125 | for grp, opt in optimizer.items():
126 | writer.add_scalar(f"LR/{grp}", list(lib.get_lr(opt).values())[0], e)
127 |
128 | for k, v in logs.items():
129 | lib.LOGGER.info(f"{k} : {v:.4f}")
130 | writer.add_scalar(f"Train/{k}", v, e)
131 |
132 | if metrics is not None:
133 | for split, mtrc in metrics.items():
134 | for k, v in mtrc.items():
135 | if k == 'epoch':
136 | continue
137 | lib.LOGGER.info(f"{split} --> {k} : {np.around(v*100, decimals=2)}")
138 | writer.add_scalar(f"{split.title()}/Evaluation/{k}", v, e)
139 | print()
140 |
141 | end_time = time()
142 |
143 | elapsed_time = lib.format_time(end_time - start_time)
144 | elapsed_time_train = lib.format_time(end_train_time - start_time)
145 | elapsed_time_eval = lib.format_time(end_time - end_train_time)
146 |
147 | lib.LOGGER.info(f"Epoch took : {elapsed_time}")
148 | lib.LOGGER.info(f"Training loop took : {elapsed_time_train}")
149 | if metrics is not None:
150 | lib.LOGGER.info(f"Evaluation step took : {elapsed_time_eval}")
151 |
152 | print()
153 | print()
154 |
155 | # """""""""""""""""" Checkpointing """"""""""""""""""""""""""
156 | checkpoint(
157 | log_dir=log_dir,
158 | save_checkpoint=(e % config.experience.save_model == 0),
159 | net=net,
160 | optimizer=optimizer,
161 | scheduler=scheduler,
162 | scaler=scaler,
163 | epoch=e,
164 | seed=config.experience.seed,
165 | args=config,
166 | score=score,
167 | best_model=best_model,
168 | best_score=best_score,
169 | )
170 |
171 | return metrics
172 |
--------------------------------------------------------------------------------
/roadmap/engine/cross_validation_splits.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import numpy as np
4 | from sklearn.model_selection import StratifiedKFold
5 |
6 | import roadmap.utils as lib
7 |
8 |
9 | @lib.get_set_random_state
10 | def get_class_disjoint_splits(labels, kfold, random_state=None):
11 | if random_state is not None:
12 | random.seed(random_state)
13 |
14 | unique_labels = list(set(labels))
15 | n = len(unique_labels)
16 | random.shuffle(unique_labels)
17 |
18 | classes = []
19 | for i in range(kfold):
20 | num_to_sample = n // kfold + (0 if i > n % kfold else 1)
21 | classes.append(set(unique_labels[:num_to_sample]))
22 | del unique_labels[:num_to_sample]
23 |
24 | indexes = [[] for _ in range(kfold)]
25 | for idx, cl in enumerate(labels):
26 | for i in range(kfold):
27 | if cl in classes[i]:
28 | indexes[i].append(idx)
29 | break
30 |
31 | splits = [{'train': [], 'val': []} for _ in range(kfold)]
32 | for i in range(kfold):
33 | tmp = indexes.copy()
34 | splits[i]['val'] = tmp[i]
35 | del tmp[i]
36 | splits[i]['train'] = [idx for sublist in tmp for idx in sublist]
37 |
38 | return splits
39 |
40 |
41 | @lib.get_set_random_state
42 | def get_hierarchical_class_disjoint_splits(labels, super_labels, kfold, random_state=None):
43 | if random_state is not None:
44 | random.seed(random_state)
45 |
46 | unique_super_labels = sorted(set(super_labels))
47 | super_dict = {slb: set() for slb in unique_super_labels}
48 | for slb, lb in zip(super_labels, labels):
49 | super_dict[slb].add(lb)
50 |
51 | super_dict = {slb: sorted(lb) for slb, lb in super_dict.items()}
52 | for slb in unique_super_labels:
53 | _ = random.shuffle(super_dict[slb])
54 | classes = [[] for _ in range(kfold)]
55 | for slb in unique_super_labels:
56 | unique_labels = super_dict[slb].copy()
57 | n = len(unique_labels)
58 | for i in range(kfold):
59 | num_to_sample = n // kfold + (0 if i > n % kfold else 1)
60 | classes[i].extend(unique_labels[:num_to_sample])
61 | del unique_labels[:num_to_sample]
62 |
63 | indexes = [[] for _ in range(kfold)]
64 | for idx, cl in enumerate(labels):
65 | for i in range(kfold):
66 | if cl in classes[i]:
67 | indexes[i].append(idx)
68 | break
69 |
70 | splits = [{'train': [], 'val': []} for _ in range(kfold)]
71 | for i in range(kfold):
72 | tmp = indexes.copy()
73 | splits[i]['val'] = tmp[i]
74 | del tmp[i]
75 | splits[i]['train'] = [idx for sublist in tmp for idx in sublist]
76 |
77 | return splits
78 |
79 |
80 | @lib.get_set_random_state
81 | def get_closed_set_splits(labels, kfold, random_state=None):
82 | split_generator = StratifiedKFold(n_splits=kfold, shuffle=True, random_state=random_state)
83 |
84 | splits = [{'train': [], 'val': []} for _ in range(kfold)]
85 | for i, (train_index, test_index) in enumerate(split_generator.split(np.zeros(len(labels)), labels)):
86 | splits[i]['train'] = train_index
87 | splits[i]['val'] = test_index
88 |
89 | return splits
90 |
91 |
92 | def get_splits(labels, super_labels, kfold, random_state=None, with_super_labels=False, open_set=True):
93 | if open_set:
94 | if super_labels is None:
95 | return get_class_disjoint_splits(labels, kfold, random_state)
96 | elif with_super_labels:
97 | return get_class_disjoint_splits(super_labels, kfold, random_state)
98 | else:
99 | return get_hierarchical_class_disjoint_splits(labels, super_labels, kfold, random_state)
100 | else:
101 | return get_closed_set_splits(labels, kfold, random_state)
102 |
103 |
104 | if __name__ == '__main__':
105 | random.seed(10)
106 | num_superlabels = 12
107 | dataset_size = 15023
108 | num_labels_per_super_labels = 17
109 | num_labels = num_superlabels * num_labels_per_super_labels
110 | kfold = 4
111 | labels = [random.randint(0, num_labels - 1) for _ in range(dataset_size)]
112 | splits = get_class_disjoint_splits(labels, kfold)
113 |
114 | for spl in splits:
115 | label_train = set([labels[x] for x in spl['train']])
116 | label_val = set([labels[x] for x in spl['val']])
117 | assert (len(set(spl['train'])) + len(set(spl['val']))) == dataset_size
118 | assert not label_train.intersection(label_val)
119 |
120 | num_superlabels = 12
121 | dataset_size = 15023
122 | num_labels_per_super_labels = 17
123 | num_labels = num_superlabels * num_labels_per_super_labels
124 | labels = [random.randint(0, num_labels - 1) for _ in range(dataset_size)]
125 | super_labels = [0] * len(labels)
126 | for slb in range(num_superlabels):
127 | mask = []
128 | for idx, lb in enumerate(labels):
129 | if (lb >= slb * num_labels_per_super_labels) & (lb < (slb + 1) * num_labels_per_super_labels):
130 | mask.append(idx)
131 |
132 | for idx in mask:
133 | super_labels[idx] = slb
134 |
135 | assert len(set(zip(super_labels, labels))) == num_labels
136 |
137 | h_splits_1 = get_hierarchical_class_disjoint_splits(labels, super_labels, kfold, random_state=1)
138 | h_splits_1_p = get_hierarchical_class_disjoint_splits(labels, super_labels, kfold, random_state=1)
139 | h_splits_2 = get_hierarchical_class_disjoint_splits(labels, super_labels, kfold, random_state=2)
140 | # import ipdb; ipdb.set_trace()
141 | for spl in h_splits_1:
142 | label_train = set([labels[x] for x in spl['train']])
143 | label_val = set([labels[x] for x in spl['val']])
144 | assert (len(set(spl['train'])) + len(set(spl['val']))) == dataset_size
145 | assert not label_train.intersection(label_val)
146 |
147 | for spl in h_splits_2:
148 | label_train = set([labels[x] for x in spl['train']])
149 | label_val = set([labels[x] for x in spl['val']])
150 | assert (len(set(spl['train'])) + len(set(spl['val']))) == dataset_size
151 | assert not label_train.intersection(label_val)
152 |
153 | for spl_1, spl_1_p, spl_2 in zip(h_splits_1, h_splits_1_p, h_splits_2):
154 | assert set(spl_1['train']) == set(spl_1_p['train'])
155 | assert set(spl_1['val']) == set(spl_1_p['val'])
156 |
157 | assert set(spl_1['train']) != set(spl_2['train'])
158 | assert set(spl_1['val']) != set(spl_2['val'])
159 |
--------------------------------------------------------------------------------
/roadmap/engine/accuracy_calculator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pytorch_metric_learning.utils.common_functions as c_f
3 | from pytorch_metric_learning.utils.accuracy_calculator import (
4 | AccuracyCalculator,
5 | get_label_match_counts,
6 | get_lone_query_labels,
7 | )
8 |
9 | import roadmap.utils as lib
10 | from .get_knn import get_knn
11 |
12 |
13 | EQUALITY = torch.eq
14 |
15 |
16 | class CustomCalculator(AccuracyCalculator):
17 |
18 | def __init__(
19 | self,
20 | *args,
21 | with_faiss=True,
22 | **kwargs
23 | ):
24 | super().__init__(*args, **kwargs)
25 | self.with_faiss = with_faiss
26 |
27 | def recall_at_k(self, knn_labels, query_labels, k):
28 | recall = self.label_comparison_fn(query_labels, knn_labels[:, :k])
29 | return recall.any(1).float().mean().item()
30 |
31 | def calculate_recall_at_1(self, knn_labels, query_labels, **kwargs):
32 | return self.recall_at_k(
33 | knn_labels,
34 | query_labels[:, None],
35 | 1,
36 | )
37 |
38 | def calculate_recall_at_2(self, knn_labels, query_labels, **kwargs):
39 | return self.recall_at_k(
40 | knn_labels,
41 | query_labels[:, None],
42 | 2,
43 | )
44 |
45 | def calculate_recall_at_4(self, knn_labels, query_labels, **kwargs):
46 | return self.recall_at_k(
47 | knn_labels,
48 | query_labels[:, None],
49 | 4,
50 | )
51 |
52 | def calculate_recall_at_8(self, knn_labels, query_labels, **kwargs):
53 | return self.recall_at_k(
54 | knn_labels,
55 | query_labels[:, None],
56 | 8,
57 | )
58 |
59 | def calculate_recall_at_10(self, knn_labels, query_labels, **kwargs):
60 | return self.recall_at_k(
61 | knn_labels,
62 | query_labels[:, None],
63 | 10,
64 | )
65 |
66 | def calculate_recall_at_16(self, knn_labels, query_labels, **kwargs):
67 | return self.recall_at_k(
68 | knn_labels,
69 | query_labels[:, None],
70 | 16,
71 | )
72 |
73 | def calculate_recall_at_20(self, knn_labels, query_labels, **kwargs):
74 | return self.recall_at_k(
75 | knn_labels,
76 | query_labels[:, None],
77 | 20,
78 | )
79 |
80 | def calculate_recall_at_30(self, knn_labels, query_labels, **kwargs):
81 | return self.recall_at_k(
82 | knn_labels,
83 | query_labels[:, None],
84 | 30,
85 | )
86 |
87 | def calculate_recall_at_32(self, knn_labels, query_labels, **kwargs):
88 | return self.recall_at_k(
89 | knn_labels,
90 | query_labels[:, None],
91 | 32,
92 | )
93 |
94 | def calculate_recall_at_100(self, knn_labels, query_labels, **kwargs):
95 | return self.recall_at_k(
96 | knn_labels,
97 | query_labels[:, None],
98 | 100,
99 | )
100 |
101 | def calculate_recall_at_1000(self, knn_labels, query_labels, **kwargs):
102 | return self.recall_at_k(
103 | knn_labels,
104 | query_labels[:, None],
105 | 1000,
106 | )
107 |
108 | def requires_knn(self):
109 | return super().requires_knn() + ["recall_classic"]
110 |
111 | def get_accuracy(
112 | self,
113 | query,
114 | reference,
115 | query_labels,
116 | reference_labels,
117 | embeddings_come_from_same_source,
118 | include=(),
119 | exclude=(),
120 | return_indices=False,
121 | ):
122 | [query, reference, query_labels, reference_labels] = [
123 | c_f.numpy_to_torch(x)
124 | for x in [query, reference, query_labels, reference_labels]
125 | ]
126 |
127 | self.curr_function_dict = self.get_function_dict(include, exclude)
128 |
129 | kwargs = {
130 | "query": query,
131 | "reference": reference,
132 | "query_labels": query_labels,
133 | "reference_labels": reference_labels,
134 | "embeddings_come_from_same_source": embeddings_come_from_same_source,
135 | "label_comparison_fn": self.label_comparison_fn,
136 | }
137 |
138 | if any(x in self.requires_knn() for x in self.get_curr_metrics()):
139 | label_counts = get_label_match_counts(
140 | query_labels, reference_labels, self.label_comparison_fn,
141 | )
142 |
143 | lone_query_labels, not_lone_query_mask = get_lone_query_labels(
144 | query_labels,
145 | label_counts,
146 | embeddings_come_from_same_source,
147 | self.label_comparison_fn,
148 | )
149 |
150 | num_k = self.determine_k(
151 | label_counts[1], len(reference), embeddings_come_from_same_source
152 | )
153 |
154 | # USE OUR OWN KNN SEARCH
155 | knn_indices, knn_distances = get_knn(
156 | reference, query, num_k, embeddings_come_from_same_source,
157 | with_faiss=self.with_faiss,
158 | )
159 | torch.cuda.empty_cache()
160 |
161 | knn_labels = reference_labels[knn_indices]
162 | if not any(not_lone_query_mask):
163 | lib.LOGGER.warning("None of the query labels are in the reference set.")
164 | kwargs["label_counts"] = label_counts
165 | kwargs["knn_labels"] = knn_labels
166 | kwargs["knn_distances"] = knn_distances
167 | kwargs["lone_query_labels"] = lone_query_labels
168 | kwargs["not_lone_query_mask"] = not_lone_query_mask
169 |
170 | if any(x in self.requires_clustering() for x in self.get_curr_metrics()):
171 | kwargs["cluster_labels"] = self.get_cluster_labels(**kwargs)
172 |
173 | if return_indices:
174 | # ADDED
175 | return knn_indices, self._get_accuracy(self.curr_function_dict, **kwargs)
176 | return self._get_accuracy(self.curr_function_dict, **kwargs)
177 |
178 |
179 | def get_accuracy_calculator(
180 | exclude_ranks=None,
181 | k=2047,
182 | with_AP=False,
183 | **kwargs,
184 | ):
185 | exclude = kwargs.pop('exclude', [])
186 | if with_AP:
187 | exclude.extend(['NMI', 'AMI'])
188 | else:
189 | exclude.extend(['NMI', 'AMI', 'mean_average_precision', 'mean_average_precision_at_r'])
190 |
191 | if exclude_ranks:
192 | for r in exclude_ranks:
193 | exclude.append(f'recall_at_{r}')
194 |
195 | return CustomCalculator(
196 | exclude=exclude,
197 | k=k,
198 | **kwargs,
199 | )
200 |
--------------------------------------------------------------------------------
/roadmap/engine/landmark_evaluation.py:
--------------------------------------------------------------------------------
1 | """
2 | Adapted from:
3 | https://github.com/filipradenovic/cnnimageretrieval-pytorch/blob/master/cirtorch/utils/evaluate.py
4 | """
5 | import os
6 |
7 | import numpy as np
8 | import torch
9 | from torch.utils.data import DataLoader
10 | from tqdm import tqdm
11 |
12 | import roadmap.utils as lib
13 |
14 |
15 | def compute_ap(ranks, nres):
16 | """
17 | Computes average precision for given ranked indexes.
18 |
19 | Arguments
20 | ---------
21 | ranks : zerro-based ranks of positive images
22 | nres : number of positive images
23 |
24 | Returns
25 | -------
26 | ap : average precision
27 | """
28 |
29 | # number of images ranked by the system
30 | nimgranks = len(ranks)
31 |
32 | # accumulate trapezoids in PR-plot
33 | ap = 0
34 |
35 | recall_step = 1. / nres
36 |
37 | for j in np.arange(nimgranks):
38 | rank = ranks[j]
39 |
40 | if rank == 0:
41 | precision_0 = 1.
42 | else:
43 | precision_0 = float(j) / rank
44 |
45 | precision_1 = float(j + 1) / (rank + 1)
46 |
47 | ap += (precision_0 + precision_1) * recall_step / 2.
48 |
49 | return ap
50 |
51 |
52 | def compute_map(ranks, gnd, kappas=[]):
53 | """
54 | Computes the mAP for a given set of returned results.
55 | Usage:
56 | map = compute_map (ranks, gnd)
57 | computes mean average precsion (map) only
58 |
59 | map, aps, pr, prs = compute_map (ranks, gnd, kappas)
60 | computes mean average precision (map), average precision (aps) for each query
61 | computes mean precision at kappas (pr), precision at kappas (prs) for each query
62 |
63 | Notes:
64 | 1) ranks starts from 0, ranks.shape = db_size X #queries
65 | 2) The junk results (e.g., the query itself) should be declared in the gnd stuct array
66 | 3) If there are no positive images for some query, that query is excluded from the evaluation
67 | """
68 |
69 | map = 0.
70 | nq = len(gnd) # number of queries
71 | aps = np.zeros(nq)
72 | # pr = np.zeros(len(kappas))
73 | # prs = np.zeros((nq, len(kappas)))
74 | nempty = 0
75 |
76 | for i in np.arange(nq):
77 | qgnd = np.array(gnd[i]['ok'])
78 |
79 | # no positive images, skip from the average
80 | if qgnd.shape[0] == 0:
81 | aps[i] = float('nan')
82 | # prs[i, :] = float('nan')
83 | nempty += 1
84 | continue
85 |
86 | try:
87 | qgndj = np.array(gnd[i]['junk'])
88 | except Exception:
89 | qgndj = np.empty(0)
90 |
91 | # sorted positions of positive and junk images (0 based)
92 | pos = np.arange(ranks.shape[0])[np.in1d(ranks[:, i], qgnd)]
93 | junk = np.arange(ranks.shape[0])[np.in1d(ranks[:, i], qgndj)]
94 |
95 | k = 0
96 | ij = 0
97 | if len(junk):
98 | # decrease positions of positives based on the number of
99 | # junk images appearing before them
100 | ip = 0
101 | while (ip < len(pos)):
102 | while (ij < len(junk) and pos[ip] > junk[ij]):
103 | k += 1
104 | ij += 1
105 | pos[ip] = pos[ip] - k
106 | ip += 1
107 |
108 | # compute ap
109 | ap = compute_ap(pos, len(qgnd))
110 | map = map + ap
111 | aps[i] = ap
112 |
113 | # # compute precision @ k
114 | # pos += 1 # get it to 1-based
115 | # for j in np.arange(len(kappas)):
116 | # kq = min(max(pos), kappas[j])
117 | # prs[i, j] = (pos <= kq).sum() / kq
118 | # pr = pr + prs[i, :]
119 |
120 | map = map / (nq - nempty)
121 | # pr = pr / (nq - nempty)
122 |
123 | return map # , aps, pr, prs
124 |
125 |
126 | def compute_map_M_and_H(ranks, gnd, kappas=[]):
127 |
128 | # gnd_t = []
129 | # for i in range(len(gnd)):
130 | # g = {}
131 | # g['ok'] = np.concatenate([gnd[i]['easy']])
132 | # g['junk'] = np.concatenate([gnd[i]['junk'], gnd[i]['hard']])
133 | # gnd_t.append(g)
134 | # mapE, *_ = compute_map(ranks, gnd_t, kappas)
135 |
136 | gnd_t = []
137 | for i in range(len(gnd)):
138 | g = {}
139 | g['ok'] = np.concatenate([gnd[i]['easy'], gnd[i]['hard']])
140 | g['junk'] = np.concatenate([gnd[i]['junk']])
141 | gnd_t.append(g)
142 | mapM = compute_map(ranks, gnd_t, kappas)
143 |
144 | gnd_t = []
145 | for i in range(len(gnd)):
146 | g = {}
147 | g['ok'] = np.concatenate([gnd[i]['hard']])
148 | g['junk'] = np.concatenate([gnd[i]['junk'], gnd[i]['easy']])
149 | gnd_t.append(g)
150 | mapH = compute_map(ranks, gnd_t, kappas)
151 |
152 | return {"mapM": mapM, "mapH": mapH}
153 |
154 |
155 | def evaluate_a_city(net, query, gallery, batch_size, num_workers):
156 | def collate_fn(batch):
157 | out = {}
158 | out["image"] = torch.stack([b["image"] for b in batch], dim=0)
159 | out["label"] = torch.cat([b["label"] for b in batch])
160 | return out
161 |
162 | features_query = []
163 | features_gallery = []
164 | loader_gallery = DataLoader(gallery, batch_size=batch_size, num_workers=num_workers, pin_memory=True)
165 | loader_query = DataLoader(query, batch_size=batch_size, num_workers=num_workers, pin_memory=True, collate_fn=collate_fn)
166 |
167 | lib.LOGGER.info("Computing embeddings")
168 | for batch in tqdm(loader_gallery, disable=os.getenv('TQDM_DISABLE')):
169 | with torch.no_grad():
170 | X = net(batch["image"].cuda())
171 | features_gallery.append(X)
172 |
173 | features_gallery = torch.cat(features_gallery)
174 |
175 | for batch in tqdm(loader_query, disable=os.getenv('TQDM_DISABLE')):
176 | with torch.no_grad():
177 | X = net(batch["image"].cuda())
178 | features_query.append(X)
179 |
180 | features_query = torch.cat(features_query)
181 |
182 | features_gallery = features_gallery.cpu().numpy()
183 | features_query = features_query.cpu().numpy()
184 |
185 | # search, rank, and print
186 | scores = np.dot(features_gallery, features_query.T)
187 | ranks = np.argsort(-scores, axis=0)
188 |
189 | return compute_map_M_and_H(ranks, query)
190 |
191 |
192 | def landmark_evaluation(
193 | net,
194 | datasets,
195 | batch_size,
196 | num_workers,
197 | ):
198 | metrics = {}
199 | for dts in datasets:
200 | city_name = list(dts.keys())[0].split('_')[-1]
201 | names = list(dts.keys())
202 | dts_ = list(dts.values())
203 | metrics[city_name] = evaluate_a_city(
204 | net=net,
205 | query=dts_[0] if names[0].startswith("query") else dts_[1],
206 | gallery=dts_[0] if names[0].startswith("gallery") else dts_[1],
207 | batch_size=batch_size,
208 | num_workers=num_workers,
209 | )
210 |
211 | return metrics
212 |
--------------------------------------------------------------------------------
/roadmap/losses/smooth_rank_ap.py:
--------------------------------------------------------------------------------
1 | """
2 | inspired from
3 | https://github.com/Andrew-Brown1/Smooth_AP
4 | """
5 | from functools import partial
6 |
7 | import torch
8 | import torch.nn as nn
9 | from tqdm.auto import tqdm
10 |
11 | import roadmap.utils as lib
12 |
13 |
14 | def heaviside(tens, val=1., target=None, general=None):
15 | return torch.heaviside(tens, values=torch.tensor(val, device=tens.device, dtype=tens.dtype))
16 |
17 |
18 | def tau_sigmoid(tensor, tau, target=None, general=None):
19 | """ temperature controlled sigmoid
20 | takes as input a torch tensor (tensor) and passes it
21 | through a sigmoid, controlled by temperature: temp
22 | """
23 | exponent = -tensor / tau
24 | # clamp the input tensor for stability
25 | exponent = 1. + exponent.clamp(-50, 50).exp()
26 | return 1.0 / exponent
27 |
28 |
29 | def step_rank(tens, tau, rho, offset=None, delta=None, start=0.5, target=None, general=False):
30 | target = target.squeeze()
31 | if general:
32 | target = target.view(1, -1).repeat(tens.size(0), 1)
33 | else:
34 | mask = target.unsqueeze(target.ndim - 1).bool()
35 | target = lib.create_label_matrix(target).bool() * mask
36 | pos_mask = (tens > 0).bool()
37 | neg_mask = ~pos_mask
38 |
39 | if isinstance(tau, str):
40 | tau_n, tau_p = tau.split("_")
41 | else:
42 | tau_n = tau_p = tau
43 |
44 | if delta is None:
45 | tens[~target & pos_mask] = rho * tens[~target & pos_mask] + offset
46 | else:
47 | margin_mask = tens > delta
48 | tens[~target & pos_mask & ~margin_mask] = start + tau_sigmoid(tens[~target & pos_mask & ~margin_mask], tau_p).type(tens.dtype)
49 | if offset is None:
50 | offset = tau_sigmoid(torch.tensor([delta], device=tens.device), tau_p).type(tens.dtype) + start
51 | tens[~target & pos_mask & margin_mask] = rho * (tens[~target & pos_mask & margin_mask] - delta) + offset
52 |
53 | tens[~target & neg_mask] = tau_sigmoid(tens[~target & neg_mask], tau_n).type(tens.dtype)
54 |
55 | tens[target] = torch.heaviside(tens[target], values=torch.tensor(1., device=tens.device, dtype=tens.dtype))
56 |
57 | return tens
58 |
59 |
60 | class SmoothRankAP(nn.Module):
61 | def __init__(
62 | self,
63 | rank_approximation,
64 | return_type='1-mAP',
65 | ):
66 | super().__init__()
67 | self.rank_approximation = rank_approximation
68 | self.return_type = return_type
69 | assert return_type in ["1-mAP", "1-AP", "AP", 'mAP']
70 |
71 | def general_forward(self, scores, target, verbose=False):
72 | batch_size = target.size(0)
73 | nb_instances = target.size(1)
74 | device = scores.device
75 |
76 | ap_score = []
77 | mask = (1 - torch.eye(nb_instances, device=device))
78 | iterator = range(batch_size)
79 | if verbose:
80 | iterator = tqdm(iterator, leave=None)
81 | for idx in iterator:
82 | # shape M
83 | query = scores[idx]
84 | pos_mask = target[idx].bool()
85 |
86 | # shape M x M
87 | query = query.view(1, -1) - query[pos_mask].view(-1, 1)
88 | query = self.rank_approximation(query, target=pos_mask, general=True) * mask[pos_mask]
89 |
90 | # shape M
91 | rk = 1 + query.sum(-1)
92 |
93 | # shape M
94 | pos_rk = 1 + (query * pos_mask.view(1, -1)).sum(-1)
95 |
96 | # shape 1
97 | ap = (pos_rk / rk).sum(-1) / pos_mask.sum()
98 | ap_score.append(ap)
99 |
100 | # shape N
101 | ap_score = torch.stack(ap_score)
102 |
103 | return ap_score
104 |
105 | def quick_forward(self, scores, target):
106 | batch_size = target.size(0)
107 | device = scores.device
108 |
109 | # ------ differentiable ranking of all retrieval set ------
110 | # compute the mask which ignores the relevance score of the query to itself
111 | mask = 1.0 - torch.eye(batch_size, device=device).unsqueeze(0)
112 | # compute the relevance scores via cosine similarity of the CNN-produced embedding vectors
113 | # compute the difference matrix
114 | sim_diff = scores.unsqueeze(1) - scores.unsqueeze(1).permute(0, 2, 1)
115 |
116 | # pass through the sigmoid
117 | sim_diff_sigmoid = self.rank_approximation(sim_diff, target=target)
118 |
119 | sim_sg = sim_diff_sigmoid * mask
120 | # compute the rankings
121 | sim_all_rk = torch.sum(sim_sg, dim=-1) + 1
122 |
123 | # ------ differentiable ranking of only positive set in retrieval set ------
124 | # compute the mask which only gives non-zero weights to the positive set
125 | pos_mask = (target - torch.eye(batch_size).to(device))
126 | sim_pos_sg = sim_diff_sigmoid * pos_mask
127 | sim_pos_rk = (torch.sum(sim_pos_sg, dim=-1) + target) * target
128 | # compute the rankings of the positive set
129 |
130 | ap = ((sim_pos_rk / sim_all_rk).sum(1) * (1 / target.sum(1)))
131 | return ap
132 |
133 | def forward(self, scores, target, force_general=False, verbose=False):
134 | assert scores.shape == target.shape
135 | assert len(scores.shape) == 2
136 |
137 | if (scores.size(0) == scores.size(1)) and not force_general:
138 | ap = self.quick_forward(scores, target)
139 | else:
140 | ap = self.general_forward(scores, target, verbose=verbose)
141 |
142 | if self.return_type == 'AP':
143 | return ap
144 | elif self.return_type == 'mAP':
145 | return ap.mean()
146 | elif self.return_type == '1-AP':
147 | return 1 - ap
148 | elif self.return_type == '1-mAP':
149 | return 1 - ap.mean()
150 |
151 | @property
152 | def my_repr(self,):
153 | repr = f"return_type={self.return_type}"
154 | return repr
155 |
156 |
157 | class HeavisideAP(SmoothRankAP):
158 | """here for testing purposes"""
159 |
160 | def __init__(self, **kwargs):
161 | rank_approximation = partial(heaviside)
162 | super().__init__(rank_approximation, **kwargs)
163 |
164 | def extra_repr(self,):
165 | repr = self.my_repr
166 | return repr
167 |
168 |
169 | class SmoothAP(SmoothRankAP):
170 |
171 | def __init__(self, tau=0.01, **kwargs):
172 | rank_approximation = partial(tau_sigmoid, tau=tau)
173 | super().__init__(rank_approximation, **kwargs)
174 | self.tau = tau
175 |
176 | def extra_repr(self,):
177 | repr = f"tau={self.tau}, {self.my_repr}"
178 | return repr
179 |
180 |
181 | class SupAP(SmoothRankAP):
182 |
183 | def __init__(self, tau=0.01, rho=100, offset=None, delta=0.05, start=0.5, **kwargs):
184 | rank_approximation = partial(step_rank, tau=tau, rho=rho, offset=offset, delta=delta, start=start)
185 | super().__init__(rank_approximation, **kwargs)
186 | self.tau = tau
187 | self.rho = rho
188 | self.offset = offset
189 | self.delta = delta
190 |
191 | def extra_repr(self,):
192 | repr = f"tau={self.tau}, rho={self.rho}, offset={self.offset}, delta={self.delta}, {self.my_repr}"
193 | return repr
194 |
--------------------------------------------------------------------------------
/roadmap/models/net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torchvision.models as models
5 | import timm
6 |
7 | import roadmap.utils as lib
8 |
9 | from .create_projection_head import create_projection_head
10 |
11 |
12 | def get_backbone(name, pretrained=True):
13 | if name == 'resnet18':
14 | lib.LOGGER.info("using ResNet-18")
15 | out_dim = 512
16 | backbone = models.resnet18(pretrained=pretrained)
17 | backbone = nn.Sequential(*list(backbone.children())[:-2])
18 | pooling = nn.AdaptiveAvgPool2d(output_size=(1, 1))
19 | elif name == 'resnet50':
20 | lib.LOGGER.info("using ResNet-50")
21 | out_dim = 2048
22 | backbone = models.resnet50(pretrained=pretrained)
23 | backbone = nn.Sequential(*list(backbone.children())[:-2])
24 | pooling = nn.AdaptiveAvgPool2d(output_size=(1, 1))
25 | elif name == 'resnet101':
26 | lib.LOGGER.info("using ResNet-101")
27 | out_dim = 2048
28 | backbone = models.resnet101(pretrained=pretrained)
29 | backbone = nn.Sequential(*list(backbone.children())[:-2])
30 | pooling = nn.AdaptiveAvgPool2d(output_size=(1, 1))
31 | elif name == 'vit':
32 | lib.LOGGER.info("using ViT-S")
33 | out_dim = 768
34 | backbone = timm.create_model('vit_small_patch16_224', pretrained=pretrained)
35 | backbone.reset_classifier(-1)
36 | pooling = nn.Identity()
37 | elif name == 'vit_deit':
38 | lib.LOGGER.info("using DeiT-S")
39 | out_dim = 384
40 | try:
41 | backbone = timm.create_model('vit_deit_small_patch16_224', pretrained=pretrained)
42 | except RuntimeError:
43 | backbone = timm.create_model('deit_small_patch16_224', pretrained=pretrained)
44 | backbone.reset_classifier(-1)
45 | pooling = nn.Identity()
46 | elif name == 'vit_deit_distilled':
47 | lib.LOGGER.info("using DeiT-S distilled")
48 | try:
49 | deit = timm.create_model('vit_deit_small_patch16_224')
50 | deit_distilled = timm.create_model('vit_deit_small_distilled_patch16_224', pretrained=pretrained)
51 | except RuntimeError:
52 | deit = timm.create_model('deit_small_patch16_224')
53 | deit_distilled = timm.create_model('deit_small_distilled_patch16_224', pretrained=pretrained)
54 | deit_distilled.pos_embed = torch.nn.Parameter(torch.cat((deit_distilled.pos_embed[:, :1], deit_distilled.pos_embed[:, 2:]), dim=1))
55 | deit.load_state_dict(deit_distilled.state_dict(), strict=False)
56 | backbone = deit
57 | backbone.reset_classifier(-1)
58 | out_dim = 384
59 | pooling = nn.Identity()
60 | elif name == 'vit_deit_base':
61 | lib.LOGGER.info("using DeiT-B")
62 | out_dim = 768
63 | try:
64 | backbone = timm.create_model('vit_deit_base_patch16_224', pretrained=pretrained)
65 | except RuntimeError:
66 | backbone = timm.create_model('deit_base_patch16_224', pretrained=pretrained)
67 | backbone.reset_classifier(-1)
68 | pooling = nn.Identity()
69 | elif name == 'vit_deit_base_distilled':
70 | lib.LOGGER.info("using DeiT-B distilled")
71 | try:
72 | deit = timm.create_model('deit_base_patch16_224')
73 | deit_distilled = timm.create_model('deit_base_distilled_patch16_224', pretrained=pretrained)
74 | except RuntimeError:
75 | deit = timm.create_model('deit_base_patch16_224')
76 | deit_distilled = timm.create_model('deit_base_distilled_patch16_224', pretrained=pretrained)
77 | deit_distilled.pos_embed = torch.nn.Parameter(torch.cat((deit_distilled.pos_embed[:, :1], deit_distilled.pos_embed[:, 2:]), dim=1))
78 | deit.load_state_dict(deit_distilled.state_dict(), strict=False)
79 | backbone = deit
80 | backbone.reset_classifier(-1)
81 | out_dim = 768
82 | pooling = nn.Identity()
83 | elif name == 'vit_deit_base_384':
84 | lib.LOGGER.info("using DeiT-B 384")
85 | out_dim = 768
86 | try:
87 | backbone = timm.create_model('vit_deit_base_patch16_384', pretrained=pretrained)
88 | except RuntimeError:
89 | backbone = timm.create_model('deit_base_patch16_384', pretrained=pretrained)
90 | backbone.reset_classifier(-1)
91 | pooling = nn.Identity()
92 | elif name == 'vit_deit_base_384_distilled':
93 | lib.LOGGER.info("using DeiT-B 384 distilled")
94 | try:
95 | deit = timm.create_model('deit_base_patch16_384')
96 | deit_distilled = timm.create_model('deit_base_distilled_patch16_384', pretrained=pretrained)
97 | except RuntimeError:
98 | deit = timm.create_model('deit_base_patch16_384')
99 | deit_distilled = timm.create_model('deit_base_distilled_patch16_384', pretrained=pretrained)
100 | deit_distilled.pos_embed = torch.nn.Parameter(torch.cat((deit_distilled.pos_embed[:, :1], deit_distilled.pos_embed[:, 2:]), dim=1))
101 | deit.load_state_dict(deit_distilled.state_dict(), strict=False)
102 | backbone = deit
103 | backbone.reset_classifier(-1)
104 | out_dim = 768
105 | pooling = nn.Identity()
106 | else:
107 | raise ValueError(f"{name} is not recognized")
108 |
109 | return (backbone, pooling, out_dim)
110 |
111 |
112 | class RetrievalNet(nn.Module):
113 |
114 | def __init__(
115 | self,
116 | backbone_name,
117 | embed_dim=512,
118 | norm_features=False,
119 | without_fc=False,
120 | with_autocast=False,
121 | pooling='default',
122 | projection_normalization_layer='none',
123 | pretrained=True,
124 | ):
125 | super().__init__()
126 |
127 | norm_features = lib.str_to_bool(norm_features)
128 | without_fc = lib.str_to_bool(without_fc)
129 | with_autocast = lib.str_to_bool(with_autocast)
130 |
131 | assert isinstance(without_fc, bool)
132 | assert isinstance(norm_features, bool)
133 | assert isinstance(with_autocast, bool)
134 | self.norm_features = norm_features
135 | self.without_fc = without_fc
136 | self.with_autocast = with_autocast
137 | if with_autocast:
138 | lib.LOGGER.info("Using mixed precision")
139 |
140 | self.backbone, default_pooling, out_features = get_backbone(backbone_name, pretrained=pretrained)
141 | if pooling == 'default':
142 | self.pooling = default_pooling
143 | elif pooling == 'none':
144 | self.pooling = nn.Identity()
145 | elif pooling == 'max':
146 | self.pooling = nn.AdaptiveMaxPool2d(output_size=(1, 1))
147 | elif pooling == 'avg':
148 | self.pooling = nn.AdaptiveAvgPool2d(output_size=(1, 1))
149 | lib.LOGGER.info(f"Pooling is {self.pooling}")
150 |
151 | if self.norm_features:
152 | lib.LOGGER.info("Using a LayerNorm layer")
153 | self.standardize = nn.LayerNorm(out_features, elementwise_affine=False)
154 | else:
155 | self.standardize = nn.Identity()
156 |
157 | if not self.without_fc:
158 | self.fc = create_projection_head(out_features, embed_dim, projection_normalization_layer)
159 | lib.LOGGER.info(f"Projection head : \n{self.fc}")
160 | else:
161 | self.fc = nn.Identity()
162 | lib.LOGGER.info("Not using a linear projection layer")
163 |
164 | def forward(self, X):
165 | with torch.cuda.amp.autocast(enabled=self.with_autocast):
166 | X = self.backbone(X)
167 | X = self.pooling(X)
168 |
169 | X = X.view(X.size(0), -1)
170 | X = self.standardize(X)
171 | X = self.fc(X)
172 | X = F.normalize(X, p=2, dim=1)
173 | return X
174 |
--------------------------------------------------------------------------------