├── torch_hub └── __init__.py ├── layers ├── __init__.py ├── transformer_layers.py └── independent_mlp.py ├── engine ├── __init__.py ├── losses │ ├── __init__.py │ ├── pixel_wise_entropy_loss.py │ ├── orthogonality_loss.py │ ├── concentration_loss.py │ ├── equivarance_loss.py │ ├── enforced_presence_loss.py │ ├── total_variation.py │ └── presence_loss.py └── eval_fg_bg.py ├── models ├── __init__.py ├── individual_landmark_convnext.py ├── individual_landmark_resnet.py └── vit_baseline.py ├── utils ├── training_utils │ ├── __init__.py │ ├── snapshot_class.py │ ├── scheduler_params.py │ ├── linear_lr_scheduler.py │ ├── engine_utils.py │ ├── optimizer_params.py │ └── ddp_utils.py ├── data_utils │ ├── __init__.py │ ├── class_balanced_sampler.py │ ├── reversible_affine_transform.py │ ├── class_balanced_distributed_sampler.py │ ├── transform_utils.py │ └── dataset_utils.py ├── __init__.py ├── wandb_params.py ├── get_landmark_coordinates.py ├── misc_utils.py └── visualize_att_maps.py ├── data_sets ├── __init__.py ├── flowers102seg.py ├── imagenet_with_ood_eval.py ├── plantnet.py ├── part_imagenet.py └── celeba.py ├── .gitignore ├── hubconf.py ├── LICENSE ├── evaluation_instructions.md ├── prepare_partimagenet_ood.py ├── train_net.py ├── load_losses.py ├── inference_benchmark_models.py ├── README.md ├── environment.yml ├── evaluate_parts.py ├── load_model.py ├── load_dataset.py └── model_zoo.md /torch_hub/__init__.py: -------------------------------------------------------------------------------- 1 | from .load_pretrained_models import * -------------------------------------------------------------------------------- /layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer_layers import * 2 | from .independent_mlp import * -------------------------------------------------------------------------------- /engine/__init__.py: -------------------------------------------------------------------------------- 1 | from .distributed_trainer_pdisco import * 2 | from .eval_interpretability_nmi_ari_keypoint import * 3 | from .eval_fg_bg import * 4 | from .losses import * -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .individual_landmark_resnet import * 2 | from .individual_landmark_convnext import * 3 | from .vit_baseline import * 4 | from .individual_landmark_vit import * -------------------------------------------------------------------------------- /utils/training_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .linear_lr_scheduler import * 2 | from .ddp_utils import * 3 | from .engine_utils import * 4 | from .optimizer_params import * 5 | from .scheduler_params import * 6 | -------------------------------------------------------------------------------- /data_sets/__init__.py: -------------------------------------------------------------------------------- 1 | from .fg_bird_dataset import * 2 | from .celeba import * 3 | from .imagenet_with_ood_eval import * 4 | from .part_imagenet import * 5 | from .plantnet import * 6 | from .flowers102seg import * -------------------------------------------------------------------------------- /utils/data_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset_utils import * 2 | from .reversible_affine_transform import * 3 | from .transform_utils import * 4 | from .class_balanced_distributed_sampler import * 5 | from .class_balanced_sampler import * -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_utils import * 2 | from .training_utils import * 3 | 4 | from .wandb_params import * 5 | from .visualize_att_maps import * 6 | from .misc_utils import * 7 | from .get_landmark_coordinates import * 8 | 9 | 10 | -------------------------------------------------------------------------------- /engine/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .total_variation import * 2 | from .concentration_loss import * 3 | from .presence_loss import * 4 | from .orthogonality_loss import * 5 | from .equivarance_loss import * 6 | from .enforced_presence_loss import * 7 | from .pixel_wise_entropy_loss import * 8 | -------------------------------------------------------------------------------- /utils/training_utils/snapshot_class.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from dataclasses import dataclass 3 | from typing import Dict, Any, List 4 | 5 | import torch 6 | 7 | 8 | @dataclass 9 | class Snapshot: 10 | model_state: 'OrderedDict[str, torch.Tensor]' 11 | optimizer_state: Dict[str, Any] 12 | scaler_state: Dict[str, Any] 13 | finished_epoch: int 14 | epoch_test_accuracies: List[float] = None 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # editor settings 2 | .idea 3 | .vscode 4 | _darcs 5 | 6 | # compilation and distribution 7 | __pycache__ 8 | _ext 9 | *.pyc 10 | *.pyd 11 | *.so 12 | *.dll 13 | *.egg-info/ 14 | build/ 15 | dist/ 16 | wheels/ 17 | 18 | # pytorch/python/numpy formats 19 | *.pth 20 | *.pkl 21 | *.npy 22 | *.ts 23 | *.pt 24 | 25 | # ipython/jupyter notebooks 26 | *.ipynb 27 | **/.ipynb_checkpoints/ 28 | 29 | # Editor temporaries 30 | *.swn 31 | *.swo 32 | *.swp 33 | *~ 34 | 35 | # Results temporary 36 | *.png 37 | *.txt 38 | *.tsv 39 | wandb/ 40 | exps/ 41 | -------------------------------------------------------------------------------- /engine/losses/pixel_wise_entropy_loss.py: -------------------------------------------------------------------------------- 1 | # This file contains the pixel-wise entropy loss function 2 | import torch 3 | 4 | 5 | def pixel_wise_entropy_loss(maps): 6 | """ 7 | Calculate pixel-wise entropy loss for a feature map 8 | :param maps: Attention map with shape (batch_size, channels, height, width) where channels is the landmark probability 9 | :return: value of the pixel-wise entropy loss 10 | """ 11 | # Calculate entropy for each pixel with numerical stability 12 | entropy = torch.distributions.categorical.Categorical(probs=maps.permute(0, 2, 3, 1).contiguous()).entropy() 13 | # Take the mean of the entropy 14 | return entropy.mean() 15 | -------------------------------------------------------------------------------- /engine/losses/orthogonality_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def orthogonality_loss(all_features): 5 | """ 6 | Calculate orthogonality loss for a feature map 7 | Ref: https://github.com/robertdvdk/part_detection/blob/eec53f2f40602113f74c6c1f60a2034823b0fcaf/train.py#L44 8 | :param all_features: The feature map with shape (batch_size, feature_dim, num_landmarks + 1) 9 | :return: 10 | """ 11 | normed_feature = torch.nn.functional.normalize(all_features, dim=1) 12 | total_landmarks = all_features.shape[-1] 13 | similarity_fg = torch.matmul(normed_feature.permute(0, 2, 1).contiguous(), normed_feature) 14 | similarity_fg = torch.sub(similarity_fg, torch.eye(total_landmarks, device=all_features.device)) 15 | orth_loss = torch.mean(torch.square(similarity_fg)) 16 | return orth_loss 17 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | from torch_hub import (pdiscoformer_cub_k_8, pdiscoformer_cub_k_4, pdiscoformer_cub_k_16, 2 | pdiscoformer_flowers_k_2, pdiscoformer_flowers_k_4, pdiscoformer_flowers_k_8, 3 | pdiscoformer_pimagenet_k_25, pdiscoformer_pimagenet_k_50, pdiscoformer_pimagenet_k_8, 4 | pdiscoformer_pimagenet_seg_k_8, pdiscoformer_pimagenet_seg_k_16, pdiscoformer_pimagenet_seg_k_25, 5 | pdiscoformer_pimagenet_seg_k_41, pdiscoformer_pimagenet_seg_k_50, 6 | pdiscoformer_nabirds_k_4, pdiscoformer_nabirds_k_8, pdiscoformer_nabirds_k_11, 7 | pdisconet_vit_nabirds_k_4, pdisconet_vit_nabirds_k_8, pdisconet_vit_nabirds_k_11, 8 | pdisconet_resnet_nabirds_k_4, pdisconet_resnet_nabirds_k_8, pdisconet_resnet_nabirds_k_11) 9 | 10 | dependencies = ['torch', 'torchvision', 'timm'] 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Ananthu Aniraj, Cassio F.Dantas, Dino Ienco and Diego Marcos 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/wandb_params.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | import copy 3 | 4 | 5 | def init_wandb(args): 6 | wandb.login() 7 | if isinstance(args, dict): 8 | args_dict = args 9 | else: 10 | args_dict = vars(args) 11 | if args["resume_training"]: 12 | if args_dict["wandb_resume_id"] is not None: 13 | run = wandb.init(project=args_dict["wandb_project"], entity=args_dict["wandb_entity"], 14 | job_type=args_dict["job_type"], 15 | group=args_dict["group"], mode=args_dict["wandb_mode"], 16 | config=args_dict, id=args_dict["wandb_resume_id"], resume="must") 17 | else: 18 | raise ValueError("wandb_resume_id is None") 19 | else: 20 | run = wandb.init(project=args_dict["wandb_project"], entity=args_dict["wandb_entity"], 21 | job_type=args_dict["job_type"], 22 | group=args_dict["group"], mode=args_dict["wandb_mode"], 23 | config=args_dict) 24 | return run 25 | 26 | 27 | def get_train_loggers(args): 28 | """Get the train loggers for the experiment""" 29 | train_loggers = [] 30 | if args.wandb: 31 | wandb_logger_settings = copy.deepcopy(vars(args)) 32 | train_loggers.append(wandb_logger_settings) 33 | return train_loggers 34 | -------------------------------------------------------------------------------- /utils/data_utils/class_balanced_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | 4 | 5 | class ClassBalancedRandomSampler(torch.utils.data.Sampler): 6 | """ 7 | A custom sampler that sub-samples a given dataset based on class labels. Based on the RandomSampler class 8 | This is essentially the non-ddp version of ClassBalancedDistributedSampler 9 | Ref: https://github.com/pytorch/pytorch/blob/abe3c55a6a01c5b625eeb4fc9aab1421a5965cd2/torch/utils/data/sampler.py#L117 10 | """ 11 | 12 | def __init__(self, dataset: Dataset, num_samples_per_class=100, seed: int = 0) -> None: 13 | self.dataset = dataset 14 | self.seed = seed 15 | # Calculate the number of samples 16 | self.generator = torch.Generator() 17 | self.generator.manual_seed(self.seed) 18 | self.num_samples_per_class = num_samples_per_class 19 | indices = dataset.generate_class_balanced_indices(self.generator, 20 | num_samples_per_class=num_samples_per_class) 21 | self.num_samples = len(indices) 22 | 23 | def __iter__(self): 24 | # Change seed for every function call 25 | seed = int(torch.empty((), dtype=torch.int64).random_().item()) 26 | self.generator.manual_seed(seed) 27 | indices = self.dataset.generate_class_balanced_indices(self.generator, num_samples_per_class=self.num_samples_per_class) 28 | return iter(indices) 29 | 30 | def __len__(self) -> int: 31 | return self.num_samples 32 | -------------------------------------------------------------------------------- /utils/get_landmark_coordinates.py: -------------------------------------------------------------------------------- 1 | # This file contains the function to generate the center coordinates as tensor for the current net. 2 | import torch 3 | 4 | 5 | def landmark_coordinates(maps, grid_x=None, grid_y=None): 6 | """ 7 | Generate the center coordinates as tensor for the current net. 8 | Modified from: https://github.com/robertdvdk/part_detection/blob/eec53f2f40602113f74c6c1f60a2034823b0fcaf/lib.py#L19 9 | Parameters 10 | ---------- 11 | maps: torch.Tensor 12 | Attention map with shape (batch_size, channels, height, width) where channels is the landmark probability 13 | grid_x: torch.Tensor 14 | The grid x coordinates 15 | grid_y: torch.Tensor 16 | The grid y coordinates 17 | Returns 18 | ---------- 19 | loc_x: Tensor 20 | The centroid x coordinates 21 | loc_y: Tensor 22 | The centroid y coordinates 23 | grid_x: Tensor 24 | grid_y: Tensor 25 | """ 26 | return_grid = False 27 | if grid_x is None or grid_y is None: 28 | return_grid = True 29 | grid_x, grid_y = torch.meshgrid(torch.arange(maps.shape[2]), 30 | torch.arange(maps.shape[3]), indexing='ij') 31 | grid_x = grid_x.unsqueeze(0).unsqueeze(0).contiguous().to(maps.device, non_blocking=True) 32 | grid_y = grid_y.unsqueeze(0).unsqueeze(0).contiguous().to(maps.device, non_blocking=True) 33 | map_sums = maps.sum(3).sum(2).detach() 34 | maps_x = grid_x * maps 35 | maps_y = grid_y * maps 36 | loc_x = maps_x.sum(3).sum(2) / map_sums 37 | loc_y = maps_y.sum(3).sum(2) / map_sums 38 | if return_grid: 39 | return loc_x, loc_y, grid_x, grid_y 40 | else: 41 | return loc_x, loc_y 42 | -------------------------------------------------------------------------------- /engine/losses/concentration_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils.get_landmark_coordinates import landmark_coordinates 3 | 4 | 5 | class ConcentrationLoss(torch.nn.Module): 6 | """ 7 | This class defines the concentration loss. 8 | Modified from: https://github.com/robertdvdk/part_detection/blob/eec53f2f40602113f74c6c1f60a2034823b0fcaf/train.py#L15 9 | """ 10 | 11 | def __init__(self): 12 | super(ConcentrationLoss, self).__init__() 13 | self.grid_x = None 14 | self.grid_y = None 15 | 16 | def forward(self, maps): 17 | """ 18 | Forward function for the concentration loss. 19 | :param maps: Attention map with shape (batch_size, channels, height, width) where channels is the landmark probability 20 | :return: The concentration loss 21 | """ 22 | if self.grid_x is None or self.grid_y is None: 23 | grid_x, grid_y = torch.meshgrid(torch.arange(maps.shape[2]), 24 | torch.arange(maps.shape[3]), indexing='ij') 25 | grid_x = grid_x.unsqueeze(0).unsqueeze(0).contiguous().to(maps.device, non_blocking=True) 26 | grid_y = grid_y.unsqueeze(0).unsqueeze(0).contiguous().to(maps.device, non_blocking=True) 27 | self.grid_x = grid_x 28 | self.grid_y = grid_y 29 | 30 | # Get landmark coordinates 31 | loc_x, loc_y = landmark_coordinates(maps, self.grid_x, self.grid_y) 32 | # Concentration loss 33 | loss_conc_x = ((loc_x.unsqueeze(-1).unsqueeze(-1).contiguous() - self.grid_x) / self.grid_x.shape[-1]) ** 2 34 | loss_conc_y = ((loc_y.unsqueeze(-1).unsqueeze(-1).contiguous() - self.grid_y) / self.grid_y.shape[-2]) ** 2 35 | loss_conc = (loss_conc_x + loss_conc_y) * maps 36 | return loss_conc[:, 0:-1, :, :].mean() 37 | -------------------------------------------------------------------------------- /utils/training_utils/scheduler_params.py: -------------------------------------------------------------------------------- 1 | from timm.scheduler.cosine_lr import CosineLRScheduler 2 | from timm.scheduler.step_lr import StepLRScheduler 3 | from .linear_lr_scheduler import LinearLRScheduler 4 | 5 | 6 | def build_scheduler(args, optimizer): 7 | """ 8 | Function to build the scheduler 9 | :param args: arguments from the command line 10 | :param optimizer: optimizer used for training 11 | :return: scheduler 12 | """ 13 | # initialize scheduler hyperparameters 14 | total_steps = args.epochs 15 | type_lr_schedule = args.scheduler_type 16 | warmup_steps = args.scheduler_warmup_epochs 17 | decay_steps = args.scheduler_step_size 18 | warmup_lr_init = args.warmup_lr 19 | 20 | restart_factor = args.scheduler_restart_factor 21 | gamma = args.scheduler_gamma 22 | 23 | min_lr = args.min_lr 24 | if type_lr_schedule == 'cosine': 25 | return CosineLRScheduler( 26 | optimizer, 27 | t_initial=total_steps, 28 | cycle_decay=restart_factor, 29 | lr_min=min_lr, 30 | warmup_t=warmup_steps, 31 | cycle_limit=args.cosine_cycle_limit, 32 | warmup_lr_init=warmup_lr_init, 33 | t_in_epochs=True 34 | ) 35 | elif type_lr_schedule == 'steplr': 36 | return StepLRScheduler( 37 | optimizer, 38 | decay_t=decay_steps, 39 | decay_rate=gamma, 40 | warmup_t=warmup_steps, 41 | warmup_lr_init=warmup_lr_init, 42 | t_in_epochs=True 43 | ) 44 | elif type_lr_schedule == 'linearlr': 45 | return LinearLRScheduler( 46 | optimizer, 47 | t_initial=total_steps, 48 | lr_min_rate=0.01, 49 | warmup_t=warmup_steps, 50 | warmup_lr_init=warmup_lr_init, 51 | t_in_epochs=True 52 | ) 53 | else: 54 | raise NotImplementedError 55 | 56 | -------------------------------------------------------------------------------- /layers/transformer_layers.py: -------------------------------------------------------------------------------- 1 | # Attention Block with option to return the mean of k over heads from attention 2 | 3 | import torch 4 | from timm.models.vision_transformer import Attention, Block 5 | import torch.nn.functional as F 6 | from typing import Tuple 7 | 8 | 9 | class AttentionWQKVReturn(Attention): 10 | """ 11 | Modifications: 12 | - Return the qkv tensors from the attention 13 | """ 14 | 15 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 16 | B, N, C = x.shape 17 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) 18 | q, k, v = qkv.unbind(0) 19 | q, k = self.q_norm(q), self.k_norm(k) 20 | 21 | if self.fused_attn: 22 | x = F.scaled_dot_product_attention( 23 | q, k, v, 24 | dropout_p=self.attn_drop.p if self.training else 0., 25 | ) 26 | else: 27 | q = q * self.scale 28 | attn = q @ k.transpose(-2, -1) 29 | attn = attn.softmax(dim=-1) 30 | attn = self.attn_drop(attn) 31 | x = attn @ v 32 | 33 | x = x.transpose(1, 2).reshape(B, N, C) 34 | x = self.proj(x) 35 | x = self.proj_drop(x) 36 | return x, torch.stack((q, k, v), dim=0) 37 | 38 | 39 | class BlockWQKVReturn(Block): 40 | """ 41 | Modifications: 42 | - Use AttentionWQKVReturn instead of Attention 43 | - Return the qkv tensors from the attention 44 | """ 45 | 46 | def forward(self, x: torch.Tensor, return_qkv: bool = False) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]: 47 | # Note: this is copied from timm.models.vision_transformer.Block with modifications. 48 | x_attn, qkv = self.attn(self.norm1(x)) 49 | x = x + self.drop_path1(self.ls1(x_attn)) 50 | x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) 51 | if return_qkv: 52 | return x, qkv 53 | else: 54 | return x 55 | -------------------------------------------------------------------------------- /engine/losses/equivarance_loss.py: -------------------------------------------------------------------------------- 1 | # Code for the Equivariance Loss 2 | 3 | import torch 4 | from utils.data_utils.reversible_affine_transform import rigid_transform 5 | 6 | 7 | def equivariance_loss(maps, equiv_maps, source, num_landmarks, translate, angle, scale, shear=0.0): 8 | """ 9 | This function calculates the equivariance loss 10 | Modified from: https://github.com/robertdvdk/part_detection/blob/eec53f2f40602113f74c6c1f60a2034823b0fcaf/train.py#L67 11 | :param maps: Attention map with shape (batch_size, channels, height, width) where channels is the landmark probability 12 | :param equiv_maps: Attention maps for same images after an affine transformation and then passed through the model 13 | :param source: Original mini-batch of images 14 | :param num_landmarks: Number of landmarks/parts 15 | :param translate: Translation parameters for the affine transformation 16 | :param angle: Angle parameter for the affine transformation 17 | :param scale: Scale parameter for the affine transformation 18 | :param shear: Shear parameter for the affine transformation 19 | :return: 20 | """ 21 | 22 | translate = [(t * maps.shape[-1] / source.shape[-1]) for t in translate] 23 | rot_back = rigid_transform(img=equiv_maps, angle=angle, translate=translate, 24 | scale=scale, shear=shear, invert=True) 25 | num_elements_per_map = maps.shape[-2] * maps.shape[-1] 26 | orig_attmap_vector = torch.reshape(maps[:, :-1, :, :], 27 | (-1, num_landmarks, 28 | num_elements_per_map)) 29 | transf_attmap_vector = torch.reshape(rot_back[:, 0:-1, :, :], 30 | (-1, num_landmarks, 31 | num_elements_per_map)) 32 | cos_sim_equiv = torch.nn.functional.cosine_similarity(orig_attmap_vector, 33 | transf_attmap_vector, -1) 34 | loss_equiv = (1 - torch.mean(cos_sim_equiv)) 35 | 36 | return loss_equiv 37 | -------------------------------------------------------------------------------- /utils/training_utils/linear_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Reference:https://github.com/microsoft/Swin-Transformer/blob/main/lr_scheduler.py 2 | import torch 3 | from timm.scheduler.scheduler import Scheduler 4 | 5 | 6 | class LinearLRScheduler(Scheduler): 7 | def __init__(self, 8 | optimizer: torch.optim.Optimizer, 9 | t_initial: int, 10 | lr_min_rate: float, 11 | warmup_t=0, 12 | warmup_lr_init=0., 13 | t_in_epochs=True, 14 | noise_range_t=None, 15 | noise_pct=0.67, 16 | noise_std=1.0, 17 | noise_seed=42, 18 | initialize=True, 19 | ) -> None: 20 | super().__init__( 21 | optimizer, param_group_field="lr", 22 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 23 | initialize=initialize) 24 | 25 | self.t_initial = t_initial 26 | self.lr_min_rate = lr_min_rate 27 | self.warmup_t = warmup_t 28 | self.warmup_lr_init = warmup_lr_init 29 | self.t_in_epochs = t_in_epochs 30 | if self.warmup_t: 31 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 32 | super().update_groups(self.warmup_lr_init) 33 | else: 34 | self.warmup_steps = [1 for _ in self.base_values] 35 | 36 | def _get_lr(self, t): 37 | if t < self.warmup_t: 38 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 39 | else: 40 | t = t - self.warmup_t 41 | total_t = self.t_initial - self.warmup_t 42 | lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values] 43 | return lrs 44 | 45 | def get_epoch_values(self, epoch: int): 46 | if self.t_in_epochs: 47 | return self._get_lr(epoch) 48 | else: 49 | return None 50 | 51 | def get_update_values(self, num_updates: int): 52 | if not self.t_in_epochs: 53 | return self._get_lr(num_updates) 54 | else: 55 | return None 56 | -------------------------------------------------------------------------------- /data_sets/flowers102seg.py: -------------------------------------------------------------------------------- 1 | # Adapted from: https://pytorch.org/vision/stable/_modules/torchvision/datasets/flowers102.html#Flowers102 2 | from torchvision import datasets 3 | import PIL 4 | from typing import Tuple, Any 5 | import cv2 6 | import numpy as np 7 | import torch 8 | 9 | 10 | class Flowers102Seg(datasets.Flowers102): 11 | """ 12 | This class is a subclass of the torchvision.datasets.Flowers102 class that adds the segmentation images to the 13 | __getitem__ method. 14 | """ 15 | 16 | def __init__(self, root, split='train', transform=None, target_transform=None, download=False, mask_transform=None): 17 | """ 18 | Args: 19 | :param root: 20 | :param split: 21 | :param transform: 22 | :param target_transform: 23 | :param download: 24 | :param mask_transform: The transform to apply to the segmentation mask 25 | """ 26 | super().__init__(root, split=split, transform=transform, target_transform=target_transform, download=download) 27 | self._seg_folder = self._base_folder / 'segmim' 28 | self.seg_files = [] 29 | for image_file in self._image_files: 30 | image_name = image_file.name.split('_')[-1] 31 | seg_name = 'segmim_' + image_name 32 | seg_file = self._seg_folder / seg_name 33 | self.seg_files.append(seg_file) 34 | self.mask_transform = mask_transform 35 | self.num_classes = len(set(self._labels)) 36 | 37 | def __getitem__(self, idx: int) -> Tuple[Any, Any, Any]: 38 | image_file, label = self._image_files[idx], self._labels[idx] 39 | image = PIL.Image.open(image_file).convert("RGB") 40 | seg_image = PIL.Image.open(self.seg_files[idx]).convert("RGB") 41 | seg_image = np.array(seg_image) 42 | # Convert RGB to BGR 43 | seg_image = seg_image[:, :, ::-1].copy() 44 | # Convert to binary mask 45 | binary_mask = ((seg_image[:, :, 0] / (seg_image[:, :, 1] + seg_image[:, :, 2] + 1e-6)) > 100).astype(np.uint8) 46 | binary_mask = 1 - binary_mask 47 | if len(np.unique(binary_mask)) > 1: 48 | binary_mask = cv2.medianBlur(binary_mask, 5) 49 | seg_image = torch.as_tensor(binary_mask, dtype=torch.float32).unsqueeze(0) 50 | if self.transform: 51 | image = self.transform(image) 52 | 53 | if self.mask_transform: 54 | seg_image = self.mask_transform(seg_image) 55 | 56 | if self.target_transform: 57 | label = self.target_transform(label) 58 | 59 | seg_image = seg_image.squeeze(0) 60 | 61 | return image, label, seg_image 62 | -------------------------------------------------------------------------------- /data_sets/imagenet_with_ood_eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import glob 4 | from collections import defaultdict 5 | from .classes_mapping_imagenet import imagenet_idx_to_class_names, imagenet_class_names_to_idx, IMAGENET2012_CLASSES 6 | from utils.data_utils.dataset_utils import pil_loader 7 | 8 | 9 | class ImageNetWithOODEval(torch.utils.data.Dataset): 10 | """ 11 | Class to train models on ImageNet with Eval on OOD sets 12 | Variables 13 | base_folder, str: Root directory of the dataset. 14 | image_sub_path, str: Path to the folder containing the images. 15 | transform, callable: A function/transform that takes in a PIL.Image and transforms it. 16 | """ 17 | 18 | def __init__(self, base_folder, image_sub_path, transform=None): 19 | self.class_to_idx = imagenet_class_names_to_idx 20 | self.idx_to_class = imagenet_idx_to_class_names 21 | 22 | self.images_folder = os.path.join(base_folder, image_sub_path) 23 | 24 | self.num_classes = len(imagenet_idx_to_class_names) 25 | 26 | self.classes = list(self.class_to_idx.keys()) 27 | 28 | self.wordnet_to_class_name = IMAGENET2012_CLASSES 29 | 30 | self.transform = transform 31 | 32 | self.loader = pil_loader 33 | 34 | self.image_paths = glob.glob(os.path.join(self.images_folder, "**/*.jpg"), recursive=True) 35 | self.image_paths += glob.glob(os.path.join(self.images_folder, "**/*.jpeg"), recursive=True) 36 | self.image_paths += glob.glob(os.path.join(self.images_folder, "**/*.png"), recursive=True) 37 | self.image_paths += glob.glob(os.path.join(self.images_folder, "**/*.bmp"), recursive=True) 38 | self.image_paths += glob.glob(os.path.join(self.images_folder, "**/*.ppm"), recursive=True) 39 | self.image_paths += glob.glob(os.path.join(self.images_folder, "**/*.JPEG"), recursive=True) 40 | 41 | self.image_paths = sorted(self.image_paths) 42 | self.per_class_count = defaultdict(int) 43 | self.labels = [self.class_to_idx[self.wordnet_to_class_name[os.path.basename(os.path.dirname(image_path))]] for 44 | image_path in self.image_paths] 45 | for label in self.labels: 46 | self.per_class_count[self.idx_to_class[label]] += 1 47 | self.cls_num_list = [self.per_class_count[self.idx_to_class[idx]] for idx in range(self.num_classes)] 48 | 49 | def __len__(self): 50 | return len(self.labels) 51 | 52 | def __getitem__(self, idx): 53 | image_path = self.image_paths[idx] 54 | image = self.loader(image_path) 55 | label = self.labels[idx] 56 | if self.transform is not None: 57 | image = self.transform(image) 58 | return image, label 59 | 60 | -------------------------------------------------------------------------------- /utils/training_utils/engine_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from enum import Enum 3 | from .snapshot_class import Snapshot 4 | 5 | 6 | class Summary(Enum): 7 | NONE = 0 8 | AVERAGE = 1 9 | SUM = 2 10 | COUNT = 3 11 | 12 | 13 | class AverageMeter(object): 14 | """computes and stores the average and current value""" 15 | 16 | def __init__(self): 17 | self.reset() 18 | 19 | def reset(self): 20 | self.val = 0 21 | self.avg = 0 22 | self.sum = 0 23 | self.count = 0 24 | 25 | def update(self, val, n=1): 26 | self.val = val 27 | self.sum += val * n 28 | self.count += n 29 | self.avg = self.sum / self.count 30 | 31 | 32 | def _get_batch_fmtstr(num_batches): 33 | num_digits = len(str(num_batches // 1)) 34 | fmt = '{:' + str(num_digits) + 'd}' 35 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 36 | 37 | 38 | class ProgressMeter(object): 39 | """ 40 | Customized progress meter 41 | Ref: https://github.com/pytorch/examples/blob/main/imagenet/main.py 42 | """ 43 | 44 | def __init__(self, num_batches, meters, prefix=""): 45 | self.batch_fmtstr = _get_batch_fmtstr(num_batches) 46 | self.meters = meters 47 | self.prefix = prefix 48 | 49 | def display(self, batch): 50 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 51 | entries += [str(meter) for meter in self.meters] 52 | print('\t'.join(entries)) 53 | 54 | def display_summary(self): 55 | entries = [" *"] 56 | entries += [meter.summary() for meter in self.meters] 57 | print(' '.join(entries)) 58 | 59 | 60 | def load_state_dict_pdisco(snapshot_data): 61 | """Load state dict of a snapshot. 62 | 63 | Args: 64 | snapshot_data (dict): dictionary containing the state dict of a snapshot 65 | """ 66 | snapshot = Snapshot(**snapshot_data) 67 | state_dict = snapshot.model_state 68 | return snapshot, state_dict 69 | 70 | 71 | def change_key(ordered_dict_obj, old, new): 72 | for _ in range(len(ordered_dict_obj)): 73 | k, v = ordered_dict_obj.popitem(False) 74 | ordered_dict_obj[new if old == k else k] = v 75 | 76 | 77 | def accuracy(output, target, topk=(1,)): 78 | """Computes the accuracy over the k top predictions for the specified values of k""" 79 | with torch.no_grad(): 80 | maxk = max(topk) 81 | batch_size = target.size(0) 82 | 83 | _, pred = output.topk(maxk, 1, True, True) 84 | pred = pred.t() 85 | correct = pred.eq(target.contiguous().view(1, -1).expand_as(pred)) 86 | 87 | res = [] 88 | for k in topk: 89 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) 90 | res.append(correct_k.mul_(100.0 / batch_size)) 91 | return res 92 | -------------------------------------------------------------------------------- /layers/independent_mlp.py: -------------------------------------------------------------------------------- 1 | # This file contains the implementation of the IndependentMLPs class 2 | import torch 3 | 4 | 5 | class IndependentMLPs(torch.nn.Module): 6 | """ 7 | This class implements the MLP used for classification with the option to use an additional independent MLP layer 8 | """ 9 | 10 | def __init__(self, part_dim, latent_dim, bias=False, num_lin_layers=1, act_layer=True, out_dim=None, stack_dim=-1): 11 | """ 12 | 13 | :param part_dim: Number of parts 14 | :param latent_dim: Latent dimension 15 | :param bias: Whether to use bias 16 | :param num_lin_layers: Number of linear layers 17 | :param act_layer: Whether to use activation layer 18 | :param out_dim: Output dimension (default: None) 19 | :param stack_dim: Dimension to stack the outputs (default: -1) 20 | """ 21 | 22 | super().__init__() 23 | 24 | self.bias = bias 25 | self.latent_dim = latent_dim 26 | if out_dim is None: 27 | out_dim = latent_dim 28 | self.out_dim = out_dim 29 | self.part_dim = part_dim 30 | self.stack_dim = stack_dim 31 | 32 | layer_stack = torch.nn.ModuleList() 33 | for i in range(part_dim): 34 | layer_stack.append(torch.nn.Sequential()) 35 | for j in range(num_lin_layers): 36 | layer_stack[i].add_module(f"fc_{j}", torch.nn.Linear(latent_dim, self.out_dim, bias=bias)) 37 | if act_layer: 38 | layer_stack[i].add_module(f"act_{j}", torch.nn.GELU()) 39 | self.feature_layers = layer_stack 40 | self.reset_weights() 41 | 42 | def __repr__(self): 43 | return f"IndependentMLPs(part_dim={self.part_dim}, latent_dim={self.latent_dim}), bias={self.bias}" 44 | 45 | def reset_weights(self): 46 | """ Initialize weights with a identity matrix""" 47 | for layer in self.feature_layers: 48 | for m in layer.modules(): 49 | if isinstance(m, torch.nn.Linear): 50 | # Initialize weights with a truncated normal distribution 51 | torch.nn.init.trunc_normal_(m.weight, std=0.02) 52 | if m.bias is not None: 53 | torch.nn.init.zeros_(m.bias) 54 | 55 | def forward(self, x): 56 | """ Input X has the dimensions batch x latent_dim x part_dim """ 57 | 58 | outputs = [] 59 | for i, layer in enumerate(self.feature_layers): 60 | if self.stack_dim == -1: 61 | in_ = x[..., i] 62 | else: 63 | in_ = x[:, i, ...] # Select feature i 64 | out = layer(in_) # Apply MLP to feature i 65 | outputs.append(out) 66 | 67 | x = torch.stack(outputs, dim=self.stack_dim) # Stack the outputs 68 | 69 | return x 70 | -------------------------------------------------------------------------------- /engine/eval_fg_bg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchmetrics.classification import BinaryJaccardIndex 3 | from tqdm import tqdm 4 | import matplotlib.pyplot as plt 5 | import os 6 | 7 | 8 | def plot_iou_figs(iou_values, iou_values_bg, model_path): 9 | plt.figure() 10 | plt.plot(iou_values) 11 | plt.xlabel('Batch') 12 | plt.ylabel('IoU') 13 | plt.title('Foreground IoU') 14 | plt.ylim(0, 1) 15 | plt.grid() 16 | plt.savefig(os.path.join(os.path.dirname(model_path), 'fg_iou.png'), bbox_inches='tight') 17 | plt.figure() 18 | plt.plot(iou_values_bg) 19 | plt.xlabel('Batch') 20 | plt.ylabel('IoU') 21 | plt.title('Background IoU') 22 | plt.ylim(0, 1) 23 | plt.grid() 24 | plt.savefig(os.path.join(os.path.dirname(model_path), 'bg_iou.png'), bbox_inches='tight') 25 | plt.close() 26 | 27 | 28 | class FgBgIoU: 29 | """ 30 | Class to calculate the IoU for the foreground and background classes 31 | """ 32 | 33 | def __init__(self, model, data_loader, device): 34 | """ 35 | Initialize the class 36 | :param device: Device 37 | :param model: Model 38 | """ 39 | self.metric_fg = BinaryJaccardIndex().to(device, non_blocking=True) 40 | self.metric_bg = BinaryJaccardIndex().to(device, non_blocking=True) 41 | self.model = model 42 | self.device = device 43 | self.num_parts = model.num_landmarks 44 | self.data_loader = data_loader 45 | 46 | def calculate_iou(self, model_path): 47 | """ 48 | Function to calculate the IoU for the foreground class 49 | :return: Foreground IoU 50 | """ 51 | iou_values = [] 52 | iou_values_bg = [] 53 | self.metric_fg.reset() 54 | self.metric_bg.reset() 55 | self.model.eval() 56 | for (img, _, mask) in tqdm(self.data_loader, desc='Testing'): 57 | img = img.to(self.device, non_blocking=True) 58 | mask = mask.to(self.device, non_blocking=True) 59 | with torch.inference_mode(): 60 | assign = self.model(img)[1] 61 | 62 | map_argmax = torch.nn.functional.interpolate(assign, size=(mask.shape[-2], mask.shape[-1]), 63 | mode='bilinear', 64 | align_corners=True).argmax(dim=1) 65 | 66 | map_argmax[map_argmax != self.num_parts] = 1 67 | map_argmax[map_argmax == self.num_parts] = 0 68 | mask = mask.float() 69 | map_argmax = map_argmax.float() 70 | inv_mask = 1 - mask 71 | inv_map_argmax = 1 - map_argmax 72 | iou = self.metric_fg(map_argmax, mask) 73 | iou_values.append(iou.item()) 74 | iou_bg = self.metric_bg(inv_map_argmax, inv_mask) 75 | iou_values_bg.append(iou_bg.item()) 76 | plot_iou_figs(iou_values, iou_values_bg, model_path) 77 | 78 | -------------------------------------------------------------------------------- /evaluation_instructions.md: -------------------------------------------------------------------------------- 1 | # Evaluation Instructions 2 | - We recommend evaluating on one GPU. The code technically runs for multiple GPUs as well, but we have not implemented the final averaging of the evaluation metrics across GPUs. 3 | - Additionally, we observe that it is best to use a batch size which is a multiple of the total number of examples in the test set. Here are the dataset sizes: 4 | - CUB: 5794 5 | - Oxford Flowers: 6149 6 | - PartImageNet OOD: 1658 7 | - PartImageNet Seg: 2405 8 | - NABirds: 24633 9 | - We provide a function called [factors](utils/misc_utils.py) which can be used to find the factors of these dataset sizes. 10 | 11 | ## Classification 12 | - For classification evaluation, simply adapt the command from [training instructions](training_instructions.md) by adding the `--eval_only` flag. 13 | - The command should look like this: 14 | ``` 15 | python train_net.py \ 16 | --eval_only \ 17 | --snapshot_dir \ 18 | --dataset \ 19 | 20 | ``` 21 | - There is no need to specify the `--wandb` flag for evaluation. All the metrics will be printed to the console. 22 | 23 | ## Part Discovery 24 | - For part discovery evaluation, use the following command: 25 | ``` 26 | python evaluate_parts.py \ 27 | --model_path \ 28 | --dataset \ 29 | --center_crop \ 30 | --eval_mode \ 31 | --num_parts \ 32 | 33 | 34 | ``` 35 | ### Specific Arguments 36 | - `--eval_mode`: There are 3 options: `nmi_ari`, `keypoint`, `fg_bg`. 37 | - `nmi_ari`: This mode evaluates the model's part discovery performance using the Normalized Mutual Information (NMI) and Adjusted Rand Index (ARI) metrics. This mode is used for CUB, NABirds, PartImageNet OOD and PartImageNetSeg datasets. 38 | - `keypoint`: This mode evaluates the model's part discovery performance using the keypoint detection metrics. This mode is used for CUB and NABirds datasets. 39 | - `fg_bg`: This mode evaluates the model's part discovery performance using the foreground-background segmentation metrics. This mode is used only for Oxford Flowers dataset. 40 | - `--num_parts`: The number of foreground parts predicted by the model. This is the same value that was used during training. 41 | - `--center_crop`: This flag is necessary for evaluation on Vision Transformers. It crops the center of the image to the required size before evaluation. This is necessary because the Vision Transformer model requires a fixed input size. Additionally, if you want to evaluate with batch size > 1, you need to use the `--center_crop` flag. 42 | - `--model_path`: The path to the model checkpoint. 43 | - `--dataset`: The name of the dataset. This is used to load the dataset and the corresponding evaluation metrics. The options are: `cub`, `part_imagenet` and `flowers102seg`. Note: For NABirds, use `cub` as the dataset name. As the dataset is similar to CUB, the evaluation metrics and dataset loading functions are the same. 44 | 45 | -------------------------------------------------------------------------------- /prepare_partimagenet_ood.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import copy 3 | import os 4 | import argparse 5 | from utils.data_utils.dataset_utils import load_json, save_json 6 | 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser( 10 | description='Prepare PartImagenet OOD dataset' 11 | ) 12 | parser.add_argument('--anno_path', type=str, required=True) 13 | parser.add_argument('--train_test_split_file', type=str, required=True) 14 | parser.add_argument('--output_path', type=str, required=True) 15 | return parser.parse_args() 16 | 17 | 18 | def prepare_pimagenet_ood(args): 19 | coco_json_path = args.anno_path 20 | train_test_split_path = args.train_test_split_file 21 | data = load_json(coco_json_path) 22 | 23 | columns = ["image_id", "is_test", "label_id", "label_name", "image_name"] 24 | 25 | train_test_csv = pd.read_csv(train_test_split_path, sep="\t", header=None, names=columns) 26 | image_id_to_test = {} 27 | for index, row in train_test_csv.iterrows(): 28 | image_id_to_test[row["image_id"]] = row["is_test"] 29 | 30 | train_data = [] 31 | test_data = [] 32 | for image in data["images"]: 33 | label_name = image["file_name"].split("_")[0] 34 | image["file_name"] = os.path.join(label_name, image["file_name"]) 35 | if image_id_to_test[image["id"]]: 36 | test_data.append(image) 37 | else: 38 | train_data.append(image) 39 | 40 | train_annotations = [] 41 | test_annotations = [] 42 | for ann in data["annotations"]: 43 | if image_id_to_test[ann["image_id"]]: 44 | test_annotations.append(ann) 45 | else: 46 | train_annotations.append(ann) 47 | 48 | # Now adjust the image ids in the annotations 49 | train_img_count = 0 50 | original_img_id_to_new_img_id = {} 51 | for image in train_data: 52 | original_img_id = image["id"] 53 | image["id"] = train_img_count 54 | original_img_id_to_new_img_id[original_img_id] = train_img_count 55 | train_img_count += 1 56 | 57 | for ann in train_annotations: 58 | ann["image_id"] = original_img_id_to_new_img_id[ann["image_id"]] 59 | 60 | test_img_count = 0 61 | original_img_id_to_new_img_id_test = {} 62 | for image in test_data: 63 | original_img_id = image["id"] 64 | image["id"] = test_img_count 65 | original_img_id_to_new_img_id_test[original_img_id] = test_img_count 66 | test_img_count += 1 67 | 68 | for ann in test_annotations: 69 | ann["image_id"] = original_img_id_to_new_img_id_test[ann["image_id"]] 70 | 71 | # Save the new json files 72 | train_json = copy.deepcopy(data) 73 | test_json = copy.deepcopy(data) 74 | train_json["images"] = train_data 75 | train_json["annotations"] = train_annotations 76 | test_json["images"] = test_data 77 | test_json["annotations"] = test_annotations 78 | save_json(train_json, os.path.join(args.output_path, "train_train.json")) 79 | save_json(test_json, os.path.join(args.output_path, "train_test.json")) 80 | 81 | 82 | if __name__ == '__main__': 83 | arguments = parse_args() 84 | prepare_pimagenet_ood(arguments) 85 | -------------------------------------------------------------------------------- /utils/data_utils/reversible_affine_transform.py: -------------------------------------------------------------------------------- 1 | # Description: This file contains the code for the reversible affine transform 2 | import torchvision.transforms as transforms 3 | import torch 4 | from typing import List, Optional, Tuple, Any 5 | 6 | 7 | def generate_affine_trans_params( 8 | degrees: List[float], 9 | translate: Optional[List[float]], 10 | scale_ranges: Optional[List[float]], 11 | shears: Optional[List[float]], 12 | img_size: List[int], 13 | ) -> Tuple[float, Tuple[int, int], float, Any]: 14 | """Get parameters for affine transformation 15 | 16 | Returns: 17 | params to be passed to the affine transformation 18 | """ 19 | angle = float(torch.empty(1).uniform_(float(degrees[0]), float(degrees[1])).item()) 20 | if translate is not None: 21 | max_dx = float(translate[0] * img_size[0]) 22 | max_dy = float(translate[1] * img_size[1]) 23 | tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item())) 24 | ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item())) 25 | translations = (tx, ty) 26 | else: 27 | translations = (0, 0) 28 | 29 | if scale_ranges is not None: 30 | scale = float(torch.empty(1).uniform_(scale_ranges[0], scale_ranges[1]).item()) 31 | else: 32 | scale = 1.0 33 | 34 | shear_x = shear_y = 0.0 35 | if shears is not None: 36 | shear_x = float(torch.empty(1).uniform_(shears[0], shears[1]).item()) 37 | if len(shears) == 4: 38 | shear_y = float(torch.empty(1).uniform_(shears[2], shears[3]).item()) 39 | 40 | shear = (shear_x, shear_y) 41 | if shear_x == 0.0 and shear_y == 0.0: 42 | shear = 0.0 43 | 44 | return angle, translations, scale, shear 45 | 46 | 47 | def rigid_transform(img, angle, translate, scale, invert=False, shear=0, 48 | interpolation=transforms.InterpolationMode.BILINEAR): 49 | """ 50 | Affine transforms input image 51 | Modified from: https://github.com/robertdvdk/part_detection/blob/eec53f2f40602113f74c6c1f60a2034823b0fcaf/lib.py#L54 52 | Parameters 53 | ---------- 54 | img: Tensor 55 | Input image 56 | angle: int 57 | Rotation angle between -180 and 180 degrees 58 | translate: [int] 59 | Sequence of horizontal/vertical translations 60 | scale: float 61 | How to scale the image 62 | invert: bool 63 | Whether to invert the transformation 64 | shear: float 65 | Shear angle in degrees 66 | interpolation: InterpolationMode 67 | Interpolation mode to calculate output values 68 | Returns 69 | ---------- 70 | img: Tensor 71 | Transformed image 72 | 73 | """ 74 | if not invert: 75 | img = transforms.functional.affine(img, angle=angle, translate=translate, scale=scale, shear=shear, 76 | interpolation=interpolation) 77 | else: 78 | translate = [-t for t in translate] 79 | img = transforms.functional.affine(img=img, angle=0, translate=translate, scale=1, shear=shear) 80 | img = transforms.functional.affine(img=img, angle=-angle, translate=[0, 0], scale=1 / scale, shear=shear) 81 | 82 | return img 83 | -------------------------------------------------------------------------------- /engine/losses/enforced_presence_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class EnforcedPresenceLoss(torch.nn.Module): 5 | """ 6 | This class defines the Enforced Presence loss. 7 | """ 8 | 9 | def __init__(self, loss_type: str = "log", eps: float = 1e-10): 10 | super(EnforcedPresenceLoss, self).__init__() 11 | self.loss_type = loss_type 12 | self.eps = eps 13 | self.grid_x = None 14 | self.grid_y = None 15 | self.mask = None 16 | 17 | def forward(self, maps): 18 | """ 19 | Forward function for the Enforced Presence loss. 20 | :param maps: Attention map with shape (batch_size, channels, height, width) where channels is the landmark probability 21 | :return: The Enforced Presence loss 22 | """ 23 | if self.loss_type == "enforced_presence": 24 | avg_pooled_maps = torch.nn.functional.avg_pool2d( 25 | maps, 3, stride=1) 26 | if self.grid_x is None or self.grid_y is None: 27 | grid_x, grid_y = torch.meshgrid(torch.arange(avg_pooled_maps.shape[2]), 28 | torch.arange(avg_pooled_maps.shape[3]), indexing='ij') 29 | grid_x = grid_x.unsqueeze(0).unsqueeze(0).contiguous().to(avg_pooled_maps.device, 30 | non_blocking=True) 31 | grid_y = grid_y.unsqueeze(0).unsqueeze(0).contiguous().to(avg_pooled_maps.device, 32 | non_blocking=True) 33 | grid_x = (grid_x / grid_x.max()) * 2 - 1 34 | grid_y = (grid_y / grid_y.max()) * 2 - 1 35 | 36 | mask = grid_x ** 2 + grid_y ** 2 37 | mask = mask / mask.max() 38 | self.grid_x = grid_x 39 | self.grid_y = grid_y 40 | self.mask = mask 41 | 42 | masked_part_activation = avg_pooled_maps * self.mask 43 | masked_bg_part_activation = masked_part_activation[:, -1, :, :] 44 | 45 | max_pooled_maps = torch.nn.functional.adaptive_max_pool2d(masked_bg_part_activation, 1).flatten(start_dim=0) 46 | # Turn off AMP for this line 47 | with torch.amp.autocast(device_type='cuda', enabled=False): 48 | loss_area = torch.nn.functional.binary_cross_entropy(max_pooled_maps, torch.ones_like(max_pooled_maps)) 49 | else: 50 | part_activation_sums = torch.nn.functional.adaptive_avg_pool2d(maps, 1).flatten(start_dim=1) 51 | background_part_activation = part_activation_sums[:, -1] 52 | if self.loss_type == "log": 53 | with torch.amp.autocast(device_type='cuda', enabled=False): 54 | loss_area = torch.nn.functional.binary_cross_entropy(background_part_activation, 55 | torch.ones_like(background_part_activation)) 56 | 57 | elif self.loss_type == "linear": 58 | loss_area = (1 - background_part_activation).mean() 59 | 60 | elif self.loss_type == "mse": 61 | loss_area = torch.nn.functional.mse_loss(background_part_activation, 62 | torch.ones_like(background_part_activation)) 63 | else: 64 | raise ValueError(f"Invalid loss type: {self.loss_type}") 65 | 66 | return loss_area 67 | -------------------------------------------------------------------------------- /data_sets/plantnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import glob 4 | from collections import defaultdict 5 | from utils.data_utils.dataset_utils import pil_loader, load_json 6 | 7 | 8 | class PlantNet(torch.utils.data.Dataset): 9 | """ 10 | Class to train models on PlantNet300K 11 | Variables 12 | base_folder, str: Root directory of the dataset. 13 | image_sub_path, str: Path to the folder containing the images. 14 | transform, callable: A function/transform that takes in a PIL.Image and transforms it. 15 | 16 | """ 17 | 18 | def __init__(self, base_folder, image_sub_path, transform=None, metadata_path=None, species_id_to_name_file=None): 19 | self.images_folder = os.path.join(base_folder, image_sub_path) 20 | 21 | self.transform = transform 22 | 23 | self.loader = pil_loader 24 | 25 | self.metadata = load_json(metadata_path) 26 | self.species_id_to_name = load_json(species_id_to_name_file) 27 | 28 | self.image_paths = glob.glob(os.path.join(self.images_folder, "**/*.jpg"), recursive=True) 29 | self.image_paths += glob.glob(os.path.join(self.images_folder, "**/*.jpeg"), recursive=True) 30 | self.image_paths += glob.glob(os.path.join(self.images_folder, "**/*.png"), recursive=True) 31 | self.image_paths += glob.glob(os.path.join(self.images_folder, "**/*.bmp"), recursive=True) 32 | self.image_paths += glob.glob(os.path.join(self.images_folder, "**/*.ppm"), recursive=True) 33 | self.image_paths += glob.glob(os.path.join(self.images_folder, "**/*.JPEG"), recursive=True) 34 | 35 | self.image_paths = sorted(self.image_paths) 36 | 37 | self.class_names = self.species_id_to_name.keys() 38 | self.class_to_idx = {class_name: idx for idx, class_name in enumerate(self.class_names)} 39 | 40 | self.labels = [self.class_to_idx[os.path.basename(os.path.dirname(image_path))] for image_path in 41 | self.image_paths] 42 | self.idx_to_class = {idx: class_name for class_name, idx in self.class_to_idx.items()} 43 | self.num_classes = len(self.class_names) 44 | self.idx_to_species_name = {idx: self.species_id_to_name[self.idx_to_class[idx]] for idx in 45 | range(self.num_classes)} 46 | self.per_class_count = defaultdict(int) 47 | self.class_to_img_ids = defaultdict(list) 48 | for idx, label in enumerate(self.labels): 49 | self.per_class_count[self.idx_to_class[label]] += 1 50 | self.class_to_img_ids[self.idx_to_class[label]].append(idx) 51 | # For top-K loss (class distribution) 52 | self.cls_num_list = [self.per_class_count[self.idx_to_class[idx]] for idx in range(self.num_classes)] 53 | 54 | def __len__(self): 55 | return len(self.labels) 56 | 57 | def __getitem__(self, idx): 58 | image_path = self.image_paths[idx] 59 | image = self.loader(image_path) 60 | label = self.labels[idx] 61 | image_name = image_path.split("/")[-1].split(".")[0] 62 | metadata = self.metadata[image_name] 63 | if self.transform is not None: 64 | image = self.transform(image) 65 | return image, label, metadata 66 | 67 | def generate_class_balanced_indices(self, generator: torch.Generator, num_samples_per_class=10): 68 | indices = [] 69 | for class_name, img_ids in self.class_to_img_ids.items(): 70 | # randomly sample num_samples_per_class images from each class 71 | sampled_img_ids = torch.randperm(len(img_ids), generator=generator).tolist() 72 | if len(img_ids) > num_samples_per_class: 73 | sampled_img_ids = sampled_img_ids[:num_samples_per_class] 74 | indices.extend(sampled_img_ids) 75 | return indices 76 | 77 | 78 | -------------------------------------------------------------------------------- /engine/losses/total_variation.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Optional, Tuple, Union 15 | import torch 16 | from torch import Tensor 17 | from typing_extensions import Literal 18 | 19 | 20 | def _total_variation_update(img: Tensor) -> Tuple[Tensor, int]: 21 | """Compute total variation statistics on current batch.""" 22 | if img.ndim != 4: 23 | raise RuntimeError(f"Expected input `img` to be an 4D tensor, but got {img.shape}") 24 | diff1 = img[..., 1:, :] - img[..., :-1, :] 25 | diff2 = img[..., :, 1:] - img[..., :, :-1] 26 | 27 | res1 = diff1.abs().sum([1, 2, 3]) 28 | res2 = diff2.abs().sum([1, 2, 3]) 29 | score = res1 + res2 30 | return score 31 | 32 | 33 | def _total_variation_compute( 34 | score: Tensor, num_elements: Union[int, Tensor], reduction: Optional[Literal["mean", "sum", "none"]] 35 | ) -> Tensor: 36 | """Compute final total variation score.""" 37 | if reduction == "mean": 38 | return score.sum() / num_elements 39 | if reduction == "sum": 40 | return score.sum() 41 | if reduction is None or reduction == "none": 42 | return score 43 | raise ValueError("Expected argument `reduction` to either be 'sum', 'mean', 'none' or None") 44 | 45 | 46 | def total_variation(img: Tensor, reduction: Optional[Literal["mean", "sum", "none"]] = "sum", num_elements: int = 0) -> Tensor: 47 | """Compute total variation loss. 48 | 49 | Args: 50 | img: A `Tensor` of shape `(N, C, H, W)` consisting of images 51 | reduction: a method to reduce metric score over samples. 52 | 53 | - ``'mean'``: takes the mean over samples 54 | - ``'sum'``: takes the sum over samples 55 | - ``None`` or ``'none'``: return the score per sample 56 | num_elements: The number of elements in the input tensor 57 | 58 | Returns: 59 | A loss scalar value containing the total variation 60 | 61 | Raises: 62 | ValueError: 63 | If ``reduction`` is not one of ``'sum'``, ``'mean'``, ``'none'`` or ``None`` 64 | RuntimeError: 65 | If ``img`` is not 4D tensor 66 | 67 | """ 68 | # code adapted from: 69 | # from kornia.losses import total_variation as kornia_total_variation 70 | score = _total_variation_update(img) 71 | return _total_variation_compute(score, num_elements, reduction) 72 | 73 | 74 | class TotalVariationLoss(torch.nn.Module): 75 | """ 76 | Compute total variation loss. 77 | args: 78 | reduction: a method to reduce metric score over samples. 79 | 80 | - ``'mean'``: takes the mean over samples 81 | - ``'sum'``: takes the sum over samples 82 | - ``None`` or ``'none'``: return the score per sample 83 | """ 84 | 85 | def __init__(self, reduction: Optional[Literal["mean", "sum", "none"]] = "mean") -> None: 86 | super(TotalVariationLoss, self).__init__() 87 | self.reduction = reduction 88 | self.num_elements = None 89 | 90 | def forward(self, img: Tensor) -> Tensor: 91 | if self.num_elements is None: 92 | self.num_elements = img.shape[0] * img.shape[2] * img.shape[3] 93 | return total_variation(img=img, reduction=self.reduction, num_elements=self.num_elements) 94 | 95 | -------------------------------------------------------------------------------- /engine/losses/presence_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def presence_loss_soft_constraint(maps: torch.Tensor, beta: float = 0.1): 5 | """ 6 | Calculate presence loss for a feature map 7 | :param maps: Attention map with shape (batch_size, channels, height, width) where channels is the landmark probability 8 | :param beta: Weight of soft constraint 9 | :return: value of the presence loss 10 | """ 11 | loss_max = torch.nn.functional.adaptive_max_pool2d(torch.nn.functional.avg_pool2d( 12 | maps, 3, stride=1), 1).flatten(start_dim=1).max(dim=0)[0] 13 | loss_max_detach = loss_max.detach().clone() 14 | loss_max_p1 = 1 - loss_max 15 | loss_max_p2 = ((1 - beta) * loss_max_detach) + beta 16 | loss_max_final = (loss_max_p1 * loss_max_p2).mean() 17 | return loss_max_final 18 | 19 | 20 | def presence_loss_tanh(maps: torch.Tensor): 21 | """ 22 | Calculate presence loss for a feature map with tanh formulation from the paper PIP-NET 23 | Ref: https://github.com/M-Nauta/PIPNet/blob/68054822ee405b5f292369ca846a9c6233f2df69/pipnet/train.py#L111 24 | :param maps: Attention map with shape (batch_size, channels, height, width) where channels is the landmark probability 25 | :return: 26 | """ 27 | pooled_maps = torch.tanh(torch.sum(torch.nn.functional.adaptive_max_pool2d(torch.nn.functional.avg_pool2d( 28 | maps, 3, stride=1), 1).flatten(start_dim=1), dim=0)) 29 | with torch.amp.autocast(device_type='cuda', enabled=False): 30 | loss_max = torch.nn.functional.binary_cross_entropy(pooled_maps, target=torch.ones_like(pooled_maps)) 31 | 32 | return loss_max 33 | 34 | 35 | def presence_loss_soft_tanh(maps: torch.Tensor): 36 | """ 37 | Calculate presence loss for a feature map with tanh formulation (non-log/softer version) 38 | :param maps: Attention map with shape (batch_size, channels, height, width) where channels is the landmark probability 39 | :return: 40 | """ 41 | pooled_maps = torch.tanh(torch.sum(torch.nn.functional.adaptive_max_pool2d(torch.nn.functional.avg_pool2d( 42 | maps, 3, stride=1), 1).flatten(start_dim=1), dim=0)) 43 | 44 | loss_max = 1 - pooled_maps 45 | 46 | return loss_max.mean() 47 | 48 | 49 | def presence_loss_original(maps: torch.Tensor): 50 | """ 51 | Calculate presence loss for a feature map 52 | Modified from: https://github.com/robertdvdk/part_detection/blob/eec53f2f40602113f74c6c1f60a2034823b0fcaf/train.py#L181 53 | :param maps: Attention map with shape (batch_size, channels, height, width) where channels is the landmark probability 54 | :return: value of the presence loss 55 | """ 56 | 57 | loss_max = torch.nn.functional.adaptive_max_pool2d(torch.nn.functional.avg_pool2d( 58 | maps, 3, stride=1), 1).flatten(start_dim=1).max(dim=0)[0].mean() 59 | 60 | return 1 - loss_max 61 | 62 | 63 | class PresenceLoss(torch.nn.Module): 64 | """ 65 | This class defines the presence loss. 66 | """ 67 | 68 | def __init__(self, loss_type: str = "original", beta: float = 0.1): 69 | super(PresenceLoss, self).__init__() 70 | self.loss_type = loss_type 71 | self.beta = beta 72 | 73 | def forward(self, maps): 74 | """ 75 | Forward function for the presence loss. 76 | :param maps: Attention map with shape (batch_size, channels, height, width) where channels is the landmark probability 77 | :return: The presence loss 78 | """ 79 | if self.loss_type == "original": 80 | return presence_loss_original(maps) 81 | elif self.loss_type == "soft_constraint": 82 | return presence_loss_soft_constraint(maps, beta=self.beta) 83 | elif self.loss_type == "tanh": 84 | return presence_loss_tanh(maps) 85 | elif self.loss_type == "soft_tanh": 86 | return presence_loss_soft_tanh(maps) 87 | else: 88 | raise NotImplementedError(f"Presence loss {self.loss_type} not implemented") 89 | -------------------------------------------------------------------------------- /train_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from timeit import default_timer as timer 3 | 4 | from argument_parser_train import parse_args 5 | from utils.data_utils.transform_utils import load_transforms 6 | from utils.training_utils.optimizer_params import build_optimizer, layer_group_matcher_pdisco 7 | from utils.training_utils.scheduler_params import build_scheduler 8 | from utils.misc_utils import sync_bn_conversion, check_snapshot 9 | from utils.training_utils.ddp_utils import multi_gpu_check 10 | from utils.wandb_params import get_train_loggers 11 | from engine.distributed_trainer_pdisco import launch_pdisco_trainer 12 | from load_dataset import get_dataset 13 | from load_model import load_model_pdisco 14 | from load_losses import load_classification_loss, load_loss_hyper_params 15 | 16 | torch.backends.cudnn.benchmark = True 17 | 18 | 19 | def pdisco_train_eval(): 20 | args = parse_args() 21 | 22 | train_loggers = get_train_loggers(args) 23 | 24 | # Create directory to save training checkpoints, otherwise load the existing checkpoint 25 | check_snapshot(args) 26 | 27 | # Get the transforms and load the dataset 28 | train_transforms, test_transforms = load_transforms(args) 29 | 30 | # Load the dataset 31 | dataset_train, dataset_test, num_cls = get_dataset(args, train_transforms, test_transforms) 32 | 33 | # Load the model 34 | model = load_model_pdisco(args, num_cls) 35 | 36 | # Check if there are multiple GPUs 37 | use_ddp = multi_gpu_check() 38 | # Convert BatchNorm to SyncBatchNorm if there is more than 1 GPU 39 | if use_ddp: 40 | model = sync_bn_conversion(model) 41 | 42 | # Load the loss function 43 | loss_fn, mixup_fn = load_classification_loss(args, num_cls) 44 | 45 | # Load the loss hyperparameters 46 | loss_hyperparams, eq_affine_transform_params = load_loss_hyper_params(args) 47 | 48 | # Define the optimizer and scheduler 49 | param_groups = layer_group_matcher_pdisco(args, model) 50 | optimizer = build_optimizer(args, param_groups, dataset_train) 51 | scheduler = build_scheduler(args, optimizer) 52 | 53 | # Load averaging parameters 54 | averaging_params = {'type': args.averaging_type, 'decay': args.model_ema_decay, 55 | 'use_warmup': not args.no_model_ema_warmup, 56 | 'device': 'cpu' if args.model_ema_force_cpu else None} 57 | 58 | # Start the timer 59 | start_time = timer() 60 | 61 | # Setup training and save the results 62 | launch_pdisco_trainer(model=model, 63 | train_dataset=dataset_train, 64 | test_dataset=dataset_test, 65 | batch_size=args.batch_size, 66 | optimizer=optimizer, 67 | scheduler=scheduler, 68 | loss_fn=loss_fn, 69 | epochs=args.epochs, 70 | save_every=args.save_every_n_epochs, 71 | loggers=train_loggers, 72 | log_freq=args.log_interval, 73 | use_amp=args.use_amp, 74 | snapshot_path=args.snapshot_dir, 75 | grad_norm_clip=args.grad_norm_clip, 76 | num_workers=args.num_workers, 77 | mixup_fn=mixup_fn, 78 | seed=args.seed, 79 | eval_only=args.eval_only, 80 | loss_hyperparams=loss_hyperparams, 81 | eq_affine_transform_params=eq_affine_transform_params, 82 | use_ddp=use_ddp, 83 | sub_path_test=args.image_sub_path_test, 84 | dataset_name=args.dataset, 85 | amap_saving_prob=args.amap_saving_prob, 86 | grad_accumulation_steps=args.grad_accumulation_steps, 87 | averaging_params=averaging_params, 88 | ) 89 | 90 | # End the timer and print out how long it took 91 | end_time = timer() 92 | print(f"[INFO] Total training time: {end_time - start_time:.3f} seconds") 93 | 94 | 95 | if __name__ == "__main__": 96 | pdisco_train_eval() 97 | -------------------------------------------------------------------------------- /load_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from timm.data.mixup import Mixup 3 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy 4 | 5 | 6 | def load_mixup_fn(args, num_classes): 7 | """Load the mixup function""" 8 | mixup_fn = None 9 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 10 | if mixup_active: 11 | mixup_args = dict( 12 | mixup_alpha=args.mixup, 13 | cutmix_alpha=args.cutmix, 14 | cutmix_minmax=args.cutmix_minmax, 15 | prob=args.mixup_prob, 16 | switch_prob=args.mixup_switch_prob, 17 | mode=args.mixup_mode, 18 | label_smoothing=args.smoothing, 19 | num_classes=num_classes, 20 | ) 21 | mixup_fn = Mixup(**mixup_args) 22 | return mixup_fn 23 | 24 | 25 | def load_classification_loss(args, num_cls): 26 | """ 27 | Load the loss function for classification 28 | :param args: Arguments from the argument parser 29 | :param num_cls: Number of classes in the dataset 30 | :return: 31 | loss_fn: List of loss functions for training and evaluation 32 | """ 33 | # Mixup/Cutmix 34 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 35 | mixup_fn = load_mixup_fn(args, num_cls) 36 | # Set up loss function for training 37 | if mixup_active: 38 | if args.use_bce_loss: 39 | loss_fn_train = BinaryCrossEntropy( 40 | target_threshold=args.bce_target_thresh, 41 | sum_classes=args.bce_sum, 42 | pos_weight=args.bce_pos_weight, 43 | ) 44 | else: 45 | loss_fn_train = SoftTargetCrossEntropy() 46 | elif args.smoothing: 47 | if args.use_bce_loss: 48 | loss_fn_train = BinaryCrossEntropy( 49 | smoothing=args.smoothing, 50 | target_threshold=args.bce_target_thresh, 51 | sum_classes=args.bce_sum, 52 | pos_weight=args.bce_pos_weight, 53 | ) 54 | else: 55 | loss_fn_train = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 56 | else: 57 | if args.use_bce_loss: 58 | loss_fn_train = BinaryCrossEntropy( 59 | smoothing=0.0, 60 | target_threshold=args.bce_target_thresh, 61 | sum_classes=args.bce_sum, 62 | pos_weight=args.bce_pos_weight, 63 | ) 64 | else: 65 | loss_fn_train = torch.nn.CrossEntropyLoss() 66 | 67 | loss_fn_eval = torch.nn.CrossEntropyLoss() 68 | loss_fn = [loss_fn_train, loss_fn_eval] 69 | return loss_fn, mixup_fn 70 | 71 | 72 | def load_loss_hyper_params(args): 73 | """ 74 | Load the hyperparameters for the loss functions and affine transform parameters for equivariance 75 | :param args: Arguments from the argument parser 76 | :return: 77 | loss_hyperparams: Dictionary of loss hyperparameters 78 | eq_affine_transform_params: Dictionary of affine transform parameters for equivariance 79 | """ 80 | loss_hyperparams = {'l_class_att': args.classification_loss, 'l_presence': args.presence_loss, 81 | 'l_presence_beta': args.presence_loss_beta, 'l_presence_type': args.presence_loss_type, 82 | 'l_equiv': args.equivariance_loss, 'l_conc': args.concentration_loss, 83 | 'l_orth': args.orthogonality_loss_landmarks, 'l_tv': args.total_variation_loss, 84 | 'l_enforced_presence': args.enforced_presence_loss, 'l_pixel_wise_entropy': args.pixel_wise_entropy_loss, 85 | 'l_enforced_presence_loss_type': args.enforced_presence_loss_type} 86 | 87 | # Affine transform parameters for equivariance 88 | degrees = [-args.degrees, args.degrees] 89 | translate = [args.translate_x, args.translate_y] 90 | scale = [args.scale_l, args.scale_u] 91 | shear_x = args.shear_x 92 | shear_y = args.shear_y 93 | shear = [shear_x, shear_y] 94 | if shear_x == 0.0 and shear_y == 0.0: 95 | shear = None 96 | 97 | eq_affine_transform_params = {'degrees': degrees, 'translate': translate, 'scale_ranges': scale, 'shear': shear} 98 | 99 | return loss_hyperparams, eq_affine_transform_params 100 | -------------------------------------------------------------------------------- /utils/data_utils/class_balanced_distributed_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from typing import Optional 4 | import math 5 | import torch.distributed as dist 6 | 7 | 8 | class ClassBalancedDistributedSampler(torch.utils.data.Sampler): 9 | """ 10 | A custom sampler that sub-samples a given dataset based on class labels. Based on the DistributedSampler class 11 | Ref: https://github.com/pytorch/pytorch/blob/04c1df651aa58bea50977f4efcf19b09ce27cefd/torch/utils/data/distributed.py#L13 12 | """ 13 | 14 | def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None, rank: Optional[int] = None, 15 | shuffle: bool = True, seed: int = 0, drop_last: bool = False, num_samples_per_class=100) -> None: 16 | 17 | if not shuffle: 18 | raise ValueError("ClassBalancedDatasetSubSampler requires shuffling, otherwise use DistributedSampler") 19 | 20 | # Check if the dataset has a generate_class_balanced_indices method 21 | if not hasattr(dataset, 'generate_class_balanced_indices'): 22 | raise ValueError("Dataset does not have a generate_class_balanced_indices method") 23 | 24 | self.shuffle = shuffle 25 | self.seed = seed 26 | if num_replicas is None: 27 | if not dist.is_available(): 28 | raise RuntimeError("Requires distributed package to be available") 29 | num_replicas = dist.get_world_size() 30 | if rank is None: 31 | if not dist.is_available(): 32 | raise RuntimeError("Requires distributed package to be available") 33 | rank = dist.get_rank() 34 | if rank >= num_replicas or rank < 0: 35 | raise ValueError( 36 | f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]") 37 | self.dataset = dataset 38 | self.num_replicas = num_replicas 39 | self.rank = rank 40 | self.epoch = 0 41 | self.drop_last = drop_last 42 | 43 | # Calculate the number of samples 44 | g = torch.Generator() 45 | g.manual_seed(self.seed + self.epoch) 46 | self.num_samples_per_class = num_samples_per_class 47 | indices = dataset.generate_class_balanced_indices(torch.Generator(), 48 | num_samples_per_class=num_samples_per_class) 49 | dataset_size = len(indices) 50 | 51 | # If the dataset length is evenly divisible by # of replicas, then there 52 | # is no need to drop any data, since the dataset will be split equally. 53 | if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type] 54 | # Split to nearest available length that is evenly divisible. 55 | # This is to ensure each rank receives the same amount of data when 56 | # using this Sampler. 57 | self.num_samples = math.ceil( 58 | (dataset_size - self.num_replicas) / self.num_replicas # type: ignore[arg-type] 59 | ) 60 | else: 61 | self.num_samples = math.ceil(dataset_size / self.num_replicas) # type: ignore[arg-type] 62 | self.total_size = self.num_samples * self.num_replicas 63 | 64 | def __iter__(self): 65 | # deterministically shuffle based on epoch and seed, here shuffle is assumed to be True 66 | g = torch.Generator() 67 | g.manual_seed(self.seed + self.epoch) 68 | indices = self.dataset.generate_class_balanced_indices(g, num_samples_per_class=self.num_samples_per_class) 69 | 70 | if not self.drop_last: 71 | # add extra samples to make it evenly divisible 72 | padding_size = self.total_size - len(indices) 73 | if padding_size <= len(indices): 74 | indices += indices[:padding_size] 75 | else: 76 | indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] 77 | else: 78 | # remove tail of data to make it evenly divisible. 79 | indices = indices[:self.total_size] 80 | 81 | # subsample 82 | indices = indices[self.rank:self.total_size:self.num_replicas] 83 | 84 | return iter(indices) 85 | 86 | def __len__(self) -> int: 87 | return self.num_samples 88 | 89 | def set_epoch(self, epoch: int) -> None: 90 | r""" 91 | Set the epoch for this sampler. 92 | 93 | When :attr:`shuffle=True`, this ensures all replicas 94 | use a different random ordering for each epoch. Otherwise, the next iteration of this 95 | sampler will yield the same ordering. 96 | 97 | Args: 98 | epoch (int): Epoch number. 99 | """ 100 | self.epoch = epoch 101 | -------------------------------------------------------------------------------- /utils/data_utils/transform_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms as transforms 3 | from torchvision.transforms import Compose 4 | 5 | from timm.data.constants import \ 6 | IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 7 | from timm.data import create_transform 8 | 9 | 10 | def make_train_transforms(args): 11 | train_transforms: Compose = transforms.Compose([ 12 | transforms.Resize(size=args.image_size, antialias=True), 13 | transforms.RandomHorizontalFlip(p=args.hflip), 14 | transforms.RandomVerticalFlip(p=args.vflip), 15 | transforms.ColorJitter(), 16 | transforms.RandomAffine(degrees=90, translate=(0.2, 0.2), scale=(0.8, 1.2)), 17 | transforms.RandomCrop(args.image_size), 18 | transforms.ToTensor(), 19 | transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD) 20 | 21 | ]) 22 | return train_transforms 23 | 24 | 25 | def make_test_transforms(args): 26 | test_transforms: Compose = transforms.Compose([ 27 | transforms.Resize(size=args.image_size, antialias=True), 28 | transforms.CenterCrop(args.image_size), 29 | transforms.ToTensor(), 30 | transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD) 31 | 32 | ]) 33 | return test_transforms 34 | 35 | 36 | def build_transform_timm(args, is_train=True): 37 | resize_im = args.image_size > 32 38 | imagenet_default_mean_and_std = args.imagenet_default_mean_and_std 39 | mean = IMAGENET_INCEPTION_MEAN if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_MEAN 40 | std = IMAGENET_INCEPTION_STD if not imagenet_default_mean_and_std else IMAGENET_DEFAULT_STD 41 | 42 | if is_train: 43 | # this should always dispatch to transforms_imagenet_train 44 | transform = create_transform( 45 | input_size=args.image_size, 46 | is_training=True, 47 | color_jitter=args.color_jitter, 48 | hflip=args.hflip, 49 | vflip=args.vflip, 50 | auto_augment=args.aa, 51 | interpolation=args.train_interpolation, 52 | re_prob=args.reprob, 53 | re_mode=args.remode, 54 | re_count=args.recount, 55 | mean=mean, 56 | std=std, 57 | ) 58 | if not resize_im: 59 | transform.transforms[0] = transforms.RandomCrop( 60 | args.image_size, padding=4) 61 | return transform 62 | 63 | t = [] 64 | if resize_im: 65 | # warping (no cropping) when evaluated at 384 or larger 66 | if args.image_size >= 384: 67 | t.append( 68 | transforms.Resize((args.image_size, args.image_size), 69 | interpolation=transforms.InterpolationMode.BICUBIC, antialias=True), 70 | ) 71 | print(f"Warping {args.image_size} size input images...") 72 | else: 73 | if args.crop_pct is None: 74 | args.crop_pct = 224 / 256 75 | size = int(args.image_size / args.crop_pct) 76 | t.append( 77 | # to maintain same ratio w.r.t. 224 images 78 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BICUBIC, antialias=True), 79 | ) 80 | t.append(transforms.CenterCrop(args.image_size)) 81 | 82 | t.append(transforms.ToTensor()) 83 | t.append(transforms.Normalize(mean, std)) 84 | return transforms.Compose(t) 85 | 86 | 87 | def inverse_normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD): 88 | mean = torch.as_tensor(mean) 89 | std = torch.as_tensor(std) 90 | un_normalize = transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist()) 91 | return un_normalize 92 | 93 | 94 | def normalize_only(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD): 95 | normalize = transforms.Normalize(mean=mean, std=std) 96 | return normalize 97 | 98 | 99 | def inverse_normalize_w_resize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, 100 | resize_resolution=(256, 256)): 101 | mean = torch.as_tensor(mean) 102 | std = torch.as_tensor(std) 103 | resize_unnorm = transforms.Compose([ 104 | transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist()), 105 | transforms.Resize(size=resize_resolution, antialias=True)]) 106 | return resize_unnorm 107 | 108 | 109 | def load_transforms(args): 110 | # Get the transforms and load the dataset 111 | if args.augmentations_to_use == 'timm': 112 | train_transforms = build_transform_timm(args, is_train=True) 113 | elif args.augmentations_to_use == 'cub_original': 114 | train_transforms = make_train_transforms(args) 115 | else: 116 | raise ValueError('Augmentations not supported.') 117 | test_transforms = make_test_transforms(args) 118 | return train_transforms, test_transforms 119 | -------------------------------------------------------------------------------- /utils/misc_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import reduce 3 | 4 | import torch 5 | import numpy as np 6 | import os 7 | from pathlib import Path 8 | 9 | 10 | def factors(n): 11 | return reduce(list.__add__, 12 | ([i, n // i] for i in range(1, int(n ** 0.5) + 1) if n % i == 0)) 13 | 14 | 15 | def file_line_count(filename: str) -> int: 16 | """Count the number of lines in a file""" 17 | with open(filename, 'rb') as f: 18 | return sum(1 for _ in f) 19 | 20 | 21 | def compute_attention(qkv, scale=None): 22 | """ 23 | Compute attention matrix (same as in the pytorch scaled dot product attention) 24 | Ref: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html 25 | :param qkv: Query, key and value tensors concatenated along the first dimension 26 | :param scale: Scale factor for the attention computation 27 | :return: 28 | """ 29 | if isinstance(qkv, torch.Tensor): 30 | query, key, value = qkv.unbind(0) 31 | else: 32 | query, key, value = qkv 33 | scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale 34 | L, S = query.size(-2), key.size(-2) 35 | attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) 36 | attn_weight = query @ key.transpose(-2, -1) * scale_factor 37 | attn_weight += attn_bias 38 | attn_weight = torch.softmax(attn_weight, dim=-1) 39 | attn_out = attn_weight @ value 40 | return attn_weight, attn_out 41 | 42 | 43 | def compute_dot_product_similarity(a, b): 44 | scores = a @ b.transpose(-1, -2) 45 | return scores 46 | 47 | 48 | def compute_cross_entropy(p, q): 49 | q = torch.nn.functional.log_softmax(q, dim=-1) 50 | loss = torch.sum(p * q, dim=-1) 51 | return - loss.mean() 52 | 53 | 54 | def rollout(attentions, discard_ratio=0.9, head_fusion="max", device=torch.device("cuda")): 55 | """ 56 | Perform attention rollout, 57 | Ref: https://github.com/jacobgil/vit-explain/blob/15a81d355a5aa6128ea4e71bbd56c28888d0f33b/vit_rollout.py#L9C1-L42C16 58 | Parameters 59 | ---------- 60 | attentions : list 61 | List of attention matrices, one for each transformer layer 62 | discard_ratio : float 63 | Ratio of lowest attention values to discard 64 | head_fusion : str 65 | Type of fusion to use for attention heads. One of "mean", "max", "min" 66 | device : torch.device 67 | Device to use for computation 68 | Returns 69 | ------- 70 | mask : np.ndarray 71 | Mask of shape (width, width), where width is the square root of the number of patches 72 | """ 73 | result = torch.eye(attentions[0].size(-1), device=device) 74 | attentions = [attention.to(device) for attention in attentions] 75 | with torch.no_grad(): 76 | for attention in attentions: 77 | if head_fusion == "mean": 78 | attention_heads_fused = attention.mean(axis=1) 79 | elif head_fusion == "max": 80 | attention_heads_fused = attention.max(axis=1).values 81 | elif head_fusion == "min": 82 | attention_heads_fused = attention.min(axis=1).values 83 | else: 84 | raise "Attention head fusion type Not supported" 85 | 86 | # Drop the lowest attentions, but 87 | # don't drop the class token 88 | flat = attention_heads_fused.view(attention_heads_fused.size(0), -1) 89 | _, indices = flat.topk(int(flat.size(-1) * discard_ratio), -1, False) 90 | indices = indices[indices != 0] 91 | flat[0, indices] = 0 92 | 93 | I = torch.eye(attention_heads_fused.size(-1), device=device) 94 | a = (attention_heads_fused + 1.0 * I) / 2 95 | a = a / a.sum(dim=-1) 96 | 97 | result = torch.matmul(a, result) 98 | 99 | # Normalize the result by max value in each row 100 | result = result / result.max(dim=-1, keepdim=True)[0] 101 | return result 102 | 103 | 104 | def sync_bn_conversion(model: torch.nn.Module): 105 | """ 106 | Convert BatchNorm to SyncBatchNorm (used for DDP) 107 | :param model: PyTorch model 108 | :return: 109 | model: PyTorch model with SyncBatchNorm layers 110 | """ 111 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 112 | return model 113 | 114 | 115 | def check_snapshot(args): 116 | """ 117 | Create directory to save training checkpoints, otherwise load the existing checkpoint. 118 | Additionally, if it is an array training job, create a new directory for each training job. 119 | :param args: Arguments from the argument parser 120 | :return: 121 | """ 122 | # Check if it is an array training job (i.e. training with multiple random seeds on the same settings) 123 | if args.array_training_job and not args.resume_training: 124 | args.snapshot_dir = os.path.join(args.snapshot_dir, str(args.seed)) 125 | if not os.path.exists(args.snapshot_dir): 126 | save_dir = Path(args.snapshot_dir) 127 | save_dir.mkdir(parents=True, exist_ok=True) 128 | else: 129 | # Create directory to save training checkpoints, otherwise load the existing checkpoint 130 | if not os.path.exists(args.snapshot_dir): 131 | if ".pt" not in args.snapshot_dir or ".pth" not in args.snapshot_dir: 132 | save_dir = Path(args.snapshot_dir) 133 | save_dir.mkdir(parents=True, exist_ok=True) 134 | else: 135 | raise ValueError('Snapshot checkpoint does not exist.') 136 | -------------------------------------------------------------------------------- /inference_benchmark_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import torchvision.datasets as datasets 4 | from data_sets import FineGrainedBirdClassificationDataset, PartImageNetDataset 5 | from load_model import load_model_pdisco 6 | import argparse 7 | from tqdm import tqdm 8 | import copy 9 | from utils.training_utils.engine_utils import load_state_dict_pdisco 10 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 11 | 12 | # fix all the randomness for reproducibility 13 | torch.backends.cudnn.enabled = True 14 | torch.manual_seed(0) 15 | torch.cuda.manual_seed(0) 16 | 17 | torch.set_float32_matmul_precision('high') 18 | 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser(description='Inference benchmark models') 22 | parser.add_argument('--model_arch', default='resnet50', type=str, 23 | help='pick model architecture') 24 | parser.add_argument('--use_torchvision_resnet_model', default=False, action='store_true') 25 | 26 | # Data 27 | parser.add_argument('--data_path', 28 | help='directory that contains cub files, must' 29 | 'contain folder "./images"', required=True) 30 | parser.add_argument('--image_sub_path', default='images', type=str, required=False) 31 | parser.add_argument('--dataset', default='cub', type=str) 32 | parser.add_argument('--anno_path_test', default='', type=str, required=False) 33 | # Model params 34 | parser.add_argument('--num_parts', help='number of parts to predict', 35 | default=8, type=int) 36 | parser.add_argument('--image_size', default=448, type=int) 37 | parser.add_argument('--output_stride', default=32, type=int) 38 | parser.add_argument('--batch_size', default=1, type=int) 39 | parser.add_argument('--num_workers', default=8, type=int) 40 | # Modulation 41 | parser.add_argument('--modulation_type', default="original", 42 | choices=["original", "layer_norm", "parallel_mlp", "parallel_mlp_no_bias", 43 | "parallel_mlp_no_act", "parallel_mlp_no_act_no_bias", "none"], 44 | type=str) 45 | parser.add_argument('--modulation_orth', default=False, action='store_true', 46 | help='use orthogonality loss on modulated features') 47 | # Part Dropout 48 | parser.add_argument('--part_dropout', default=0.0, type=float) 49 | 50 | # Add noise to vit output features 51 | parser.add_argument('--noise_variance', default=0.0, type=float) 52 | 53 | # Gumbel Softmax 54 | parser.add_argument('--gumbel_softmax', default=False, action='store_true') 55 | parser.add_argument('--gumbel_softmax_temperature', default=1.0, type=float) 56 | parser.add_argument('--gumbel_softmax_hard', default=False, action='store_true') 57 | 58 | # Model path 59 | parser.add_argument('--model_path', default=None, type=str) 60 | 61 | # Classifier type 62 | parser.add_argument('--classifier_type', default="linear", 63 | choices=["linear", "independent_mlp"], type=str) 64 | 65 | args = parser.parse_args() 66 | return args 67 | 68 | 69 | def benchmark(args): 70 | args.eval_only = True 71 | args.pretrained_start_weights = True 72 | height = args.image_size 73 | test_transforms = transforms.Compose([ 74 | transforms.Resize(size=height), 75 | transforms.CenterCrop(size=height), 76 | transforms.ToTensor(), 77 | transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD) 78 | ]) 79 | # define dataset path 80 | if args.dataset == 'cub': 81 | cub_path = args.data_path 82 | # define dataset and loader 83 | eval_data = FineGrainedBirdClassificationDataset(cub_path, split=1, transform=test_transforms, mode='test') 84 | num_cls = eval_data.num_classes 85 | elif args.dataset == 'part_imagenet': 86 | # define dataset and loader 87 | eval_data = PartImageNetDataset(data_path=args.data_path, image_sub_path=args.image_sub_path, 88 | transform=test_transforms, 89 | annotation_file_path=args.anno_path_test, 90 | ) 91 | num_cls = eval_data.num_classes 92 | elif args.dataset == 'flowers102': 93 | # define dataset and loader 94 | eval_data = datasets.Flowers102(root=args.data_path, split='test', transform=test_transforms) 95 | num_cls = len(set(eval_data._labels)) 96 | else: 97 | raise ValueError('Dataset not supported.') 98 | # Load the model 99 | model = load_model_pdisco(args, num_cls) 100 | snapshot_data = torch.load(args.model_path, map_location=torch.device('cpu'), weights_only=True) 101 | if 'model_state' in snapshot_data: 102 | _, state_dict = load_state_dict_pdisco(snapshot_data) 103 | else: 104 | state_dict = copy.deepcopy(snapshot_data) 105 | model.load_state_dict(state_dict, strict=True) 106 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 107 | model = model.eval().to(device) 108 | model = torch.compile(model, mode="reduce-overhead") 109 | test_loader = torch.utils.data.DataLoader( 110 | eval_data, 111 | batch_size=args.batch_size, shuffle=False, 112 | num_workers=args.num_workers, pin_memory=True, drop_last=True) 113 | 114 | # Warmup 115 | for i, batch in tqdm(enumerate(test_loader), total=len(test_loader), desc='Warmup'): 116 | images = batch[0].to(device) 117 | with torch.no_grad(): 118 | output = model(images) 119 | if i == 100: 120 | break 121 | 122 | # Benchmark 123 | for idx in tqdm(range(100), desc="Inference benchmark"): 124 | with torch.no_grad(): 125 | output = model(images) 126 | 127 | print("Inference benchmark done!") 128 | 129 | torch._dynamo.reset() 130 | 131 | 132 | if __name__ == '__main__': 133 | args = parse_args() 134 | benchmark(args) 135 | -------------------------------------------------------------------------------- /utils/data_utils/dataset_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from torch import Tensor 3 | from typing import List, Optional 4 | import numpy as np 5 | import torchvision 6 | import json 7 | 8 | 9 | def load_json(path: str): 10 | """ 11 | Load json file from path and return the data 12 | :param path: Path to the json file 13 | :return: 14 | data: Data in the json file 15 | """ 16 | with open(path, 'r') as f: 17 | data = json.load(f) 18 | return data 19 | 20 | 21 | def save_json(data: dict, path: str): 22 | """ 23 | Save data to a json file 24 | :param data: Data to be saved 25 | :param path: Path to save the data 26 | :return: 27 | """ 28 | with open(path, "w") as f: 29 | json.dump(data, f) 30 | 31 | 32 | def pil_loader(path): 33 | """ 34 | Load image from path using PIL 35 | :param path: Path to the image 36 | :return: 37 | img: PIL Image 38 | """ 39 | with open(path, 'rb') as f: 40 | img = Image.open(f) 41 | return img.convert('RGB') 42 | 43 | 44 | def get_dimensions(image: Tensor): 45 | """ 46 | Get the dimensions of the image 47 | :param image: Tensor or PIL Image or np.ndarray 48 | :return: 49 | h: Height of the image 50 | w: Width of the image 51 | """ 52 | if isinstance(image, Tensor): 53 | _, h, w = image.shape 54 | elif isinstance(image, np.ndarray): 55 | h, w, _ = image.shape 56 | elif isinstance(image, Image.Image): 57 | w, h = image.size 58 | else: 59 | raise ValueError(f"Invalid image type: {type(image)}") 60 | return h, w 61 | 62 | 63 | def center_crop_boxes_kps(img: Tensor, output_size: Optional[List[int]] = 448, parts: Optional[Tensor] = None, 64 | boxes: Optional[Tensor] = None, num_keypoints: int = 15): 65 | """ 66 | Calculate the center crop parameters for the bounding boxes and landmarks and update them 67 | :param img: Image 68 | :param output_size: Output size of the cropped image 69 | :param parts: Locations of the landmarks of following format: 70 | :param boxes: Bounding boxes of the landmarks of following format: 71 | :param num_keypoints: Number of keypoints 72 | :return: 73 | cropped_img: Center cropped image 74 | parts: Updated locations of the landmarks 75 | boxes: Updated bounding boxes of the landmarks 76 | """ 77 | if isinstance(output_size, int): 78 | output_size = (output_size, output_size) 79 | elif isinstance(output_size, (tuple, list)) and len(output_size) == 1: 80 | output_size = (output_size[0], output_size[0]) 81 | elif isinstance(output_size, (tuple, list)) and len(output_size) == 2: 82 | output_size = output_size 83 | else: 84 | raise ValueError(f"Invalid output size: {output_size}") 85 | 86 | crop_height, crop_width = output_size 87 | image_height, image_width = get_dimensions(img) 88 | img = torchvision.transforms.functional.center_crop(img, output_size) 89 | 90 | crop_top, crop_left = _get_center_crop_params_(image_height, image_width, output_size) 91 | 92 | if parts is not None: 93 | for j in range(num_keypoints): 94 | # Skip if part is invisible 95 | if parts[j][-1] == 0: 96 | continue 97 | parts[j][1] -= crop_left 98 | parts[j][2] -= crop_top 99 | 100 | # Skip if part is outside the crop 101 | if parts[j][1] > crop_width or parts[j][2] > crop_height: 102 | parts[j][-1] = 0 103 | if parts[j][1] < 0 or parts[j][2] < 0: 104 | parts[j][-1] = 0 105 | 106 | parts[j][1] = min(crop_width, parts[j][1]) 107 | parts[j][2] = min(crop_height, parts[j][2]) 108 | parts[j][1] = max(0, parts[j][1]) 109 | parts[j][2] = max(0, parts[j][2]) 110 | 111 | if boxes is not None: 112 | boxes[1] -= crop_left 113 | boxes[2] -= crop_top 114 | boxes[1] = max(0, boxes[1]) 115 | boxes[2] = max(0, boxes[2]) 116 | boxes[1] = min(crop_width, boxes[1]) 117 | boxes[2] = min(crop_height, boxes[2]) 118 | 119 | return img, parts, boxes 120 | 121 | 122 | def _get_center_crop_params_(image_height: int, image_width: int, output_size: Optional[List[int]] = 448): 123 | """ 124 | Get the parameters for center cropping the image 125 | :param image_height: Height of the image 126 | :param image_width: Width of the image 127 | :param output_size: Output size of the cropped image 128 | :return: 129 | crop_top: Top coordinate of the cropped image 130 | crop_left: Left coordinate of the cropped image 131 | """ 132 | if isinstance(output_size, int): 133 | output_size = (output_size, output_size) 134 | elif isinstance(output_size, (tuple, list)) and len(output_size) == 1: 135 | output_size = (output_size[0], output_size[0]) 136 | elif isinstance(output_size, (tuple, list)) and len(output_size) == 2: 137 | output_size = output_size 138 | else: 139 | raise ValueError(f"Invalid output size: {output_size}") 140 | 141 | crop_height, crop_width = output_size 142 | 143 | if crop_width > image_width or crop_height > image_height: 144 | padding_ltrb = [ 145 | (crop_width - image_width) // 2 if crop_width > image_width else 0, 146 | (crop_height - image_height) // 2 if crop_height > image_height else 0, 147 | (crop_width - image_width + 1) // 2 if crop_width > image_width else 0, 148 | (crop_height - image_height + 1) // 2 if crop_height > image_height else 0, 149 | ] 150 | crop_top, crop_left = padding_ltrb[1], padding_ltrb[0] 151 | return crop_top, crop_left 152 | 153 | if crop_width == image_width and crop_height == image_height: 154 | crop_top = 0 155 | crop_left = 0 156 | return crop_top, crop_left 157 | 158 | crop_top = int(round((image_height - crop_height) / 2.0)) 159 | crop_left = int(round((image_width - crop_width) / 2.0)) 160 | 161 | return crop_top, crop_left 162 | -------------------------------------------------------------------------------- /utils/training_utils/optimizer_params.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from timm.optim.lars import Lars 4 | from timm.optim.lamb import Lamb 5 | from utils.training_utils.ddp_utils import calculate_effective_batch_size 6 | 7 | 8 | def build_optimizer(args, params_groups, dataset_train): 9 | """ 10 | Function to build the optimizer 11 | :param args: arguments from the command line 12 | :param params_groups: parameters to be optimized 13 | :param dataset_train: training dataset 14 | :return: optimizer 15 | """ 16 | grad_averaging = not args.turn_off_grad_averaging 17 | weight_decay = calculate_weight_decay(args, dataset_train) 18 | if args.optimizer_type == 'adamw': 19 | return torch.optim.AdamW(params=params_groups, betas=(args.betas1, args.betas2), lr=args.lr, 20 | weight_decay=weight_decay) 21 | elif args.optimizer_type == 'sgd': 22 | return torch.optim.SGD(params=params_groups, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, 23 | nesterov=True) 24 | elif args.optimizer_type == 'adam': 25 | return torch.optim.Adam(params=params_groups, betas=(args.betas1, args.betas2), lr=args.lr, 26 | weight_decay=weight_decay) 27 | elif args.optimizer_type == 'nadam': 28 | return torch.optim.NAdam(params=params_groups, betas=(args.betas1, args.betas2), lr=args.lr, 29 | weight_decay=weight_decay) 30 | elif args.optimizer_type == 'lars': 31 | return Lars(params=params_groups, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, 32 | dampening=args.dampening, trust_coeff=args.trust_coeff, trust_clip=False, 33 | always_adapt=args.always_adapt) 34 | elif args.optimizer_type == 'nlars': 35 | return Lars(params=params_groups, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, 36 | dampening=args.dampening, nesterov=True, trust_coeff=args.trust_coeff, trust_clip=False, 37 | always_adapt=args.always_adapt) 38 | elif args.optimizer_type == 'larc': 39 | return Lars(params=params_groups, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, 40 | dampening=args.dampening, trust_coeff=args.trust_coeff, trust_clip=True, 41 | always_adapt=args.always_adapt) 42 | elif args.optimizer_type == 'nlarc': 43 | return Lars(params=params_groups, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay, 44 | dampening=args.dampening, nesterov=True, trust_coeff=args.trust_coeff, trust_clip=True, 45 | always_adapt=args.always_adapt) 46 | elif args.optimizer_type == 'lamb': 47 | return Lamb(params=params_groups, lr=args.lr, betas=(args.betas1, args.betas2), weight_decay=weight_decay, 48 | grad_averaging=grad_averaging, max_grad_norm=args.max_grad_norm, trust_clip=False, 49 | always_adapt=args.always_adapt) 50 | elif args.optimizer_type == 'lambc': 51 | return Lamb(params=params_groups, lr=args.lr, betas=(args.betas1, args.betas2), weight_decay=weight_decay, 52 | grad_averaging=grad_averaging, max_grad_norm=args.max_grad_norm, trust_clip=True, 53 | always_adapt=args.always_adapt) 54 | else: 55 | raise NotImplementedError(f'Optimizer {args.optimizer_type} not implemented.') 56 | 57 | 58 | def calculate_weight_decay(args, dataset_train): 59 | """ 60 | Function to calculate the weight decay 61 | Implementation of normalized weight decay as per the paper "Decoupled Weight Decay Regularization": https://arxiv.org/pdf/1711.05101.pdf 62 | :param args: Arguments from the command line 63 | :param dataset_train: Training dataset 64 | :return: weight_decay: Weight decay 65 | """ 66 | batch_size = calculate_effective_batch_size(args) 67 | num_iterations = len(dataset_train) // batch_size # Since we set drop_last=True 68 | norm_weight_decay = args.weight_decay 69 | weight_decay = norm_weight_decay * math.sqrt(1 / (num_iterations * args.epochs)) 70 | return weight_decay 71 | 72 | 73 | def layer_group_matcher_pdisco(args, model): 74 | """ 75 | Function to group the parameters of the model into different groups 76 | :param args: Arguments from the command line 77 | :param model: Model to be trained 78 | :return: param_groups: Parameters grouped into different groups 79 | """ 80 | scratch_layers = ["fc_class_landmarks"] 81 | modulation_layers = ["modulation", "modulation_parts", "modulation_instances"] 82 | finer_layers = ["fc_landmarks", "fc_landmarks_instances", "decoder", "landmark_tokens"] 83 | unfrozen_layers = ["cls_token", "pos_embed", "reg_token"] 84 | scratch_parameters = [] 85 | modulation_parameters = [] 86 | backbone_parameters_wd = [] 87 | no_weight_decay_params = [] 88 | finer_parameters = [] 89 | 90 | for name, p in model.named_parameters(): 91 | if any(x in name for x in scratch_layers): 92 | scratch_parameters.append(p) 93 | p.requires_grad = True 94 | 95 | elif any(x in name for x in modulation_layers): 96 | modulation_parameters.append(p) 97 | p.requires_grad = True 98 | 99 | elif any(x in name for x in finer_layers): 100 | finer_parameters.append(p) 101 | p.requires_grad = True 102 | 103 | elif any(x in name for x in unfrozen_layers): 104 | no_weight_decay_params.append(p) 105 | if args.freeze_params: 106 | p.requires_grad = False 107 | else: 108 | p.requires_grad = True 109 | 110 | else: 111 | if args.freeze_backbone: 112 | p.requires_grad = False 113 | else: 114 | p.requires_grad = True 115 | 116 | if p.ndim == 1: 117 | no_weight_decay_params.append(p) 118 | else: 119 | backbone_parameters_wd.append(p) 120 | 121 | param_groups = [{'params': backbone_parameters_wd, 'lr': args.lr}, 122 | {'params': no_weight_decay_params, 'lr': args.lr, 'weight_decay': 0.0}, 123 | {'params': finer_parameters, 'lr': args.lr * args.finer_lr_factor, 'weight_decay': 0.0}, 124 | {'params': modulation_parameters, 'lr': args.lr * args.modulation_lr_factor, 'weight_decay': 0.0}, 125 | {'params': scratch_parameters, 'lr': args.lr * args.scratch_lr_factor}] 126 | 127 | return param_groups 128 | -------------------------------------------------------------------------------- /models/individual_landmark_convnext.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from torch.nn import Parameter 4 | from typing import Any 5 | from layers.independent_mlp import IndependentMLPs 6 | 7 | 8 | # Baseline model, a modified convnext with reduced downsampling for a spatially larger feature tensor in the last layer 9 | class IndividualLandmarkConvNext(torch.nn.Module): 10 | def __init__(self, init_model: torch.nn.Module, num_landmarks: int = 8, 11 | num_classes: int = 200, sl_channels: int = 1024, fl_channels: int = 2048, part_dropout: float = 0.3, 12 | modulation_type: str = "original", modulation_orth: bool = False, gumbel_softmax: bool = False, 13 | gumbel_softmax_temperature: float = 1.0, gumbel_softmax_hard: bool = False, 14 | classifier_type: str = "linear", noise_variance: float = 0.0) -> None: 15 | super().__init__() 16 | 17 | self.num_landmarks = num_landmarks 18 | self.num_classes = num_classes 19 | self.noise_variance = noise_variance 20 | self.stem = init_model.stem 21 | self.stages = init_model.stages 22 | self.feature_dim = sl_channels + fl_channels 23 | self.fc_landmarks = torch.nn.Conv2d(self.feature_dim, num_landmarks + 1, 1, bias=False) 24 | self.gumbel_softmax = gumbel_softmax 25 | self.gumbel_softmax_temperature = gumbel_softmax_temperature 26 | self.gumbel_softmax_hard = gumbel_softmax_hard 27 | self.modulation_type = modulation_type 28 | if modulation_type == "layer_norm": 29 | self.modulation = torch.nn.LayerNorm([self.feature_dim, self.num_landmarks + 1]) 30 | elif modulation_type == "original": 31 | self.modulation = torch.nn.Parameter(torch.ones(1, self.feature_dim, self.num_landmarks + 1)) 32 | elif modulation_type == "parallel_mlp": 33 | self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim, 34 | num_lin_layers=1, act_layer=True, bias=True) 35 | elif modulation_type == "parallel_mlp_no_bias": 36 | self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim, 37 | num_lin_layers=1, act_layer=True, bias=False) 38 | elif modulation_type == "parallel_mlp_no_act": 39 | self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim, 40 | num_lin_layers=1, act_layer=False, bias=True) 41 | elif modulation_type == "parallel_mlp_no_act_no_bias": 42 | self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim, 43 | num_lin_layers=1, act_layer=False, bias=False) 44 | elif modulation_type == "none": 45 | self.modulation = torch.nn.Identity() 46 | else: 47 | raise ValueError("modulation_type not implemented") 48 | self.modulation_orth = modulation_orth 49 | self.dropout_full_landmarks = torch.nn.Dropout1d(part_dropout) 50 | self.classifier_type = classifier_type 51 | if classifier_type == "independent_mlp": 52 | self.fc_class_landmarks = IndependentMLPs(part_dim=self.num_landmarks, latent_dim=self.feature_dim, 53 | num_lin_layers=1, act_layer=False, out_dim=num_classes, 54 | bias=False, stack_dim=1) 55 | elif classifier_type == "linear": 56 | self.fc_class_landmarks = torch.nn.Linear(in_features=self.feature_dim, out_features=num_classes, 57 | bias=False) 58 | else: 59 | raise ValueError("classifier_type not implemented") 60 | 61 | def forward(self, x: Tensor) -> tuple[Any, Any, Any, Any, Parameter, int | Any]: 62 | # Pretrained ConvNeXt part of the model 63 | x = self.stem(x) 64 | x = self.stages[0](x) 65 | x = self.stages[1](x) 66 | l3 = self.stages[2](x) 67 | x = self.stages[3](l3) 68 | x = torch.nn.functional.interpolate(x, size=(l3.shape[-2], l3.shape[-1]), mode='bilinear', align_corners=False) 69 | x = torch.cat((x, l3), dim=1) 70 | 71 | # Compute per landmark attention maps 72 | # (b - a)^2 = b^2 - 2ab + a^2, b = feature maps resnet, a = convolution kernel 73 | batch_size = x.shape[0] 74 | ab = self.fc_landmarks(x) 75 | b_sq = x.pow(2).sum(1, keepdim=True) 76 | b_sq = b_sq.expand(-1, self.num_landmarks + 1, -1, -1).contiguous() 77 | a_sq = self.fc_landmarks.weight.pow(2).sum(1).unsqueeze(1).expand(-1, batch_size, x.shape[-2], 78 | x.shape[-1]).contiguous() 79 | a_sq = a_sq.permute(1, 0, 2, 3).contiguous() 80 | 81 | dist = b_sq - 2 * ab + a_sq 82 | maps = -dist 83 | 84 | # Softmax so that the attention maps for each pixel add up to 1 85 | if self.gumbel_softmax: 86 | maps = torch.nn.functional.gumbel_softmax(maps, dim=1, tau=self.gumbel_softmax_temperature, 87 | hard=self.gumbel_softmax_hard) # [B, num_landmarks + 1, H, W] 88 | else: 89 | maps = torch.nn.functional.softmax(maps, dim=1) # [B, num_landmarks + 1, H, W] 90 | 91 | # Use maps to get weighted average features per landmark 92 | all_features = (maps.unsqueeze(1) * x.unsqueeze(2)).mean(-1).mean(-1).contiguous() 93 | if self.noise_variance > 0.0: 94 | all_features += torch.randn_like(all_features, 95 | device=all_features.device) * x.std().detach() * self.noise_variance 96 | 97 | # Modulate the features 98 | if self.modulation_type == "original": 99 | all_features_mod = all_features * self.modulation 100 | else: 101 | all_features_mod = self.modulation(all_features) 102 | 103 | # Classification based on the landmark features 104 | scores = self.fc_class_landmarks( 105 | self.dropout_full_landmarks(all_features_mod[..., :-1].permute(0, 2, 1).contiguous())).permute(0, 2, 106 | 1).contiguous() 107 | if self.modulation_orth: 108 | return all_features_mod, maps, scores, dist 109 | else: 110 | return all_features, maps, scores, dist 111 | -------------------------------------------------------------------------------- /utils/training_utils/ddp_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.distributed import init_process_group 4 | import torch.distributed as dist 5 | import numpy as np 6 | import subprocess 7 | import socket 8 | 9 | 10 | def get_local_rank(): 11 | use_ddp = multi_gpu_check() 12 | is_slurm_job = "SLURM_NODEID" in os.environ 13 | if is_slurm_job: 14 | local_rank = int(os.environ['SLURM_LOCALID']) 15 | else: 16 | if not use_ddp: 17 | local_rank = 0 18 | else: 19 | local_rank = int(os.environ["LOCAL_RANK"]) 20 | return local_rank 21 | 22 | 23 | def is_enabled() -> bool: 24 | """ 25 | Returns: 26 | True if distributed training is enabled 27 | """ 28 | return dist.is_available() and dist.is_initialized() 29 | 30 | 31 | def get_global_rank() -> int: 32 | """ 33 | Returns: 34 | The rank of the current process within the global process group. 35 | """ 36 | return dist.get_rank() if is_enabled() else 0 37 | 38 | 39 | def is_main_process() -> bool: 40 | """ 41 | Returns: 42 | True if the current process is the main one. 43 | """ 44 | return get_global_rank() == 0 45 | 46 | 47 | def save_on_master(*args, **kwargs): 48 | if is_main_process(): 49 | torch.save(*args, **kwargs) 50 | # print("Saved checkpoint on master process.") 51 | 52 | 53 | def unwrap_model(model): 54 | if hasattr(model, 'module'): 55 | return unwrap_model(model.module) 56 | elif hasattr(model, '_orig_mod'): 57 | return unwrap_model(model._orig_mod) 58 | else: 59 | return model 60 | 61 | 62 | def get_state_dict(model, unwrap_fn=unwrap_model): 63 | return unwrap_fn(model).state_dict() 64 | 65 | 66 | def ddp_setup(): 67 | is_slurm_job = "SLURM_NODEID" in os.environ 68 | if is_slurm_job: 69 | # Define the process group based on SLURM env variables 70 | # number of nodes / node ID 71 | n_nodes = int(os.environ['SLURM_JOB_NUM_NODES']) 72 | node_id = int(os.environ['SLURM_NODEID']) 73 | 74 | # local rank on the current node / global rank 75 | local_rank = int(os.environ['SLURM_LOCALID']) 76 | global_rank = int(os.environ['SLURM_PROCID']) 77 | 78 | # number of processes / GPUs per node 79 | world_size = int(os.environ['SLURM_NTASKS']) 80 | n_gpu_per_node = world_size // n_nodes 81 | 82 | # define master address and master port 83 | hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', os.environ['SLURM_JOB_NODELIST']]) 84 | master_addr = hostnames.split()[0].decode('utf-8') 85 | 86 | # set environment variables for 'env://' 87 | os.environ['MASTER_ADDR'] = master_addr 88 | os.environ['MASTER_PORT'] = str(29500) 89 | os.environ['WORLD_SIZE'] = str(world_size) 90 | os.environ['RANK'] = str(global_rank) 91 | 92 | # define whether this is the master process / if we are in distributed mode 93 | is_master = node_id == 0 and local_rank == 0 94 | multi_node = n_nodes > 1 95 | multi_gpu = world_size > 1 96 | 97 | # summary 98 | prefix = "%i - " % global_rank 99 | if local_rank == 0: 100 | print(prefix + "Number of nodes: %i" % n_nodes) 101 | print(prefix + "Node ID : %i" % node_id) 102 | print(prefix + "Local rank : %i" % local_rank) 103 | print(prefix + "Global rank : %i" % global_rank) 104 | print(prefix + "World size : %i" % world_size) 105 | print(prefix + "GPUs per node : %i" % n_gpu_per_node) 106 | print(prefix + "Master : %s" % str(is_master)) 107 | print(prefix + "Multi-node : %s" % str(multi_node)) 108 | print(prefix + "Multi-GPU : %s" % str(multi_gpu)) 109 | print(prefix + "Hostname : %s" % socket.gethostname()) 110 | else: 111 | local_rank = int(os.environ["LOCAL_RANK"]) 112 | if local_rank == 0: 113 | print("Initializing PyTorch distributed ...") 114 | init_process_group(init_method='env://', backend="nccl") 115 | torch.cuda.set_device(local_rank) 116 | return 117 | 118 | 119 | def set_seeds(seed_value: int = 42): 120 | # Set the manual seeds 121 | torch.manual_seed(seed_value) 122 | torch.cuda.manual_seed(seed_value) 123 | np.random.seed(seed_value) 124 | 125 | 126 | def reduce_tensor(tensor: torch.Tensor, world_size: int): 127 | """Reduce tensor across all nodes.""" 128 | rt = tensor.clone() 129 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 130 | rt /= world_size 131 | return rt 132 | 133 | 134 | def to_python_float(t: torch.Tensor): 135 | if hasattr(t, 'item'): 136 | return t.item() 137 | else: 138 | return t[0] 139 | 140 | 141 | def is_dist_avail_and_initialized(): 142 | if not dist.is_available(): 143 | return False 144 | if not dist.is_initialized(): 145 | return False 146 | return True 147 | 148 | 149 | def get_rank(): 150 | if not is_dist_avail_and_initialized(): 151 | return 0 152 | return dist.get_rank() 153 | 154 | 155 | @torch.no_grad() 156 | def concat_all_gather(tensor): 157 | """ 158 | Performs all_gather operation on the provided tensors. 159 | *** Warning ***: torch.distributed.all_gather has no gradient. 160 | """ 161 | tensors_gather = [torch.ones_like(tensor) 162 | for _ in range(torch.distributed.get_world_size())] 163 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 164 | 165 | output = torch.cat(tensors_gather, dim=0) 166 | return output 167 | 168 | 169 | def multi_gpu_check(): 170 | """ 171 | Check if there are multiple GPUs available for DDP 172 | :return: 173 | use_ddp: bool, whether to use DDP or not 174 | """ 175 | torchrun_active = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run? 176 | slurm_active = "SLURM_NODEID" in os.environ # is this a slurm job? 177 | if slurm_active: 178 | # Check device count 179 | slurm_active = torch.cuda.device_count() > 1 180 | use_ddp = torchrun_active or slurm_active 181 | return use_ddp 182 | 183 | 184 | def calculate_effective_batch_size(args): 185 | """ 186 | Calculate the effective batch size for DDP 187 | :param args: Arguments from the argument parser 188 | :return: 189 | effective_batch_size: int, effective batch size 190 | """ 191 | batch_size = args.batch_size 192 | use_ddp = multi_gpu_check() 193 | is_slurm_job = "SLURM_NODEID" in os.environ 194 | if is_slurm_job: 195 | # number of processes / GPUs per node 196 | world_size = int(os.environ['SLURM_NTASKS']) 197 | else: 198 | if use_ddp: 199 | world_size = int(os.environ['WORLD_SIZE']) 200 | else: 201 | world_size = 1 202 | 203 | effective_batch_size = batch_size * world_size 204 | return effective_batch_size 205 | -------------------------------------------------------------------------------- /utils/visualize_att_maps.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | 3 | matplotlib.use('Agg') 4 | import matplotlib.pyplot as plt 5 | from mpl_toolkits.axes_grid1 import make_axes_locatable 6 | import colorcet as cc 7 | import numpy as np 8 | import skimage 9 | from pathlib import Path 10 | import os 11 | import torch 12 | 13 | from utils.data_utils.transform_utils import inverse_normalize_w_resize 14 | from utils.misc_utils import factors 15 | 16 | # Define the colors to use for the attention maps 17 | colors = cc.glasbey_category10 18 | 19 | 20 | class VisualizeAttentionMaps: 21 | def __init__(self, snapshot_dir="", save_resolution=(256, 256), alpha=0.5, sub_path_test="", 22 | dataset_name="", bg_label=0, batch_size=32, num_parts=15, plot_ims_separately=False, 23 | plot_landmark_amaps=False): 24 | """ 25 | Plot attention maps and optionally landmark centroids on images. 26 | :param snapshot_dir: Directory to save the visualization results 27 | :param save_resolution: Size of the images to save 28 | :param alpha: The transparency of the attention maps 29 | :param sub_path_test: The sub-path of the test dataset 30 | :param dataset_name: The name of the dataset 31 | :param bg_label: The background label index in the attention maps 32 | :param batch_size: The batch size 33 | :param num_parts: The number of parts in the attention maps 34 | :param plot_ims_separately: Whether to plot the images separately 35 | :param plot_landmark_amaps: Whether to plot the landmark attention maps 36 | """ 37 | self.save_resolution = save_resolution 38 | self.alpha = alpha 39 | self.sub_path_test = sub_path_test 40 | self.dataset_name = dataset_name 41 | self.bg_label = bg_label 42 | self.snapshot_dir = snapshot_dir 43 | if self.snapshot_dir == "": 44 | matplotlib.use('Qt5Agg') 45 | self.resize_unnorm = inverse_normalize_w_resize(resize_resolution=self.save_resolution) 46 | self.batch_size = batch_size 47 | self.nrows = factors(self.batch_size)[-1] 48 | self.ncols = factors(self.batch_size)[-2] 49 | self.num_parts = num_parts 50 | self.plot_ims_separately = plot_ims_separately 51 | self.plot_landmark_amaps = plot_landmark_amaps 52 | if self.nrows == 1 and self.ncols == 1: 53 | self.figs_size = (10, 10) 54 | else: 55 | self.figs_size = (self.ncols * 2, self.nrows * 2) 56 | 57 | def recalculate_nrows_ncols(self): 58 | self.nrows = factors(self.batch_size)[-1] 59 | self.ncols = factors(self.batch_size)[-2] 60 | if self.nrows == 1 and self.ncols == 1: 61 | self.figs_size = (10, 10) 62 | else: 63 | self.figs_size = (self.ncols * 2, self.nrows * 2) 64 | 65 | @torch.no_grad() 66 | def show_maps(self, ims, maps, epoch=0, curr_iter=0, extra_info=""): 67 | """ 68 | Plot images, attention maps and landmark centroids. 69 | Parameters 70 | ---------- 71 | ims: Tensor, [batch_size, 3, width_im, height_im] 72 | Input images on which to show the attention maps 73 | maps: Tensor, [batch_size, number of parts + 1, width_map, height_map] 74 | The attention maps to display 75 | epoch: int 76 | The epoch number 77 | curr_iter: int 78 | The current iteration number 79 | extra_info: str 80 | Any extra information to add to the file name 81 | """ 82 | ims = self.resize_unnorm(ims) 83 | if ims.shape[0] != self.batch_size: 84 | self.batch_size = ims.shape[0] 85 | self.recalculate_nrows_ncols() 86 | fig, axs = plt.subplots(nrows=self.nrows, ncols=self.ncols, squeeze=False, figsize=self.figs_size) 87 | ims = (ims.permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8) 88 | map_argmax = torch.nn.functional.interpolate(maps.clone().detach(), size=self.save_resolution, 89 | mode='bilinear', 90 | align_corners=False, antialias=True).argmax(dim=1).cpu().numpy() 91 | 92 | # Select colors for parts which are present 93 | parts_present = np.unique(map_argmax).tolist() 94 | if self.bg_label in parts_present: 95 | parts_present.remove(self.bg_label) 96 | colors_present = [colors[i] for i in parts_present] 97 | for i, ax in enumerate(axs.ravel()): 98 | curr_map = skimage.color.label2rgb(label=map_argmax[i], image=ims[i], colors=colors_present, 99 | bg_label=self.bg_label, alpha=self.alpha) 100 | ax.imshow(curr_map) 101 | ax.axis('off') 102 | save_dir = Path(os.path.join(self.snapshot_dir, 'results_vis_' + self.sub_path_test)) 103 | save_dir.mkdir(parents=True, exist_ok=True) 104 | save_path = os.path.join(save_dir, f'{epoch}_{curr_iter}_{self.dataset_name}{extra_info}.png') 105 | fig.tight_layout() 106 | if self.snapshot_dir != "": 107 | plt.savefig(save_path, bbox_inches='tight', pad_inches=0) 108 | else: 109 | plt.show() 110 | plt.close('all') 111 | 112 | if self.plot_ims_separately: 113 | fig, axs = plt.subplots(nrows=self.nrows, ncols=self.ncols, squeeze=False, figsize=self.figs_size) 114 | for i, ax in enumerate(axs.ravel()): 115 | ax.imshow(ims[i]) 116 | ax.axis('off') 117 | save_path = os.path.join(save_dir, f'image_{epoch}_{curr_iter}_{self.dataset_name}{extra_info}.jpg') 118 | fig.tight_layout() 119 | if self.snapshot_dir != "": 120 | plt.savefig(save_path, bbox_inches='tight', pad_inches=0) 121 | else: 122 | plt.show() 123 | plt.close('all') 124 | 125 | if self.plot_landmark_amaps: 126 | if self.batch_size > 1: 127 | raise ValueError('Not implemented for batch size > 1') 128 | for i in range(self.num_parts): 129 | fig, ax = plt.subplots(1, 1, figsize=self.figs_size) 130 | divider = make_axes_locatable(ax) 131 | cax = divider.append_axes('right', size='5%', pad=0.05) 132 | im = ax.imshow(maps[0, i, ...].detach().cpu().numpy(), cmap='cet_gouldian') 133 | fig.colorbar(im, cax=cax, orientation='vertical') 134 | ax.axis('off') 135 | save_path = os.path.join(save_dir, 136 | f'landmark_{i}_{epoch}_{curr_iter}_{self.dataset_name}{extra_info}.png') 137 | fig.tight_layout() 138 | if self.snapshot_dir != "": 139 | plt.savefig(save_path, bbox_inches='tight', pad_inches=0) 140 | else: 141 | plt.show() 142 | plt.close() 143 | 144 | plt.close('all') 145 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PDiscoFormer: Relaxing Part Discovery Constraints with Vision Transformers 2 | Official implementation of the paper "PDiscoFormer: Relaxing Part Discovery Constraints with Vision Transformers", accepted as an Oral presentation at ECCV 2024. 3 | 4 | 5 | 6 | 7 | [[`Oral`]](https://eccv.ecva.net/virtual/2024/oral/125)[[`🤗 Space`]](https://huggingface.co/spaces/ananthu-aniraj/pdiscoformer)[[`Paper`]](https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/11397.pdf) [[`Supp.`]](https://www.ecva.net/papers/eccv_2024/papers_ECCV/papers/11397-supp.pdf) [[`Arxiv`]](https://arxiv.org/abs/2407.04538) [[`🤗 Page`]](https://huggingface.co/papers/2407.04538) 8 | 9 | 10 | # Abstract 11 | Computer vision methods that explicitly detect object parts and reason on them are a step towards inherently interpretable models. Existing approaches that perform part discovery driven by a fine-grained classification task make very restrictive assumptions on the geometric properties of the discovered parts; they should be small and compact. Although this prior is useful in some cases, in this paper we show that pre-trained transformer-based vision models, such as self-supervised DINOv2 ViT, enable the relaxation of these constraints. In particular, we find that a total variation (TV) prior, which allows for multiple connected components of any size, substantially outperforms previous work. We test our approach on three fine-grained classification benchmarks: CUB, PartImageNet and Oxford Flowers, and compare our results to previously published methods as well as a re-implementation of the state-of-the-art method PDiscoNet with a transformer-based backbone. We consistently obtain substantial improvements across the board, both on part discovery metrics and the downstream classification task, showing that the strong inductive biases in self-supervised ViT models require to rethink the geometric priors that can be used for unsupervised part discovery. 12 | 13 | 14 | # Model Architecture 15 | ![image](https://github.com/ananthu-aniraj/pdiscoformer/assets/50333505/73c30fb1-2f2c-408a-81dd-4447f9091f86) 16 | 17 | # Updates 18 | 1. The code has been updated to support the NABirds dataset. The corresponding evaluation metrics and pre-trained models have also been added. 19 | 2. The models are available via torch hub. The details can be found in the [model zoo](model_zoo.md) file. 20 | 3. PDiscoformer has been accepted as an Oral presentation at ECCV 2024 :tada: 21 | 4. Models are now available via HuggingFace. Thanks to [Niels Rogge](https://github.com/NielsRogge) and [Merve Noyan](https://github.com/merveenoyan). 22 | 23 | # Setup 24 | To install the required packages, run the following command: 25 | ```conda env create -f environment.yml``` 26 | 27 | Otherwise, you can also individually install the following packages: 28 | 1. [PyTorch](https://pytorch.org/get-started/locally/): Tested upto version 2.3, please raise an issue if you face any problems with more recent versions. 29 | 2. [Colorcet](https://colorcet.holoviz.org/getting_started/index.html) 30 | 3. [Matplotlib](https://matplotlib.org/stable/users/installing.html) 31 | 3. [OpenCV](https://pypi.org/project/opencv-python-headless/) 32 | 4. [Pandas](https://pandas.pydata.org/pandas-docs/stable/getting_started/install.html) 33 | 5. [Scikit-Image](https://scikit-image.org/docs/stable/install.html) 34 | 6. [Scikit-Learn](https://scikit-learn.org/stable/install.html) 35 | 7. [TorchMetrics](https://torchmetrics.readthedocs.io/en/latest/pages/install.html) 36 | 8. [timm](https://pypi.org/project/timm/) 37 | 9. [wandb](https://pypi.org/project/wandb/): It is recommended to create an account and use it for tracking the experiments. Use the '--wandb' flag when running the training script to enable this feature. 38 | 10. [pycocotools](https://pypi.org/project/pycocotools/) 39 | 11. [pytopk](https://pypi.org/project/pytopk/) 40 | 12. [huggingface-hub](https://pypi.org/project/huggingface-hub/) 41 | 42 | 43 | # Datasets 44 | ### CUB 45 | The dataset can be downloaded from [here](https://www.vision.caltech.edu/datasets/cub_200_2011/). 46 | 47 | The folder structure should look like this: 48 | 49 | ``` 50 | CUB_200_2011 51 | ├── attributes 52 | ├── bounding_boxes.txt 53 | ├── classes.txt 54 | ├── images 55 | ├── image_class_labels.txt 56 | ├── images.txt 57 | ├── parts 58 | ├── README 59 | └── train_test_split.txt 60 | ``` 61 | 62 | ### PartImageNet OOD 63 | The dataset can be downloaded from [here](https://github.com/TACJu/PartImageNet). 64 | After downloading the dataset, use the pre-processing script (prepare_partimagenet_ood.py) and train-test split (data_sets/train_test_split_pimagenet_ood.txt) to generate the required annotation files for training and evaluation. 65 | The command to run the pre-processing script is as follows: 66 | 67 | ```python prepare_partimagenet_ood.py --anno_path --output_dir --train_test_split_file data_sets/train_test_split_pimagenet_ood.txt``` 68 | 69 | ### Oxford Flowers 70 | The dataset is automatically downloaded by the training script with the required folder structure (except for the segmentation masks). 71 | If you want to evaluate the foreground segmentation on the dataset, please download the segmentations from [here](https://www.robots.ox.ac.uk/~vgg/data/flowers/102/). 72 | The final folder structure should look like this: 73 | 74 | ``` 75 | (root folder) 76 | ├── flowers-102 (folder containing the dataset created automatically by the training script) 77 |    ├── segmim (folder containing the segmentation masks) 78 |    ├── jpg 79 |    ├── imagelabels.mat 80 |    └── setid.mat 81 | ``` 82 | ### PartImageNet Seg 83 | The dataset can be downloaded from [here](https://github.com/TACJu/PartImageNet). No additional pre-processing is required. 84 | 85 | ### NABirds 86 | The dataset can be downloaded from [here](https://dl.allaboutbirds.org/nabirds). 87 | The experiments on this dataset are not present in the paper as they were conducted after the paper was submitted. 88 | The folder structure should look like this (essentially the same as CUB except for the attributes): 89 | 90 | ``` 91 | nabirds 92 | ├── bounding_boxes.txt 93 | ├── classes.txt 94 | ├── images 95 | ├── image_class_labels.txt 96 | ├── images.txt 97 | ├── parts 98 | ├── hierarchy.txt 99 | ├── README 100 | └── train_test_split.txt 101 | ``` 102 | 103 | # Training 104 | The details of running the training script can be found in the [training instructions](training_instructions.md) file. 105 | 106 | # Evaluation 107 | The details of running the evaluation metrics for both classification and part discovery can be found in the [evaluation instructions](evaluation_instructions.md) file. 108 | 109 | # Model Zoo 110 | The trained models can be found in the [model zoo](model_zoo.md) file. 111 | 112 | 113 | # Issues and Questions 114 | Feel free to raise an issue if you face any problems with the code or have any questions about the paper. 115 | 116 | # Citation 117 | If you find our work useful in your research, please consider citing: 118 | 119 | ``` 120 | @inproceedings{aniraj2024pdiscoformer, 121 | title={PDiscoFormer: Relaxing Part Discovery Constraints with Vision Transformers}, 122 | author={Aniraj, Ananthu and Dantas, Cassio F and Ienco, Dino and Marcos, Diego}, 123 | booktitle={European Conference on Computer Vision}, 124 | pages={256--272}, 125 | year={2024}, 126 | organization={Springer} 127 | } 128 | ``` 129 | 130 | -------------------------------------------------------------------------------- /data_sets/part_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.utils.data 3 | from pycocotools.coco import COCO 4 | import copy 5 | import torch 6 | import torch.utils.data 7 | from collections import defaultdict 8 | from utils.data_utils.dataset_utils import pil_loader 9 | from .classes_mapping_imagenet import IMAGENET2012_CLASSES 10 | 11 | 12 | class PartImageNetDataset(torch.utils.data.Dataset): 13 | """PartImageNet dataset""" 14 | 15 | def __init__(self, data_path: str, transform=None, 16 | get_masks=False, image_sub_path='train', annotation_file_path="train.json", class_names_to_idx=None, 17 | class_idx_to_names=None, class_names=None, mask_transform=None): 18 | """ 19 | Args: 20 | data_path (string): path to the dataset 21 | transform (callable, optional): Optional transform to be applied 22 | on a sample. 23 | get_masks (bool): whether to return the masks along with the images 24 | image_sub_path (str): sub path to the images 25 | annotation_file_path (str): path to the annotation file 26 | class_names_to_idx (dict): dictionary mapping class names to indices 27 | class_idx_to_names (dict): dictionary mapping class indices to names 28 | class_names (list): list of class names 29 | mask_transform (callable, optional): Optional transform to be applied 30 | on the masks. 31 | """ 32 | self.data_path = data_path 33 | self.transform = transform 34 | self.get_masks = get_masks 35 | self.loader = pil_loader 36 | self.image_sub_path = image_sub_path 37 | self.coco = COCO(annotation_file_path) 38 | self._preprocess_annotations() 39 | self.image_ids = [img_dict['id'] for img_dict in self.coco.imgs.values()] 40 | # Number of key-points in the dataset (Ground truth parts) 41 | self.num_kps = len(self.coco.cats) 42 | # Coarse-grained classes in the dataset 43 | self.super_categories = list(dict.fromkeys([self.coco.cats[cat]['supercategory'] for cat in self.coco.cats])) 44 | self.super_categories.sort() 45 | self.mask_transform = mask_transform 46 | self.img_id_to_label = {} 47 | self.image_id_to_name = {} 48 | self.img_id_to_supercat = {} 49 | 50 | if class_names is None and class_names_to_idx is None and class_idx_to_names is None: 51 | self.class_names = [] 52 | 53 | for img_dict in self.coco.imgs.values(): 54 | img_name = os.path.basename(img_dict['file_name']) 55 | class_name_wordnet = img_name.split('_')[0] 56 | self.class_names.append(IMAGENET2012_CLASSES[class_name_wordnet]) 57 | 58 | self.class_names = list(dict.fromkeys(self.class_names)) 59 | self.class_names.sort() 60 | self.num_classes = len(self.class_names) 61 | self.class_names_to_idx = {class_name: idx for idx, class_name in enumerate(self.class_names)} 62 | self.class_idx_to_names = {idx: class_name for idx, class_name in enumerate(self.class_names)} 63 | else: 64 | self.class_names_to_idx = class_names_to_idx 65 | self.class_idx_to_names = class_idx_to_names 66 | self.class_names = class_names 67 | self.num_classes = len(self.class_names) 68 | filtered_img_iterator = 0 69 | self.filtered_img_id_to_orig_img_id = {} 70 | self.img_ids_filtered = [] 71 | # Number of instances per class 72 | self.per_class_count = defaultdict(int) 73 | for image_id in self.image_ids: 74 | annIds = self.coco.getAnnIds(imgIds=image_id, iscrowd=None) 75 | anns = self.coco.loadAnns(annIds) 76 | img_name = self.coco.loadImgs(image_id)[0]['file_name'] 77 | 78 | if anns: 79 | cats = [ann['category_id'] for ann in anns if ann['area'] > 0] 80 | supercat_img = list(dict.fromkeys([self.coco.cats[cat]['supercategory'] for cat in cats]))[0] 81 | class_name_wordnet = os.path.basename(img_name).split('_')[0] 82 | class_idx = self.class_names_to_idx[IMAGENET2012_CLASSES[class_name_wordnet]] 83 | self.image_id_to_name[filtered_img_iterator] = os.path.join(self.data_path, self.image_sub_path, 84 | img_name) 85 | self.img_ids_filtered.append(filtered_img_iterator) 86 | self.img_id_to_label[filtered_img_iterator] = class_idx 87 | self.filtered_img_id_to_orig_img_id[filtered_img_iterator] = image_id 88 | self.img_id_to_supercat[filtered_img_iterator] = supercat_img 89 | self.per_class_count[self.class_idx_to_names[class_idx]] += 1 90 | filtered_img_iterator += 1 91 | # For top-K loss (class distribution) 92 | self.cls_num_list = [self.per_class_count[self.class_idx_to_names[idx]] for idx in range(self.num_classes)] 93 | 94 | def __len__(self): 95 | return len(self.img_ids_filtered) 96 | 97 | def _preprocess_annotations(self): 98 | json_dict = copy.deepcopy(self.coco.dataset) 99 | for ann in json_dict['annotations']: 100 | if ann["area"] == 0 or ann["iscrowd"] == 1: 101 | continue 102 | for poly_num, seg in enumerate(ann['segmentation']): 103 | if len(seg) == 4: 104 | x1, y1, w, h = ann['bbox'] 105 | x2 = x1 + w 106 | y2 = y1 + h 107 | seg_poly = [x1, y1, x1, y2, x2, y2, x2, y1] 108 | ann['segmentation'][poly_num] = seg_poly 109 | self.coco.dataset = copy.deepcopy(json_dict) 110 | self.coco.createIndex() 111 | 112 | def __getitem__(self, idx): 113 | img_id = self.img_ids_filtered[idx] 114 | img_path = self.image_id_to_name[img_id] 115 | im = self.loader(img_path) 116 | label = self.img_id_to_label[img_id] 117 | if self.transform: 118 | im = self.transform(im) 119 | if not self.get_masks: 120 | return im, label 121 | mask = self.getmasks(img_id) 122 | if self.mask_transform: 123 | mask = self.mask_transform(mask) 124 | return im, label, mask 125 | 126 | def getmasks(self, img_id): 127 | coco = self.coco 128 | original_img_id = self.filtered_img_id_to_orig_img_id[img_id] 129 | anns = coco.imgToAnns[original_img_id] 130 | img = coco.imgs[original_img_id] 131 | mask_tensor = torch.zeros(size=(self.num_kps, img['height'], img['width'])) 132 | for i, ann in enumerate(anns): 133 | if ann["area"] == 0 or ann["iscrowd"] == 1: 134 | continue 135 | cat = ann['category_id'] 136 | mask = torch.as_tensor(coco.annToMask(ann), dtype=torch.float32) 137 | mask_tensor[cat] += mask 138 | return mask_tensor 139 | 140 | def generate_supercat_subset_all(self): 141 | supercat_to_img_ids = defaultdict(list) 142 | for img_id in self.img_ids_filtered: 143 | supercat_to_img_ids[self.img_id_to_supercat[img_id]].append(img_id) 144 | return supercat_to_img_ids 145 | 146 | 147 | if __name__ == '__main__': 148 | pass 149 | -------------------------------------------------------------------------------- /models/individual_landmark_resnet.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/robertdvdk/part_detection/blob/main/nets.py 2 | import torch 3 | from torch import Tensor 4 | from timm.models import create_model 5 | from torchvision.models import get_model 6 | from torch.nn import Parameter 7 | from typing import Any 8 | from layers.independent_mlp import IndependentMLPs 9 | 10 | 11 | # Baseline model, a modified ResNet with reduced downsampling for a spatially larger feature tensor in the last layer 12 | class IndividualLandmarkResNet(torch.nn.Module): 13 | def __init__(self, init_model: torch.nn.Module, num_landmarks: int = 8, 14 | num_classes: int = 200, sl_channels: int = 1024, fl_channels: int = 2048, 15 | use_torchvision_model: bool = False, part_dropout: float = 0.3, 16 | modulation_type: str = "original", modulation_orth: bool = False, gumbel_softmax: bool = False, 17 | gumbel_softmax_temperature: float = 1.0, gumbel_softmax_hard: bool = False, 18 | classifier_type: str = "linear", noise_variance: float = 0.0) -> None: 19 | super().__init__() 20 | 21 | self.num_landmarks = num_landmarks 22 | self.num_classes = num_classes 23 | self.noise_variance = noise_variance 24 | self.conv1 = init_model.conv1 25 | self.bn1 = init_model.bn1 26 | if use_torchvision_model: 27 | self.act1 = init_model.relu 28 | else: 29 | self.act1 = init_model.act1 30 | self.maxpool = init_model.maxpool 31 | self.layer1 = init_model.layer1 32 | self.layer2 = init_model.layer2 33 | self.layer3 = init_model.layer3 34 | self.layer4 = init_model.layer4 35 | self.feature_dim = sl_channels + fl_channels 36 | self.fc_landmarks = torch.nn.Conv2d(self.feature_dim, num_landmarks + 1, 1, bias=False) 37 | self.gumbel_softmax = gumbel_softmax 38 | self.gumbel_softmax_temperature = gumbel_softmax_temperature 39 | self.gumbel_softmax_hard = gumbel_softmax_hard 40 | self.modulation_type = modulation_type 41 | if modulation_type == "layer_norm": 42 | self.modulation = torch.nn.LayerNorm([self.feature_dim, self.num_landmarks + 1]) 43 | elif modulation_type == "original": 44 | self.modulation = torch.nn.Parameter(torch.ones(1, self.feature_dim, self.num_landmarks + 1)) 45 | elif modulation_type == "parallel_mlp": 46 | self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim, 47 | num_lin_layers=1, act_layer=True, bias=True) 48 | elif modulation_type == "parallel_mlp_no_bias": 49 | self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim, 50 | num_lin_layers=1, act_layer=True, bias=False) 51 | elif modulation_type == "parallel_mlp_no_act": 52 | self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim, 53 | num_lin_layers=1, act_layer=False, bias=True) 54 | elif modulation_type == "parallel_mlp_no_act_no_bias": 55 | self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim, 56 | num_lin_layers=1, act_layer=False, bias=False) 57 | elif modulation_type == "none": 58 | self.modulation = torch.nn.Identity() 59 | else: 60 | raise ValueError("modulation_type not implemented") 61 | 62 | self.modulation_orth = modulation_orth 63 | 64 | self.dropout_full_landmarks = torch.nn.Dropout1d(part_dropout) 65 | self.classifier_type = classifier_type 66 | if classifier_type == "independent_mlp": 67 | self.fc_class_landmarks = IndependentMLPs(part_dim=self.num_landmarks, latent_dim=self.feature_dim, 68 | num_lin_layers=1, act_layer=False, out_dim=num_classes, 69 | bias=False, stack_dim=1) 70 | elif classifier_type == "linear": 71 | self.fc_class_landmarks = torch.nn.Linear(in_features=self.feature_dim, out_features=num_classes, 72 | bias=False) 73 | else: 74 | raise ValueError("classifier_type not implemented") 75 | 76 | def forward(self, x: Tensor) -> tuple[Any, Any, Any, Any, Parameter, int | Any]: 77 | # Pretrained ResNet part of the model 78 | x = self.conv1(x) 79 | x = self.bn1(x) 80 | x = self.act1(x) 81 | x = self.maxpool(x) 82 | x = self.layer1(x) 83 | x = self.layer2(x) 84 | l3 = self.layer3(x) 85 | x = self.layer4(l3) 86 | x = torch.nn.functional.interpolate(x, size=(l3.shape[-2], l3.shape[-1]), mode='bilinear', align_corners=False) 87 | x = torch.cat((x, l3), dim=1) 88 | 89 | # Compute per landmark attention maps 90 | # (b - a)^2 = b^2 - 2ab + a^2, b = feature maps resnet, a = convolution kernel 91 | batch_size = x.shape[0] 92 | 93 | ab = self.fc_landmarks(x) 94 | b_sq = x.pow(2).sum(1, keepdim=True) 95 | b_sq = b_sq.expand(-1, self.num_landmarks + 1, -1, -1).contiguous() 96 | a_sq = self.fc_landmarks.weight.pow(2).sum(1).unsqueeze(1).expand(-1, batch_size, x.shape[-2], 97 | x.shape[-1]).contiguous() 98 | a_sq = a_sq.permute(1, 0, 2, 3).contiguous() 99 | 100 | dist = b_sq - 2 * ab + a_sq 101 | maps = -dist 102 | 103 | # Softmax so that the attention maps for each pixel add up to 1 104 | if self.gumbel_softmax: 105 | maps = torch.nn.functional.gumbel_softmax(maps, dim=1, tau=self.gumbel_softmax_temperature, 106 | hard=self.gumbel_softmax_hard) # [B, num_landmarks + 1, H, W] 107 | else: 108 | maps = torch.nn.functional.softmax(maps, dim=1) # [B, num_landmarks + 1, H, W] 109 | 110 | # Use maps to get weighted average features per landmark 111 | all_features = (maps.unsqueeze(1) * x.unsqueeze(2)).mean(-1).mean(-1).contiguous() 112 | if self.noise_variance > 0.0: 113 | all_features += torch.randn_like(all_features, 114 | device=all_features.device) * x.std().detach() * self.noise_variance 115 | 116 | # Modulate the features 117 | if self.modulation_type == "original": 118 | all_features_mod = all_features * self.modulation 119 | else: 120 | all_features_mod = self.modulation(all_features) 121 | 122 | # Classification based on the landmark features 123 | scores = self.fc_class_landmarks( 124 | self.dropout_full_landmarks(all_features_mod[..., :-1].permute(0, 2, 1).contiguous())).permute(0, 2, 125 | 1).contiguous() 126 | if self.modulation_orth: 127 | return all_features_mod, maps, scores, dist 128 | else: 129 | return all_features, maps, scores, dist 130 | 131 | 132 | def pdisconet_resnet_torchvision_bb(backbone, num_cls=200, k=8, **kwargs): 133 | base_model = get_model(backbone) 134 | return IndividualLandmarkResNet(base_model, num_landmarks=k, num_classes=num_cls, 135 | modulation_type="original") 136 | 137 | 138 | def pdisconet_resnet_timm_bb(backbone, num_cls=200, k=8, output_stride=32, **kwargs): 139 | base_model = create_model(backbone, pretrained=True, output_stride=output_stride) 140 | return IndividualLandmarkResNet(base_model, num_landmarks=k, num_classes=num_cls, 141 | modulation_type="original") 142 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: pdiscoformer_venv 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=conda_forge 9 | - _openmp_mutex=4.5=2_gnu 10 | - aom=3.6.0=h6a678d5_0 11 | - blas=1.0=mkl 12 | - blosc=1.21.3=h6a678d5_0 13 | - bottleneck=1.3.5=py311hbed6279_0 14 | - brotli=1.0.9=h5eee18b_7 15 | - brotli-bin=1.0.9=h5eee18b_7 16 | - brotli-python=1.0.9=py311h6a678d5_7 17 | - brunsli=0.1=h2531618_0 18 | - bzip2=1.0.8=h7b6447c_0 19 | - c-ares=1.19.1=h5eee18b_0 20 | - ca-certificates=2024.2.2=hbcca054_0 21 | - certifi=2023.11.17=pyhd8ed1ab_0 22 | - cffi=1.16.0=py311h5eee18b_0 23 | - cfitsio=3.470=h5893167_7 24 | - charls=2.2.0=h2531618_0 25 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 26 | - click=8.1.7=py311h06a4308_0 27 | - cloudpickle=2.2.1=py311h06a4308_0 28 | - colorcet=3.0.1=py311h06a4308_0 29 | - contourpy=1.2.0=py311hdb19cb5_0 30 | - cryptography=41.0.7=py311hdda0065_0 31 | - cuda-cudart=12.1.105=0 32 | - cuda-cupti=12.1.105=0 33 | - cuda-libraries=12.1.0=0 34 | - cuda-nvrtc=12.1.105=0 35 | - cuda-nvtx=12.1.105=0 36 | - cuda-opencl=12.3.101=0 37 | - cuda-runtime=12.1.0=0 38 | - cycler=0.11.0=pyhd3eb1b0_0 39 | - cyrus-sasl=2.1.28=h52b45da_1 40 | - cytoolz=0.12.2=py311h5eee18b_0 41 | - dask-core=2023.11.0=py311h06a4308_0 42 | - dav1d=1.2.1=h5eee18b_0 43 | - dbus=1.13.18=hb2f20db_0 44 | - expat=2.5.0=h6a678d5_0 45 | - ffmpeg=4.3=hf484d3e_0 46 | - filelock=3.13.1=py311h06a4308_0 47 | - fontconfig=2.14.1=h4c34cd2_2 48 | - fonttools=4.25.0=pyhd3eb1b0_0 49 | - freetype=2.12.1=h4a9f257_0 50 | - giflib=5.2.1=h5eee18b_3 51 | - glib=2.69.1=he621ea3_2 52 | - gmp=6.2.1=h295c915_3 53 | - gmpy2=2.1.2=py311hc9b5ff0_0 54 | - gnutls=3.6.15=he1e5248_0 55 | - gst-plugins-base=1.14.1=h6a678d5_1 56 | - gstreamer=1.14.1=h5eee18b_1 57 | - icu=73.1=h6a678d5_0 58 | - idna=3.4=py311h06a4308_0 59 | - imagecodecs=2023.1.23=py311h8105a5c_0 60 | - imageio=2.31.4=py311h06a4308_0 61 | - importlib-metadata=7.0.1=py311h06a4308_0 62 | - intel-openmp=2023.1.0=hdb19cb5_46306 63 | - jinja2=3.1.2=py311h06a4308_0 64 | - joblib=1.2.0=py311h06a4308_0 65 | - jpeg=9e=h5eee18b_1 66 | - jxrlib=1.1=h7b6447c_2 67 | - kiwisolver=1.4.4=py311h6a678d5_0 68 | - krb5=1.20.1=h143b758_1 69 | - lame=3.100=h7b6447c_0 70 | - lazy_loader=0.3=py311h06a4308_0 71 | - lcms2=2.12=h3be6417_0 72 | - ld_impl_linux-64=2.38=h1181459_1 73 | - lerc=3.0=h295c915_0 74 | - libaec=1.0.4=he6710b0_1 75 | - libavif=0.11.1=h5eee18b_0 76 | - libbrotlicommon=1.0.9=h5eee18b_7 77 | - libbrotlidec=1.0.9=h5eee18b_7 78 | - libbrotlienc=1.0.9=h5eee18b_7 79 | - libclang=14.0.6=default_hc6dbbc7_1 80 | - libclang13=14.0.6=default_he11475f_1 81 | - libcublas=12.1.0.26=0 82 | - libcufft=11.0.2.4=0 83 | - libcufile=1.8.1.2=0 84 | - libcups=2.4.2=h2d74bed_1 85 | - libcurand=10.3.4.107=0 86 | - libcurl=8.5.0=h251f7ec_0 87 | - libcusolver=11.4.4.55=0 88 | - libcusparse=12.0.2.55=0 89 | - libdeflate=1.17=h5eee18b_1 90 | - libedit=3.1.20230828=h5eee18b_0 91 | - libev=4.33=h7f8727e_1 92 | - libffi=3.4.4=h6a678d5_0 93 | - libgcc-ng=13.2.0=h807b86a_5 94 | - libgfortran-ng=11.2.0=h00389a5_1 95 | - libgfortran5=11.2.0=h1234567_1 96 | - libgomp=13.2.0=h807b86a_5 97 | - libiconv=1.16=h7f8727e_2 98 | - libidn2=2.3.4=h5eee18b_0 99 | - libjpeg-turbo=2.0.0=h9bf148f_0 100 | - libllvm14=14.0.6=hdb19cb5_3 101 | - libnghttp2=1.57.0=h2d74bed_0 102 | - libnpp=12.0.2.50=0 103 | - libnvjitlink=12.1.105=0 104 | - libnvjpeg=12.1.1.14=0 105 | - libpng=1.6.39=h5eee18b_0 106 | - libpq=12.17=hdbd6064_0 107 | - libssh2=1.10.0=hdbd6064_2 108 | - libstdcxx-ng=11.2.0=h1234567_1 109 | - libtasn1=4.19.0=h5eee18b_0 110 | - libtiff=4.5.1=h6a678d5_0 111 | - libunistring=0.9.10=h27cfd23_0 112 | - libuuid=1.41.5=h5eee18b_0 113 | - libwebp=1.3.2=h11a3e52_0 114 | - libwebp-base=1.3.2=h5eee18b_0 115 | - libxcb=1.15=h7f8727e_0 116 | - libxkbcommon=1.0.1=h5eee18b_1 117 | - libxml2=2.10.4=hf1b16e4_1 118 | - libzopfli=1.0.3=he6710b0_0 119 | - lightning-utilities=0.10.1=pyhd8ed1ab_0 120 | - llvm-openmp=14.0.6=h9e868ea_0 121 | - locket=1.0.0=py311h06a4308_0 122 | - lz4-c=1.9.4=h6a678d5_0 123 | - markupsafe=2.1.3=py311h5eee18b_0 124 | - matplotlib=3.8.0=py311h06a4308_0 125 | - matplotlib-base=3.8.0=py311ha02d727_0 126 | - mkl=2023.1.0=h213fc3f_46344 127 | - mkl-service=2.4.0=py311h5eee18b_1 128 | - mkl_fft=1.3.8=py311h5eee18b_0 129 | - mkl_random=1.2.4=py311hdb19cb5_0 130 | - mpc=1.1.0=h10f8cd9_1 131 | - mpfr=4.0.2=hb69a4c5_1 132 | - mpmath=1.3.0=py311h06a4308_0 133 | - munkres=1.1.4=py_0 134 | - mysql=5.7.24=h721c034_2 135 | - ncurses=6.4=h6a678d5_0 136 | - nettle=3.7.3=hbbd107a_1 137 | - networkx=3.1=py311h06a4308_0 138 | - numexpr=2.8.7=py311h65dcdc2_0 139 | - numpy=1.24.3=py311h08b1b3b_1 140 | - numpy-base=1.24.3=py311hf175353_1 141 | - openh264=2.1.1=h4ff587b_0 142 | - openjpeg=2.4.0=h3ad879b_0 143 | - openssl=3.2.1=hd590300_0 144 | - pandas=2.1.4=py311ha02d727_0 145 | - param=2.0.2=py311h06a4308_0 146 | - partd=1.4.1=py311h06a4308_0 147 | - pcre=8.45=h295c915_0 148 | - pillow=10.0.1=py311ha6cbd5a_0 149 | - pip=23.3.1=py311h06a4308_0 150 | - ply=3.11=py311h06a4308_0 151 | - pycparser=2.21=pyhd3eb1b0_0 152 | - pyct=0.5.0=py311h06a4308_0 153 | - pyopenssl=23.2.0=py311h06a4308_0 154 | - pyparsing=3.0.9=py311h06a4308_0 155 | - pyqt=5.15.10=py311h6a678d5_0 156 | - pyqt5-sip=12.13.0=py311h5eee18b_0 157 | - pysocks=1.7.1=py311h06a4308_0 158 | - python=3.11.7=h955ad1f_0 159 | - python-dateutil=2.8.2=pyhd3eb1b0_0 160 | - python-tzdata=2023.3=pyhd3eb1b0_0 161 | - pytorch=2.2.0=py3.11_cuda12.1_cudnn8.9.2_0 162 | - pytorch-cuda=12.1=ha16c6d3_5 163 | - pytorch-mutex=1.0=cuda 164 | - pytz=2023.3.post1=py311h06a4308_0 165 | - pywavelets=1.5.0=py311hf4808d0_0 166 | - pyyaml=6.0.1=py311h5eee18b_0 167 | - qt-main=5.15.2=h53bd1ea_10 168 | - readline=8.2=h5eee18b_0 169 | - requests=2.31.0=py311h06a4308_0 170 | - scikit-image=0.20.0=py311h6a678d5_0 171 | - scikit-learn=1.3.0=py311ha02d727_1 172 | - scipy=1.11.4=py311h08b1b3b_0 173 | - setuptools=68.2.2=py311h06a4308_0 174 | - sip=6.7.12=py311h6a678d5_0 175 | - six=1.16.0=pyhd3eb1b0_1 176 | - snappy=1.1.10=h6a678d5_1 177 | - sqlite=3.41.2=h5eee18b_0 178 | - sympy=1.12=py311h06a4308_0 179 | - tbb=2021.8.0=hdb19cb5_0 180 | - threadpoolctl=2.2.0=pyh0d69192_0 181 | - tifffile=2023.4.12=py311h06a4308_0 182 | - tk=8.6.12=h1ccaba5_0 183 | - toolz=0.12.0=py311h06a4308_0 184 | - torchaudio=2.2.0=py311_cu121 185 | - torchmetrics=1.2.1=pyhd8ed1ab_0 186 | - torchtriton=2.2.0=py311 187 | - torchvision=0.17.0=py311_cu121 188 | - tornado=6.3.3=py311h5eee18b_0 189 | - typing_extensions=4.9.0=py311h06a4308_1 190 | - tzdata=2023d=h04d1e81_0 191 | - urllib3=1.26.18=py311h06a4308_0 192 | - wheel=0.41.2=py311h06a4308_0 193 | - xz=5.4.5=h5eee18b_0 194 | - yaml=0.2.5=h7b6447c_0 195 | - zfp=1.0.0=h6a678d5_0 196 | - zipp=3.17.0=py311h06a4308_0 197 | - zlib=1.2.13=h5eee18b_0 198 | - zstd=1.5.5=hc292b87_0 199 | - pip: 200 | - appdirs==1.4.4 201 | - asttokens==2.4.1 202 | - decorator==5.1.1 203 | - docker-pycreds==0.4.0 204 | - executing==2.0.1 205 | - fsspec==2023.12.2 206 | - gitdb==4.0.11 207 | - gitpython==3.1.41 208 | - huggingface-hub==0.20.3 209 | - ipython==8.20.0 210 | - jedi==0.19.1 211 | - matplotlib-inline==0.1.6 212 | - opencv-python-headless==4.9.0.80 213 | - packaging==23.2 214 | - parso==0.8.3 215 | - pexpect==4.9.0 216 | - prompt-toolkit==3.0.43 217 | - protobuf==4.25.2 218 | - psutil==5.9.8 219 | - ptyprocess==0.7.0 220 | - pure-eval==0.2.2 221 | - pycocotools==2.0.7 222 | - pygments==2.17.2 223 | - safetensors==0.4.2 224 | - sentry-sdk==1.39.2 225 | - setproctitle==1.3.3 226 | - smmap==5.0.1 227 | - stack-data==0.6.3 228 | - timm==0.9.12 229 | - tqdm==4.66.1 230 | - traitlets==5.14.1 231 | - wandb==0.16.3 232 | - wcwidth==0.2.13 233 | -------------------------------------------------------------------------------- /evaluate_parts.py: -------------------------------------------------------------------------------- 1 | """ 2 | From: https://github.com/zxhuang1698/interpretability-by-parts/blob/master/src/cub200/eval_interp.py 3 | """ 4 | # pytorch & misc 5 | import torch 6 | import torchvision.transforms as transforms 7 | from data_sets import FineGrainedBirdClassificationParts, PartImageNetDataset, Flowers102Seg 8 | from load_model import load_model_pdisco 9 | import argparse 10 | import copy 11 | from engine.eval_interpretability_nmi_ari_keypoint import eval_nmi_ari, eval_kpr 12 | from engine.eval_fg_bg import FgBgIoU 13 | from utils.training_utils.engine_utils import load_state_dict_pdisco 14 | 15 | torch.multiprocessing.set_sharing_strategy('file_system') 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser( 20 | description='Evaluate model interpretability via part parsing quality' 21 | ) 22 | parser.add_argument('--model_arch', default='resnet50', type=str, 23 | help='pick model architecture') 24 | parser.add_argument('--use_torchvision_resnet_model', default=False, action='store_true') 25 | 26 | # Data 27 | parser.add_argument('--data_path', 28 | help='directory that contains cub files, must' 29 | 'contain folder "./images"', required=True) 30 | parser.add_argument('--image_sub_path', default='images', type=str, required=False) 31 | parser.add_argument('--dataset', default='cub', type=str) 32 | parser.add_argument('--anno_path_test', default='', type=str, required=False) 33 | parser.add_argument('--center_crop', default=False, action='store_true') 34 | 35 | # Eval mode 36 | parser.add_argument('--eval_mode', default='keypoint', choices=['keypoint', 'nmi_ari', 'fg_bg'], type=str) 37 | 38 | # Model params 39 | parser.add_argument('--num_parts', help='number of parts to predict', 40 | default=8, type=int) 41 | parser.add_argument('--image_size', default=448, type=int) 42 | parser.add_argument('--output_stride', default=32, type=int) 43 | parser.add_argument('--batch_size', default=1, type=int) 44 | parser.add_argument('--num_workers', default=1, type=int) 45 | # Modulation 46 | parser.add_argument('--modulation_type', default="original", 47 | choices=["original", "layer_norm", "parallel_mlp", "parallel_mlp_no_bias", 48 | "parallel_mlp_no_act", "parallel_mlp_no_act_no_bias", "none"], 49 | type=str) 50 | parser.add_argument('--modulation_orth', default=False, action='store_true', 51 | help='use orthogonality loss on modulated features') 52 | # Part Dropout 53 | parser.add_argument('--part_dropout', default=0.0, type=float) 54 | 55 | # Add noise to vit output features 56 | parser.add_argument('--noise_variance', default=0.0, type=float) 57 | 58 | # Gumbel Softmax 59 | parser.add_argument('--gumbel_softmax', default=False, action='store_true') 60 | parser.add_argument('--gumbel_softmax_temperature', default=1.0, type=float) 61 | parser.add_argument('--gumbel_softmax_hard', default=False, action='store_true') 62 | 63 | # Model path 64 | parser.add_argument('--model_path', default=None, type=str) 65 | 66 | # Classifier type 67 | parser.add_argument('--classifier_type', default="linear", 68 | choices=["linear", "independent_mlp"], type=str) 69 | 70 | args = parser.parse_args() 71 | return args 72 | 73 | 74 | def main(args): 75 | mode = args.eval_mode 76 | nparts = args.num_parts 77 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 78 | resize_transform = transforms.Resize(size=args.image_size) 79 | resize_transform_mask = transforms.Resize(size=args.image_size, interpolation=transforms.InterpolationMode.NEAREST) 80 | center_crop_transform = transforms.CenterCrop(size=args.image_size) 81 | def_transform = transforms.ToTensor() 82 | if "vit" in args.model_arch: 83 | if not args.center_crop: 84 | raise ValueError('ViT models require center crop.') 85 | 86 | if args.center_crop and args.dataset != 'cub': 87 | data_transforms = transforms.Compose([resize_transform, center_crop_transform, def_transform]) 88 | mask_transform = transforms.Compose([resize_transform_mask, center_crop_transform]) 89 | 90 | else: 91 | data_transforms = transforms.Compose([resize_transform, def_transform]) 92 | mask_transform = resize_transform_mask 93 | 94 | # define dataset path 95 | if args.dataset == 'cub': 96 | cub_path = args.data_path 97 | # define dataset and loader 98 | eval_data = FineGrainedBirdClassificationParts(cub_path, 99 | train=False, transform=data_transforms, resize=args.image_size, center_crop=args.center_crop, 100 | image_sub_path=args.image_sub_path) 101 | elif args.dataset == 'part_imagenet': 102 | # define dataset and loader 103 | eval_data = PartImageNetDataset(data_path=args.data_path, image_sub_path=args.image_sub_path, 104 | transform=data_transforms, 105 | annotation_file_path=args.anno_path_test, 106 | get_masks=True, 107 | mask_transform=mask_transform, 108 | ) 109 | 110 | elif args.dataset == 'flowers102seg': 111 | # define dataset and loader 112 | eval_data = Flowers102Seg(args.data_path, transform=data_transforms, mask_transform=mask_transform, split='test') 113 | 114 | else: 115 | raise ValueError('Dataset not supported.') 116 | 117 | eval_loader = torch.utils.data.DataLoader( 118 | eval_data, batch_size=args.batch_size, shuffle=False, 119 | num_workers=args.num_workers, pin_memory=True, drop_last=True) 120 | 121 | num_cls = eval_data.num_classes 122 | 123 | # Add arguments to args 124 | args.eval_only = True 125 | args.pretrained_start_weights = True 126 | 127 | # Load the model 128 | net = load_model_pdisco(args, num_cls) 129 | snapshot_data = torch.load(args.model_path, map_location=torch.device('cpu'), weights_only=True) 130 | if 'model_state' in snapshot_data: 131 | _, state_dict = load_state_dict_pdisco(snapshot_data) 132 | else: 133 | state_dict = copy.deepcopy(snapshot_data) 134 | net.load_state_dict(state_dict, strict=True) 135 | net.eval() 136 | net.to(device) 137 | 138 | if mode == 'keypoint': 139 | if args.dataset == 'cub': 140 | fit_data = FineGrainedBirdClassificationParts(args.data_path, 141 | train=True, transform=data_transforms, resize=args.image_size, 142 | center_crop=args.center_crop) 143 | fit_loader = torch.utils.data.DataLoader( 144 | fit_data, batch_size=args.batch_size, shuffle=True, 145 | num_workers=args.num_workers, pin_memory=True, drop_last=True) 146 | kpr = eval_kpr(net, fit_loader, eval_loader, nparts, num_landmarks=eval_data.num_kps, device=device) 147 | print('Mean keypoint regression error on the test set is %.2f%%.' % kpr) 148 | else: 149 | raise ValueError('Dataset not supported.') 150 | 151 | elif mode == 'nmi_ari': 152 | nmi, ari = eval_nmi_ari(net, eval_loader, dataset=args.dataset, device=device) 153 | print(nmi) 154 | print(ari) 155 | print('NMI between predicted and ground truth parts is %.2f' % nmi) 156 | print('ARI between predicted and ground truth parts is %.2f' % ari) 157 | print('Evaluation finished.') 158 | 159 | elif mode == 'fg_bg': 160 | if args.dataset != 'flowers102seg': 161 | raise ValueError('Dataset not supported.') 162 | iou_calculator = FgBgIoU(net, eval_loader, device=device) 163 | iou_calculator.calculate_iou(args.model_path) 164 | m_iou = iou_calculator.metric_fg.compute().item() * 100 165 | m_iou_bg = iou_calculator.metric_bg.compute().item() * 100 166 | print('Foreground mIoU is %.2f' % m_iou) 167 | print('Background mIoU is %.2f' % m_iou_bg) 168 | print('Evaluation finished.') 169 | 170 | else: 171 | print("Please run with either keypoint or nmi_ari or fg_bg mode.") 172 | 173 | 174 | if __name__ == '__main__': 175 | arguments = parse_args() 176 | main(arguments) 177 | -------------------------------------------------------------------------------- /models/vit_baseline.py: -------------------------------------------------------------------------------- 1 | # Compostion of the VisionTransformer class from timm with extra features: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py 2 | import torch 3 | import torch.nn as nn 4 | from typing import Tuple, Union, Sequence, Any 5 | from timm.layers import trunc_normal_ 6 | from timm.models.vision_transformer import Block, Attention 7 | from layers.transformer_layers import BlockWQKVReturn, AttentionWQKVReturn 8 | 9 | from utils.misc_utils import compute_attention 10 | 11 | 12 | class BaselineViT(torch.nn.Module): 13 | """ 14 | Modifications: 15 | - Use PDiscoBlock instead of Block 16 | - Use PDiscoAttention instead of Attention 17 | - Return the mean of k over heads from attention 18 | - Option to use only class tokens or only patch tokens or both (concat) for classification 19 | """ 20 | 21 | def __init__(self, init_model: torch.nn.Module, num_classes: int, 22 | class_tokens_only: bool = False, 23 | patch_tokens_only: bool = False, return_transformer_qkv: bool = False) -> None: 24 | super().__init__() 25 | self.num_classes = num_classes 26 | self.class_tokens_only = class_tokens_only 27 | self.patch_tokens_only = patch_tokens_only 28 | self.num_prefix_tokens = init_model.num_prefix_tokens 29 | self.num_reg_tokens = init_model.num_reg_tokens 30 | self.has_class_token = init_model.has_class_token 31 | self.no_embed_class = init_model.no_embed_class 32 | self.cls_token = init_model.cls_token 33 | self.reg_token = init_model.reg_token 34 | 35 | self.patch_embed = init_model.patch_embed 36 | 37 | self.pos_embed = init_model.pos_embed 38 | self.pos_drop = init_model.pos_drop 39 | self.part_embed = nn.Identity() 40 | self.patch_prune = nn.Identity() 41 | self.norm_pre = init_model.norm_pre 42 | self.blocks = init_model.blocks 43 | self.norm = init_model.norm 44 | 45 | self.fc_norm = init_model.fc_norm 46 | if class_tokens_only or patch_tokens_only: 47 | self.head = nn.Linear(init_model.embed_dim, num_classes) 48 | else: 49 | self.head = nn.Linear(init_model.embed_dim * 2, num_classes) 50 | 51 | self.h_fmap = int(self.patch_embed.img_size[0] // self.patch_embed.patch_size[0]) 52 | self.w_fmap = int(self.patch_embed.img_size[1] // self.patch_embed.patch_size[1]) 53 | 54 | self.return_transformer_qkv = return_transformer_qkv 55 | self.convert_blocks_and_attention() 56 | self._init_weights_head() 57 | 58 | def convert_blocks_and_attention(self): 59 | for module in self.modules(): 60 | if isinstance(module, Block): 61 | module.__class__ = BlockWQKVReturn 62 | elif isinstance(module, Attention): 63 | module.__class__ = AttentionWQKVReturn 64 | 65 | def _pos_embed(self, x: torch.Tensor) -> torch.Tensor: 66 | pos_embed = self.pos_embed 67 | to_cat = [] 68 | if self.cls_token is not None: 69 | to_cat.append(self.cls_token.expand(x.shape[0], -1, -1)) 70 | if self.reg_token is not None: 71 | to_cat.append(self.reg_token.expand(x.shape[0], -1, -1)) 72 | if self.no_embed_class: 73 | # deit-3, updated JAX (big vision) 74 | # position embedding does not overlap with class token, add then concat 75 | x = x + pos_embed 76 | if to_cat: 77 | x = torch.cat(to_cat + [x], dim=1) 78 | else: 79 | # original timm, JAX, and deit vit impl 80 | # pos_embed has entry for class token, concat then add 81 | if to_cat: 82 | x = torch.cat(to_cat + [x], dim=1) 83 | x = x + pos_embed 84 | return self.pos_drop(x) 85 | 86 | def _init_weights_head(self): 87 | trunc_normal_(self.head.weight, std=.02) 88 | if self.head.bias is not None: 89 | nn.init.constant_(self.head.bias, 0.) 90 | 91 | def forward(self, x: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]: 92 | 93 | x = self.patch_embed(x) 94 | 95 | # Position Embedding 96 | x = self._pos_embed(x) 97 | 98 | x = self.part_embed(x) 99 | x = self.patch_prune(x) 100 | 101 | # Forward pass through transformer 102 | x = self.norm_pre(x) 103 | 104 | if self.return_transformer_qkv: 105 | # Return keys of last attention layer 106 | for i, blk in enumerate(self.blocks): 107 | x, qkv = blk(x, return_qkv=True) 108 | else: 109 | x = self.blocks(x) 110 | 111 | x = self.norm(x) 112 | 113 | # Classification head 114 | x = self.fc_norm(x) 115 | if self.class_tokens_only: # only use class token 116 | x = x[:, 0, :] 117 | elif self.patch_tokens_only: # only use patch tokens 118 | x = x[:, self.num_prefix_tokens:, :].mean(dim=1) 119 | else: 120 | x = torch.cat([x[:, 0, :], x[:, self.num_prefix_tokens:, :].mean(dim=1)], dim=1) 121 | x = self.head(x) 122 | if self.return_transformer_qkv: 123 | return x, qkv 124 | else: 125 | return x 126 | 127 | def get_specific_intermediate_layer( 128 | self, 129 | x: torch.Tensor, 130 | n: int = 1, 131 | return_qkv: bool = False, 132 | return_att_weights: bool = False, 133 | ): 134 | num_blocks = len(self.blocks) 135 | attn_weights = [] 136 | if n >= num_blocks: 137 | raise ValueError(f"n must be less than {num_blocks}") 138 | 139 | # forward pass 140 | x = self.patch_embed(x) 141 | x = self._pos_embed(x) 142 | x = self.norm_pre(x) 143 | 144 | if n == -1: 145 | if return_qkv: 146 | raise ValueError("take_indice cannot be -1 if return_transformer_qkv is True") 147 | else: 148 | return x 149 | 150 | for i, blk in enumerate(self.blocks): 151 | if self.return_transformer_qkv: 152 | x, qkv = blk(x, return_qkv=True) 153 | 154 | if return_att_weights: 155 | attn_weight, _ = compute_attention(qkv) 156 | attn_weights.append(attn_weight.detach()) 157 | else: 158 | x = blk(x) 159 | if i == n: 160 | output = x.clone() 161 | if self.return_transformer_qkv and return_qkv: 162 | qkv_output = qkv.clone() 163 | break 164 | if self.return_transformer_qkv and return_qkv and return_att_weights: 165 | return output, qkv_output, attn_weights 166 | elif self.return_transformer_qkv and return_qkv: 167 | return output, qkv_output 168 | elif self.return_transformer_qkv and return_att_weights: 169 | return output, attn_weights 170 | else: 171 | return output 172 | 173 | def _intermediate_layers( 174 | self, 175 | x: torch.Tensor, 176 | n: Union[int, Sequence] = 1, 177 | ): 178 | outputs, num_blocks = [], len(self.blocks) 179 | if self.return_transformer_qkv: 180 | qkv_outputs = [] 181 | take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n) 182 | 183 | # forward pass 184 | x = self.patch_embed(x) 185 | x = self._pos_embed(x) 186 | x = self.norm_pre(x) 187 | 188 | for i, blk in enumerate(self.blocks): 189 | if self.return_transformer_qkv: 190 | x, qkv = blk(x, return_qkv=True) 191 | else: 192 | x = blk(x) 193 | if i in take_indices: 194 | outputs.append(x) 195 | if self.return_transformer_qkv: 196 | qkv_outputs.append(qkv) 197 | if self.return_transformer_qkv: 198 | return outputs, qkv_outputs 199 | else: 200 | return outputs 201 | 202 | def get_intermediate_layers( 203 | self, 204 | x: torch.Tensor, 205 | n: Union[int, Sequence] = 1, 206 | reshape: bool = False, 207 | return_prefix_tokens: bool = False, 208 | norm: bool = False, 209 | ) -> tuple[tuple, Any]: 210 | """ Intermediate layer accessor (NOTE: This is a WIP experiment). 211 | Inspired by DINO / DINOv2 interface 212 | """ 213 | # take last n blocks if n is an int, if in is a sequence, select by matching indices 214 | if self.return_transformer_qkv: 215 | outputs, qkv = self._intermediate_layers(x, n) 216 | else: 217 | outputs = self._intermediate_layers(x, n) 218 | 219 | if norm: 220 | outputs = [self.norm(out) for out in outputs] 221 | prefix_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs] 222 | outputs = [out[:, self.num_prefix_tokens:] for out in outputs] 223 | 224 | if reshape: 225 | grid_size = self.patch_embed.grid_size 226 | outputs = [ 227 | out.reshape(x.shape[0], grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2).contiguous() 228 | for out in outputs 229 | ] 230 | 231 | if return_prefix_tokens: 232 | return_out = tuple(zip(outputs, prefix_tokens)) 233 | else: 234 | return_out = tuple(outputs) 235 | 236 | if self.return_transformer_qkv: 237 | return return_out, qkv 238 | else: 239 | return return_out 240 | -------------------------------------------------------------------------------- /data_sets/celeba.py: -------------------------------------------------------------------------------- 1 | """Adapted from: https://github.com/zxhuang1698/interpretability-by-parts/""" 2 | 3 | import torch 4 | import torch.utils.data as data 5 | import os 6 | import os.path 7 | import pickle 8 | import numpy as np 9 | from collections import defaultdict 10 | from utils.data_utils.dataset_utils import pil_loader 11 | 12 | 13 | class CelebA(data.Dataset): 14 | """ 15 | CelebA dataset. 16 | Variables 17 | ---------- 18 | root, str: Root directory of the dataset. 19 | split, str: Current data split. 20 | "train": Training split without MAFL images. (For localization) 21 | "train_full": Training split with MAFL images. (For classification) 22 | "val": Validation split for classification accuracy. 23 | "test": Testing split for classification accuracy. 24 | "fit": Split for fitting the linear regressor. 25 | "eval": Split for evaluating the linear regressor. 26 | align, bool: Whether use aligned version or not. 27 | percentage, float: For unaligned version, the least percentage of (face area / image area) 28 | transform, callable: A function/transform that takes in a PIL.Image and transforms it. 29 | resize, tuple: The size of image (h, w) after transformation (This version does not support cropping) 30 | """ 31 | 32 | def __init__(self, root, 33 | split='train', align=False, 34 | percentage=None, transform=None, resize=256, image_sub_path="unaligned"): 35 | 36 | self.root = root 37 | self.split = split 38 | self.align = align 39 | self.resize = resize 40 | self.image_sub_path = image_sub_path 41 | # load the dictionary for data 42 | align_name = '_aligned' if align else '_unaligned' 43 | percentage_name = '_0' if percentage is None else '_' + str(int(percentage * 100)) 44 | save_name = os.path.join(root, split + align_name + percentage_name + '.pickle') 45 | self.shuffle = np.arange(182637) 46 | np.random.shuffle(self.shuffle) 47 | if os.path.exists(save_name) is False: 48 | print('Preparing the data...') 49 | self.generate_dict(save_name) 50 | print('Data dictionary created and saved.') 51 | with open(save_name, 'rb') as handle: 52 | save_dict = pickle.load(handle) 53 | 54 | self.images = save_dict['images'] # image filenames 55 | self.landmarks = save_dict['landmarks'] # 5 face landmarks 56 | self.targets = save_dict['targets'] # binary labels 57 | self.bboxes = save_dict['bboxes'] # x y w h 58 | self.sizes = save_dict['sizes'] # height width 59 | self.identities = save_dict['identities'] 60 | self.transform = transform 61 | self.loader = pil_loader 62 | 63 | # select a subset of the current data split according the face area 64 | if percentage is not None: 65 | new_images = [] 66 | new_landmarks = [] 67 | new_targets = [] 68 | new_bboxes = [] 69 | new_sizes = [] 70 | new_identities = [] 71 | for i in range(len(self.images)): 72 | if float(self.bboxes[i][-1] * self.bboxes[i][-2]) >= float( 73 | self.sizes[i][-1] * self.sizes[i][-2]) * percentage: 74 | new_images.append(self.images[i]) 75 | new_landmarks.append(self.landmarks[i]) 76 | new_targets.append(self.targets[i]) 77 | new_bboxes.append(self.bboxes[i]) 78 | new_sizes.append(self.sizes[i]) 79 | new_identities.append(self.identities[i]) 80 | self.images = new_images 81 | self.landmarks = new_landmarks 82 | self.targets = new_targets 83 | self.bboxes = new_bboxes 84 | self.sizes = new_sizes 85 | self.identities = new_identities 86 | print('Number of classes in the ' + self.split + ' split: ' + str(max(self.identities))) 87 | print('Number of samples in the ' + self.split + ' split: ' + str(len(self.images))) 88 | self.num_classes = max(self.identities) 89 | self.per_class_count = defaultdict(int) 90 | for label in self.identities: 91 | self.per_class_count[label] += 1 92 | self.cls_num_list = [self.per_class_count[idx] for idx in range(self.num_classes)] 93 | 94 | # generate a dictionary for a certain data split 95 | def generate_dict(self, save_name): 96 | 97 | print('Start generating data dictionary as ' + save_name) 98 | 99 | full_img_list = [] 100 | ann_file = 'list_attr_celeba.txt' 101 | bbox_file = 'list_bbox_celeba.txt' 102 | size_file = 'list_imsize_celeba.txt' 103 | identity_file = 'identity_CelebA.txt' 104 | 105 | if self.align is True: 106 | landmark_file = 'list_landmarks_align_celeba.txt' 107 | else: 108 | landmark_file = 'list_landmarks_unalign_celeba.txt' 109 | 110 | # load all the images according to the current split 111 | if self.split == 'train': 112 | imgfile = 'celebA_training.txt' 113 | elif self.split == 'val': 114 | imgfile = 'celebA_validating.txt' 115 | elif self.split == 'test': 116 | imgfile = 'celebA_testing.txt' 117 | elif self.split == 'fit': 118 | imgfile = 'MAFL_training.txt' 119 | elif self.split == 'eval': 120 | imgfile = 'MAFL_testing.txt' 121 | elif self.split == 'train_full': 122 | imgfile = 'celebA_training_full.txt' 123 | for line in open(os.path.join(self.root, imgfile), 'r'): 124 | full_img_list.append(line.split()[0]) 125 | 126 | # prepare the indexes and convert annotation files to lists 127 | full_img_list_idx = [(int(s.rstrip(".jpg")) - 1) for s in full_img_list] 128 | ann_full_list = [line.split() for line in open(os.path.join(self.root, ann_file), 'r')] 129 | bbox_full_list = [line.split() for line in open(os.path.join(self.root, bbox_file), 'r')] 130 | size_full_list = [line.split() for line in open(os.path.join(self.root, size_file), 'r')] 131 | landmark_full_list = [line.split() for line in open(os.path.join(self.root, landmark_file), 'r')] 132 | identity_full_list = [line.split() for line in open(os.path.join(self.root, identity_file), 'r')] 133 | 134 | # assertion 135 | assert len(ann_full_list[0]) == 41 136 | assert len(bbox_full_list[0]) == 5 137 | assert len(size_full_list[0]) == 3 138 | assert len(landmark_full_list[0]) == 11 139 | 140 | # select samples and annotations for the current data split 141 | # init the lists 142 | filename_list = [] 143 | target_list = [] 144 | landmark_list = [] 145 | bbox_list = [] 146 | size_list = [] 147 | identity_list = [] 148 | 149 | # select samples and annotations 150 | for i in full_img_list_idx: 151 | idx = self.shuffle[i] 152 | 153 | # assertion 154 | assert (idx + 1) == int(ann_full_list[idx][0].rstrip(".jpg")) 155 | assert (idx + 1) == int(bbox_full_list[idx][0].rstrip(".jpg")) 156 | assert (idx + 1) == int(size_full_list[idx][0].rstrip(".jpg")) 157 | assert (idx + 1) == int(landmark_full_list[idx][0].rstrip(".jpg")) 158 | 159 | # append the filenames and annotations 160 | filename_list.append(ann_full_list[idx][0]) 161 | target_list.append([int(i) for i in ann_full_list[idx][1:]]) 162 | bbox_list.append([int(i) for i in bbox_full_list[idx][1:]]) 163 | size_list.append([int(i) for i in size_full_list[idx][1:]]) 164 | landmark_list_xy = [] 165 | for j in range(5): 166 | landmark_list_xy.append( 167 | [int(landmark_full_list[idx][1 + 2 * j]), int(landmark_full_list[idx][2 + 2 * j])]) 168 | landmark_list.append(landmark_list_xy) 169 | identity_list.append(int(identity_full_list[idx][1])) 170 | 171 | # expand the filename to the full path 172 | full_path_list = [os.path.join(self.root, self.image_sub_path, filename) for filename in filename_list] 173 | 174 | # create the dictionary and save it on the disk 175 | save_dict = dict() 176 | save_dict['images'] = full_path_list 177 | save_dict['landmarks'] = landmark_list 178 | save_dict['targets'] = target_list 179 | save_dict['bboxes'] = bbox_list 180 | save_dict['sizes'] = size_list 181 | save_dict['identities'] = identity_list 182 | with open(save_name, 'wb') as handle: 183 | pickle.dump(save_dict, handle) 184 | 185 | def __getitem__(self, index): 186 | """ 187 | Retrieve data samples. 188 | Args 189 | ---------- 190 | index: int 191 | Index of the sample. 192 | Returns 193 | ---------- 194 | sample: PIL.Image 195 | Image of the given index. 196 | identity: torch.LongTensor 197 | Corresponding identity labels for all images 198 | landmark_locs: torch.FloatTensor, [5, 2] 199 | Landmark annotations, column first. 200 | """ 201 | # load images and targets 202 | path = self.images[index] 203 | sample = self.loader(path) 204 | identity = self.identities[index] - 1 205 | image = np.array(sample) 206 | if image.shape[-3] > image.shape[-2]: 207 | factor = self.resize / image.shape[-2] 208 | else: 209 | factor = self.resize / image.shape[-3] 210 | 211 | # transform the image and target 212 | if self.transform is not None: 213 | sample = self.transform(sample) 214 | 215 | # processing the landmarks 216 | landmark_locs = self.landmarks[index] 217 | landmark_locs = torch.LongTensor(landmark_locs).float() 218 | landmark_locs[:, 0] = landmark_locs[:, 0] * factor 219 | landmark_locs[:, 1] = landmark_locs[:, 1] * factor 220 | return sample, identity, landmark_locs 221 | 222 | def __len__(self): 223 | return len(self.images) 224 | -------------------------------------------------------------------------------- /load_model.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | from pathlib import Path 4 | 5 | import torch 6 | from timm.models import create_model 7 | from torchvision.models import get_model 8 | 9 | from models import pdiscoformer_vit_bb, pdisconet_vit_bb, pdisconet_resnet_torchvision_bb 10 | from models.individual_landmark_resnet import IndividualLandmarkResNet 11 | from models.individual_landmark_convnext import IndividualLandmarkConvNext 12 | from models.individual_landmark_vit import IndividualLandmarkViT 13 | from utils import load_state_dict_pdisco 14 | 15 | 16 | def load_model_arch(args, num_cls): 17 | """ 18 | Function to load the model 19 | :param args: Arguments from the command line 20 | :param num_cls: Number of classes in the dataset 21 | :return: 22 | """ 23 | if 'resnet' in args.model_arch: 24 | num_layers_split = [int(s) for s in args.model_arch if s.isdigit()] 25 | num_layers = int(''.join(map(str, num_layers_split))) 26 | if num_layers >= 100: 27 | timm_model_arch = args.model_arch + ".a1h_in1k" 28 | else: 29 | timm_model_arch = args.model_arch + ".a1_in1k" 30 | 31 | if "resnet" in args.model_arch and args.use_torchvision_resnet_model: 32 | weights = "DEFAULT" if args.pretrained_start_weights else None 33 | base_model = get_model(args.model_arch, weights=weights) 34 | elif "resnet" in args.model_arch and not args.use_torchvision_resnet_model: 35 | if args.eval_only: 36 | base_model = create_model( 37 | timm_model_arch, 38 | pretrained=args.pretrained_start_weights, 39 | num_classes=num_cls, 40 | output_stride=args.output_stride, 41 | ) 42 | else: 43 | base_model = create_model( 44 | timm_model_arch, 45 | pretrained=args.pretrained_start_weights, 46 | drop_path_rate=args.drop_path, 47 | num_classes=num_cls, 48 | output_stride=args.output_stride, 49 | ) 50 | 51 | elif "convnext" in args.model_arch: 52 | if args.eval_only: 53 | base_model = create_model( 54 | args.model_arch, 55 | pretrained=args.pretrained_start_weights, 56 | num_classes=num_cls, 57 | output_stride=args.output_stride, 58 | ) 59 | else: 60 | base_model = create_model( 61 | args.model_arch, 62 | pretrained=args.pretrained_start_weights, 63 | drop_path_rate=args.drop_path, 64 | num_classes=num_cls, 65 | output_stride=args.output_stride, 66 | ) 67 | elif "vit" in args.model_arch: 68 | if args.eval_only: 69 | base_model = create_model( 70 | args.model_arch, 71 | pretrained=args.pretrained_start_weights, 72 | img_size=args.image_size, 73 | ) 74 | else: 75 | base_model = create_model( 76 | args.model_arch, 77 | pretrained=args.pretrained_start_weights, 78 | drop_path_rate=args.drop_path, 79 | img_size=args.image_size, 80 | ) 81 | vit_patch_size = base_model.patch_embed.proj.kernel_size[0] 82 | if args.image_size % vit_patch_size != 0: 83 | raise ValueError(f"Image size {args.image_size} must be divisible by patch size {vit_patch_size}") 84 | else: 85 | raise ValueError('Model not supported.') 86 | 87 | return base_model 88 | 89 | 90 | def init_pdisco_model(base_model, args, num_cls): 91 | """ 92 | Function to initialize the model 93 | :param base_model: Base model 94 | :param args: Arguments from the command line 95 | :param num_cls: Number of classes in the dataset 96 | :return: 97 | """ 98 | # Initialize the network 99 | if 'convnext' in args.model_arch: 100 | sl_channels = base_model.stages[-1].downsample[-1].in_channels 101 | fl_channels = base_model.head.in_features 102 | model = IndividualLandmarkConvNext(base_model, args.num_parts, num_classes=num_cls, 103 | sl_channels=sl_channels, fl_channels=fl_channels, 104 | part_dropout=args.part_dropout, modulation_type=args.modulation_type, 105 | gumbel_softmax=args.gumbel_softmax, 106 | gumbel_softmax_temperature=args.gumbel_softmax_temperature, 107 | gumbel_softmax_hard=args.gumbel_softmax_hard, 108 | modulation_orth=args.modulation_orth, classifier_type=args.classifier_type, 109 | noise_variance=args.noise_variance) 110 | elif 'resnet' in args.model_arch: 111 | sl_channels = base_model.layer4[0].conv1.in_channels 112 | fl_channels = base_model.fc.in_features 113 | model = IndividualLandmarkResNet(base_model, args.num_parts, num_classes=num_cls, 114 | sl_channels=sl_channels, fl_channels=fl_channels, 115 | use_torchvision_model=args.use_torchvision_resnet_model, 116 | part_dropout=args.part_dropout, modulation_type=args.modulation_type, 117 | gumbel_softmax=args.gumbel_softmax, 118 | gumbel_softmax_temperature=args.gumbel_softmax_temperature, 119 | gumbel_softmax_hard=args.gumbel_softmax_hard, 120 | modulation_orth=args.modulation_orth, classifier_type=args.classifier_type, 121 | noise_variance=args.noise_variance) 122 | elif 'vit' in args.model_arch: 123 | model = IndividualLandmarkViT(base_model, num_landmarks=args.num_parts, num_classes=num_cls, 124 | part_dropout=args.part_dropout, 125 | modulation_type=args.modulation_type, gumbel_softmax=args.gumbel_softmax, 126 | gumbel_softmax_temperature=args.gumbel_softmax_temperature, 127 | gumbel_softmax_hard=args.gumbel_softmax_hard, 128 | modulation_orth=args.modulation_orth, classifier_type=args.classifier_type, 129 | noise_variance=args.noise_variance) 130 | else: 131 | raise ValueError('Model not supported.') 132 | 133 | return model 134 | 135 | 136 | def load_model_pdisco(args, num_cls): 137 | """ 138 | Function to load the model 139 | :param args: Arguments from the command line 140 | :param num_cls: Number of classes in the dataset 141 | :return: 142 | """ 143 | base_model = load_model_arch(args, num_cls) 144 | model = init_pdisco_model(base_model, args, num_cls) 145 | 146 | return model 147 | 148 | 149 | def pdiscoformer_vit(pretrained=True, model_dataset="cub", k=8, model_url="", img_size=224, num_cls=200): 150 | """ 151 | Function to load the PDiscoFormer model with ViT backbone 152 | :param pretrained: Boolean flag to load the pretrained weights 153 | :param model_dataset: Dataset for which the model is trained 154 | :param k: Number of unsupervised landmarks the model is trained on 155 | :param model_url: URL to load the model weights from 156 | :param img_size: Image size 157 | :param num_cls: Number of classes in the dataset 158 | :return: PDiscoFormer model with ViT backbone 159 | """ 160 | model = pdiscoformer_vit_bb("vit_base_patch14_reg4_dinov2.lvd142m", num_cls=num_cls, k=k, img_size=img_size) 161 | if pretrained: 162 | hub_dir = torch.hub.get_dir() 163 | model_dir = os.path.join(hub_dir, "pdiscoformer_checkpoints", f"pdiscoformer_{model_dataset}") 164 | 165 | Path(model_dir).mkdir(parents=True, exist_ok=True) 166 | url_path = model_url + str(k) + "_parts_snapshot_best.pt" 167 | snapshot_data = torch.hub.load_state_dict_from_url(url_path, model_dir=model_dir, map_location='cpu') 168 | if 'model_state' in snapshot_data: 169 | _, state_dict = load_state_dict_pdisco(snapshot_data) 170 | else: 171 | state_dict = copy.deepcopy(snapshot_data) 172 | model.load_state_dict(state_dict, strict=True) 173 | return model 174 | 175 | 176 | def pdisconet_vit(pretrained=True, model_dataset="nabirds", k=8, model_url="", img_size=224, num_cls=555): 177 | """ 178 | Function to load the PDiscoNet model with ViT backbone 179 | :param pretrained: Boolean flag to load the pretrained weights 180 | :param model_dataset: Dataset for which the model is trained 181 | :param k: Number of unsupervised landmarks the model is trained on 182 | :param model_url: URL to load the model weights from 183 | :param img_size: Image size 184 | :param num_cls: Number of classes in the dataset 185 | :return: PDiscoNet model with ViT backbone 186 | """ 187 | model = pdisconet_vit_bb("vit_base_patch14_reg4_dinov2.lvd142m", num_cls=num_cls, k=k, img_size=img_size) 188 | if pretrained: 189 | hub_dir = torch.hub.get_dir() 190 | model_dir = os.path.join(hub_dir, "pdiscoformer_checkpoints", f"pdisconet_{model_dataset}") 191 | 192 | Path(model_dir).mkdir(parents=True, exist_ok=True) 193 | url_path = model_url + str(k) + "_parts_snapshot_best.pt" 194 | snapshot_data = torch.hub.load_state_dict_from_url(url_path, model_dir=model_dir, map_location='cpu') 195 | if 'model_state' in snapshot_data: 196 | _, state_dict = load_state_dict_pdisco(snapshot_data) 197 | else: 198 | state_dict = copy.deepcopy(snapshot_data) 199 | model.load_state_dict(state_dict, strict=True) 200 | return model 201 | 202 | 203 | def pdisconet_resnet101(pretrained=True, model_dataset="nabirds", k=8, model_url="", num_cls=555): 204 | """ 205 | Function to load the PDiscoNet model with ResNet-101 backbone 206 | :param pretrained: Boolean flag to load the pretrained weights 207 | :param model_dataset: Dataset for which the model is trained 208 | :param k: Number of unsupervised landmarks the model is trained on 209 | :param model_url: URL to load the model weights from 210 | :param num_cls: Number of classes in the dataset 211 | :return: PDiscoNet model with ResNet-101 backbone 212 | """ 213 | model = pdisconet_resnet_torchvision_bb("resnet101", num_cls=num_cls, k=k) 214 | if pretrained: 215 | hub_dir = torch.hub.get_dir() 216 | model_dir = os.path.join(hub_dir, "pdiscoformer_checkpoints", f"pdisconet_{model_dataset}") 217 | 218 | Path(model_dir).mkdir(parents=True, exist_ok=True) 219 | url_path = model_url + str(k) + "_parts_snapshot_best.pt" 220 | snapshot_data = torch.hub.load_state_dict_from_url(url_path, model_dir=model_dir, map_location='cpu') 221 | if 'model_state' in snapshot_data: 222 | _, state_dict = load_state_dict_pdisco(snapshot_data) 223 | else: 224 | state_dict = copy.deepcopy(snapshot_data) 225 | model.load_state_dict(state_dict, strict=True) 226 | return model 227 | -------------------------------------------------------------------------------- /load_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from data_sets import FineGrainedBirdClassificationDataset, CelebA, ImageNetWithOODEval, PartImageNetDataset, PlantNet 4 | from torchvision import datasets 5 | from collections import Counter 6 | 7 | 8 | def get_dataset(args, train_transforms, test_transforms): 9 | if args.dataset == 'cub' or args.dataset == 'nabirds': 10 | dataset_train = FineGrainedBirdClassificationDataset(args.data_path, split=args.train_split, mode='train', 11 | transform=train_transforms, 12 | image_sub_path=args.image_sub_path_train) 13 | dataset_test = FineGrainedBirdClassificationDataset(args.data_path, mode=args.eval_mode, 14 | transform=test_transforms, 15 | image_sub_path=args.image_sub_path_test) 16 | num_cls = dataset_train.num_classes 17 | elif args.dataset == 'celeba': 18 | dataset_train = CelebA(args.data_path, split='train', align=False, percentage=0.3, 19 | transform=train_transforms, resize=args.image_size, 20 | image_sub_path=args.image_sub_path_train) 21 | dataset_test = CelebA(args.data_path, split=args.eval_mode, align=False, percentage=0.3, 22 | transform=test_transforms, resize=args.image_size, 23 | image_sub_path=args.image_sub_path_test) 24 | num_cls = dataset_train.num_classes 25 | elif args.dataset == 'pug': 26 | if args.eval_mode == 'val' or args.eval_mode == 'train': 27 | dataset_train = datasets.ImageFolder(os.path.join(args.data_path, args.image_sub_path_train), 28 | train_transforms) 29 | train_class_to_num_instances = Counter(dataset_train.targets) 30 | dataset_train.cls_num_list = [train_class_to_num_instances[idx] for idx in 31 | range(len(dataset_train.classes))] 32 | dataset_test = datasets.ImageFolder(os.path.join(args.data_path, args.image_sub_path_test), test_transforms) 33 | else: 34 | dataset_train_p1 = datasets.ImageFolder(os.path.join(args.data_path, args.image_sub_path_train), 35 | train_transforms) 36 | p1_class_to_num_instances = Counter(dataset_train_p1.targets) 37 | dataset_train_p2 = datasets.ImageFolder(os.path.join(args.data_path, 'val'), train_transforms) 38 | p2_class_to_num_instances = Counter(dataset_train_p2.targets) 39 | train_class_to_num_instances = p1_class_to_num_instances + p2_class_to_num_instances 40 | dataset_train = torch.utils.data.ConcatDataset([dataset_train_p1, dataset_train_p2]) 41 | dataset_train.cls_num_list = [train_class_to_num_instances[idx] for idx in 42 | range(len(dataset_train.classes))] 43 | dataset_test = datasets.ImageFolder(os.path.join(args.data_path, args.image_sub_path_test), test_transforms) 44 | test_class_to_num_instances = Counter(dataset_test.targets) 45 | dataset_test.cls_num_list = [test_class_to_num_instances[idx] for idx in range(len(dataset_test.classes))] 46 | num_cls = len(dataset_test.classes) 47 | elif args.dataset == 'imagenet': 48 | if args.eval_mode == 'val' or args.eval_mode == 'train': 49 | dataset_train = ImageNetWithOODEval(args.data_path, args.image_sub_path_train, 50 | transform=train_transforms) 51 | dataset_test = ImageNetWithOODEval(args.data_path, args.image_sub_path_test, 52 | transform=test_transforms) 53 | else: 54 | dataset_train_p1 = ImageNetWithOODEval(args.data_path, args.image_sub_path_train, 55 | transform=train_transforms) 56 | p1_cls_num_list = dataset_train_p1.cls_num_list 57 | dataset_train_p2 = ImageNetWithOODEval(args.data_path, 'val', 58 | transform=train_transforms) 59 | p2_cls_num_list = dataset_train_p2.cls_num_list 60 | train_cls_num_list = [p1_cls_num_list[idx] + p2_cls_num_list[idx] for idx in 61 | range(len(p1_cls_num_list))] 62 | dataset_train = torch.utils.data.ConcatDataset([dataset_train_p1, dataset_train_p2]) 63 | dataset_train.cls_num_list = train_cls_num_list 64 | dataset_test = ImageNetWithOODEval(args.data_path, args.image_sub_path_test, 65 | transform=test_transforms) 66 | num_cls = dataset_train.num_classes 67 | elif args.dataset == 'part_imagenet': 68 | if args.eval_mode == 'val' or args.eval_mode == 'train': 69 | dataset_train = PartImageNetDataset(data_path=args.data_path, image_sub_path=args.image_sub_path_train, 70 | transform=train_transforms, 71 | annotation_file_path=args.anno_path_train) 72 | dataset_test = PartImageNetDataset(data_path=args.data_path, image_sub_path=args.image_sub_path_test, 73 | transform=test_transforms, annotation_file_path=args.anno_path_test, 74 | class_names=dataset_train.class_names, 75 | class_names_to_idx=dataset_train.class_names_to_idx, 76 | class_idx_to_names=dataset_train.class_idx_to_names) 77 | else: 78 | dataset_train_p1 = PartImageNetDataset(data_path=args.data_path, 79 | image_sub_path=args.image_sub_path_train, 80 | transform=train_transforms, 81 | annotation_file_path=args.anno_path_train) 82 | p1_cls_num_list = dataset_train_p1.cls_num_list 83 | dataset_train_p2 = PartImageNetDataset(data_path=args.data_path, image_sub_path='val', 84 | transform=train_transforms, 85 | annotation_file_path=args.anno_path_train.replace('train', 'val'), 86 | class_names=dataset_train_p1.class_names, 87 | class_idx_to_names=dataset_train_p1.class_idx_to_names, 88 | class_names_to_idx=dataset_train_p1.class_names_to_idx) 89 | p2_cls_num_list = dataset_train_p2.cls_num_list 90 | train_cls_num_list = [p1_cls_num_list[idx] + p2_cls_num_list[idx] for idx in 91 | range(len(p1_cls_num_list))] 92 | dataset_train = torch.utils.data.ConcatDataset([dataset_train_p1, dataset_train_p2]) 93 | dataset_train.cls_num_list = train_cls_num_list 94 | dataset_test = PartImageNetDataset(data_path=args.data_path, image_sub_path=args.image_sub_path_test, 95 | transform=test_transforms, annotation_file_path=args.anno_path_test, 96 | class_names=dataset_train_p1.class_names, 97 | class_idx_to_names=dataset_train_p1.class_idx_to_names, 98 | class_names_to_idx=dataset_train_p1.class_names_to_idx) 99 | 100 | num_cls = dataset_test.num_classes 101 | 102 | elif args.dataset == 'part_imagenet_ood': 103 | 104 | dataset_train = PartImageNetDataset(data_path=args.data_path, image_sub_path=args.image_sub_path_train, 105 | transform=train_transforms, 106 | annotation_file_path=args.anno_path_train) 107 | dataset_test = PartImageNetDataset(data_path=args.data_path, image_sub_path=args.image_sub_path_test, 108 | transform=test_transforms, annotation_file_path=args.anno_path_test, 109 | class_names=dataset_train.class_names, 110 | class_names_to_idx=dataset_train.class_names_to_idx, 111 | class_idx_to_names=dataset_train.class_idx_to_names) 112 | num_cls = dataset_test.num_classes 113 | 114 | elif args.dataset == 'fgvc_aircraft': 115 | if args.eval_mode == 'val' or args.eval_mode == 'train': 116 | dataset_train = datasets.FGVCAircraft(root=args.data_path, split='train', transform=train_transforms, 117 | target_transform=None, download=True) 118 | dataset_test = datasets.FGVCAircraft(root=args.data_path, split='val', transform=test_transforms, 119 | target_transform=None, download=True) 120 | else: 121 | dataset_train = datasets.FGVCAircraft(root=args.data_path, split='trainval', transform=train_transforms, 122 | target_transform=None, download=True) 123 | dataset_test = datasets.FGVCAircraft(root=args.data_path, split='test', transform=test_transforms, 124 | target_transform=None, download=True) 125 | train_class_to_num_instances = Counter(dataset_train.targets) 126 | dataset_train.cls_num_list = [train_class_to_num_instances[idx] for idx in range(len(dataset_train.classes))] 127 | test_class_to_num_instances = Counter(dataset_test.targets) 128 | dataset_test.cls_num_list = [test_class_to_num_instances[idx] for idx in range(len(dataset_test.classes))] 129 | num_cls = len(dataset_test.classes) 130 | elif args.dataset == 'flowers102': 131 | if args.eval_mode == 'val' or args.eval_mode == 'train': 132 | dataset_train = datasets.Flowers102(root=args.data_path, split='train', transform=train_transforms, 133 | target_transform=None, download=True) 134 | dataset_test = datasets.Flowers102(root=args.data_path, split='val', transform=test_transforms, 135 | target_transform=None, download=True) 136 | else: 137 | dataset_train = datasets.Flowers102(root=args.data_path, split='train', transform=train_transforms, 138 | target_transform=None, download=True) 139 | dataset_test = datasets.Flowers102(root=args.data_path, split='test', transform=test_transforms, 140 | target_transform=None, download=True) 141 | num_cls = len(set(dataset_test._labels)) 142 | train_class_to_num_instances = Counter(dataset_train._labels) 143 | dataset_train.cls_num_list = [train_class_to_num_instances[idx] for idx in range(num_cls)] 144 | test_class_to_num_instances = Counter(dataset_test._labels) 145 | dataset_test.cls_num_list = [test_class_to_num_instances[idx] for idx in range(num_cls)] 146 | elif args.dataset == 'plantnet': 147 | dataset_train = PlantNet(args.data_path, args.image_sub_path_train, transform=train_transforms, 148 | metadata_path=args.metadata_path, 149 | species_id_to_name_file=args.species_id_to_name_file) 150 | dataset_test = PlantNet(args.data_path, args.image_sub_path_test, transform=test_transforms, 151 | metadata_path=args.metadata_path, 152 | species_id_to_name_file=args.species_id_to_name_file) 153 | num_cls = dataset_test.num_classes 154 | else: 155 | raise ValueError('Dataset not supported.') 156 | return dataset_train, dataset_test, num_cls 157 | -------------------------------------------------------------------------------- /model_zoo.md: -------------------------------------------------------------------------------- 1 | # Model Zoo 2 | 3 | We provide the pre-trained models for the following datasets: 4 | - CUB-200-2011 5 | - PartImageNet OOD 6 | - Oxford Flowers 7 | - PartImageNet Seg 8 | - NABirds 9 | 10 | The models can be downloaded from the links provided below. They can also be loaded using torch hub using the code snippets provided below. 11 | 12 | ## How to Get Started with the Model with Hugging Face 🤗 13 | ```python 14 | from models import IndividualLandmarkViT 15 | 16 | # CUB Models 17 | pdiscoformer_cub_k_4 = IndividualLandmarkViT.from_pretrained("ananthu-aniraj/pdiscoformer_cub_k_4") 18 | pdiscoformer_cub_k_8 = IndividualLandmarkViT.from_pretrained("ananthu-aniraj/pdiscoformer_cub_k_8") 19 | pdiscoformer_cub_k_16 = IndividualLandmarkViT.from_pretrained("ananthu-aniraj/pdiscoformer_cub_k_16") 20 | 21 | # PartImageNet OOD Models 22 | pdiscoformer_partimagenet_ood_k_8 = IndividualLandmarkViT.from_pretrained("ananthu-aniraj/pdiscoformer_part_imagenet_ood_k_8", input_size=224) 23 | pdiscoformer_partimagenet_ood_k_25 = IndividualLandmarkViT.from_pretrained("ananthu-aniraj/pdiscoformer_part_imagenet_ood_k_25", input_size=224) 24 | pdiscoformer_partimagenet_ood_k_50 = IndividualLandmarkViT.from_pretrained("ananthu-aniraj/pdiscoformer_part_imagenet_ood_k_50", input_size=224) 25 | 26 | # Oxford Flowers Models 27 | pdiscoformer_flowers_k_2 = IndividualLandmarkViT.from_pretrained("ananthu-aniraj/pdiscoformer_flowers_k_2", input_size=224) 28 | pdiscoformer_flowers_k_4 = IndividualLandmarkViT.from_pretrained("ananthu-aniraj/pdiscoformer_flowers_k_4", input_size=224) 29 | pdiscoformer_flowers_k_8 = IndividualLandmarkViT.from_pretrained("ananthu-aniraj/pdiscoformer_flowers_k_8", input_size=224) 30 | ``` 31 | 32 | 33 | 34 | ## How to Get Started with the Model with Torch Hub 35 | 36 | ```python 37 | import torch 38 | 39 | # CUB Models 40 | pdiscoformer_cub_k_4 = torch.hub.load("ananthu-aniraj/pdiscoformer:main", 'pdiscoformer_cub_k_4', pretrained=True, trust_repo=True) 41 | pdiscoformer_cub_k_8 = torch.hub.load("ananthu-aniraj/pdiscoformer:main", 'pdiscoformer_cub_k_8', pretrained=True, trust_repo=True) 42 | pdiscoformer_cub_k_16 = torch.hub.load("ananthu-aniraj/pdiscoformer:main", 'pdiscoformer_cub_k_16', pretrained=True, trust_repo=True) 43 | 44 | # PartImageNet OOD Models 45 | pdiscoformer_partimagenet_ood_k_8 = torch.hub.load("ananthu-aniraj/pdiscoformer:main", 'pdiscoformer_pimagenet_k_8', pretrained=True, trust_repo=True) 46 | pdiscoformer_partimagenet_ood_k_25 = torch.hub.load("ananthu-aniraj/pdiscoformer:main", 'pdiscoformer_pimagenet_k_25', pretrained=True, trust_repo=True) 47 | pdiscoformer_partimagenet_ood_k_50 = torch.hub.load("ananthu-aniraj/pdiscoformer:main", 'pdiscoformer_pimagenet_k_50', pretrained=True, trust_repo=True) 48 | 49 | 50 | # Oxford Flowers Models 51 | pdiscoformer_flowers_k_2 = torch.hub.load("ananthu-aniraj/pdiscoformer:main", 'pdiscoformer_flowers_k_2', pretrained=True, trust_repo=True) 52 | pdiscoformer_flowers_k_4 = torch.hub.load("ananthu-aniraj/pdiscoformer:main", 'pdiscoformer_flowers_k_4', pretrained=True, trust_repo=True) 53 | pdiscoformer_flowers_k_8 = torch.hub.load("ananthu-aniraj/pdiscoformer:main", 'pdiscoformer_flowers_k_8', pretrained=True, trust_repo=True) 54 | ``` 55 | 56 | The full list of model keys are provided using the following code snippet: 57 | 58 | ```python 59 | import torch 60 | torch.hub.list("ananthu-aniraj/pdiscoformer:main") 61 | ``` 62 | 63 | # Pre-trained Models 64 | Please note that these models were recently trained and may have slight deviations in performance compared to the models reported in the paper. 65 | 66 | ## CUB-200-2011 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 |
ModelBackboneKURL
PdiscoFormerViT-B4Download
PdiscoFormerViT-B8Download
PdiscoFormerViT-B16Download
94 | 95 | ## PartImageNet OOD 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 |
ModelBackboneKURL
PdiscoFormerViT-B8Download
PdiscoFormerViT-B25Download
PdiscoFormerViT-B50Download
123 | 124 | ## Oxford Flowers 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 |
ModelBackboneKURL
PdiscoFormerViT-B2Download
PdiscoFormerViT-B4Download
PdiscoFormerViT-B8Download
152 | 153 | ## PartImageNet Seg 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 |
ModelBackboneKURL
PdiscoFormerViT-B8Download
PdiscoFormerViT-B16Download
PdiscoFormerViT-B25Download
PdiscoFormerViT-B41Download
PdiscoFormerViT-B50Download
193 | 194 | ## NABirds 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 |
MethodKKpNMIARITop-1 AccuracyURL
Dino4-26.5011.30--
Dino8-39.4523.20--
Dino11-39.2323.37--
Huang414.5431.9919.3185.46-
Huang813.4742.0627.3285.17-
Huang1112.8244.0829.2885.14-
PDiscoNet411.531.9313.3283.56Download
PDiscoNet811.1937.6019.4784.31Download
PDiscoNet119.5943.5729.6384.51Download
PDiscoNet + ViT-B49.7643.0222.7787.74Download
PDiscoNet + ViT-B89.1756.5034.1085.60Download
PDiscoNet + ViT-B119.3468.9254.6583.37Download
PDiscoFormer411.2248.2427.7389.29Download
PDiscoFormer88.8460.3946.7488.72Download
PDiscoFormer118.3672.0463.3588.69Download
208 | --------------------------------------------------------------------------------