├── setup_commands.sh ├── imgs ├── method.png ├── results_final.png └── results_continual.png ├── kaizen ├── __init__.py ├── args │ ├── __init__.py │ ├── continual.py │ ├── dataset.py │ └── utils.py ├── losses │ ├── wmse.py │ ├── byol.py │ ├── simsiam.py │ ├── nnclr.py │ ├── __init__.py │ ├── moco.py │ ├── swav.py │ ├── deepclusterv2.py │ ├── ressl.py │ ├── barlow.py │ ├── vicreg.py │ ├── dino.py │ └── simclr.py ├── utils │ ├── __init__.py │ ├── gather_layer.py │ ├── whitening.py │ ├── metrics.py │ ├── trunc_normal.py │ ├── datasets.py │ ├── sinkhorn_knopp.py │ ├── momentum.py │ ├── lars.py │ ├── checkpointer.py │ ├── kmeans.py │ ├── auto_umap.py │ └── knn.py ├── methods │ ├── multi_layer_classifier.py │ ├── __init__.py │ ├── barlow_twins.py │ ├── vicreg.py │ ├── wmse.py │ ├── simsiam.py │ ├── byol.py │ ├── ressl.py │ ├── simclr.py │ └── mocov2plus.py ├── distillers │ ├── __init__.py │ ├── base.py │ ├── predictive_mse.py │ ├── predictive.py │ ├── contrastive.py │ ├── decorrelative.py │ └── knowledge.py └── distiller_factories │ ├── __init__.py │ ├── base.py │ ├── predictive_mse.py │ ├── predictive.py │ ├── soft_label.py │ ├── contrastive.py │ ├── decorrelative.py │ └── knowledge.py ├── requirements.txt ├── LICENSE.md ├── bash_files ├── byol_cifar_distill_classifier_l1000_soft_label_replay_0.01_b32.sh ├── mocov2plus_cifar_distill_classifier_l1000_soft_label_replay_0.01_b32.sh ├── simclr_cifar_distill_classifier_l1000_soft_label_replay_0.01_b32.sh └── vicreg_cifar_distill_classifier_l1000_soft_label_replay_0.01_b32.sh ├── job_launcher.py ├── .gitignore ├── main_continual.py ├── main_linear.py ├── README.md ├── evaluate_folder.py └── main_eval.py /setup_commands.sh: -------------------------------------------------------------------------------- 1 | export LC_ALL=C.UTF-8 2 | export LANG=C.UTF-8 3 | -------------------------------------------------------------------------------- /imgs/method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nokia-Bell-Labs/kaizen/HEAD/imgs/method.png -------------------------------------------------------------------------------- /imgs/results_final.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nokia-Bell-Labs/kaizen/HEAD/imgs/results_final.png -------------------------------------------------------------------------------- /imgs/results_continual.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Nokia-Bell-Labs/kaizen/HEAD/imgs/results_continual.png -------------------------------------------------------------------------------- /kaizen/__init__.py: -------------------------------------------------------------------------------- 1 | from kaizen import args, losses, methods, utils 2 | 3 | __all__ = ["args", "losses", "methods", "utils"] 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.10 2 | torchvision 3 | pytorch-lightning 4 | lightning-bolts 5 | wandb 6 | scikit-learn 7 | einops 8 | torchaudio 9 | -------------------------------------------------------------------------------- /kaizen/args/__init__.py: -------------------------------------------------------------------------------- 1 | from kaizen.args import dataset, setup, utils, continual 2 | 3 | __all__ = ["dataset", "setup", "utils", "continual"] 4 | -------------------------------------------------------------------------------- /kaizen/losses/wmse.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def wmse_loss_func(z1: torch.Tensor, z2: torch.Tensor, simplified: bool = True) -> torch.Tensor: 6 | """Computes W-MSE's loss given two batches of whitened features z1 and z2. 7 | 8 | Args: 9 | z1 (torch.Tensor): NxD Tensor containing whitened features from view 1. 10 | z2 (torch.Tensor): NxD Tensor containing whitened features from view 2. 11 | simplified (bool): faster computation, but with same result. 12 | 13 | Returns: 14 | torch.Tensor: W-MSE loss. 15 | """ 16 | 17 | if simplified: 18 | return 2 - 2 * F.cosine_similarity(z1, z2.detach(), dim=-1).mean() 19 | else: 20 | z1 = F.normalize(z1, dim=-1) 21 | z2 = F.normalize(z2, dim=-1) 22 | 23 | return 2 - 2 * (z1 * z2).sum(dim=-1).mean() 24 | -------------------------------------------------------------------------------- /kaizen/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from kaizen.utils import ( 2 | checkpointer, 3 | classification_dataloader, 4 | datasets, 5 | gather_layer, 6 | knn, 7 | lars, 8 | metrics, 9 | momentum, 10 | pretrain_dataloader, 11 | sinkhorn_knopp, 12 | ) 13 | 14 | __all__ = [ 15 | "classification_dataloader", 16 | "pretrain_dataloader", 17 | "checkpointer", 18 | "datasets", 19 | "gather_layer", 20 | "knn", 21 | "lars", 22 | "metrics", 23 | "momentum", 24 | "sinkhorn_knopp", 25 | ] 26 | 27 | try: 28 | from kaizen.utils import dali_dataloader # noqa: F401 29 | except ImportError: 30 | pass 31 | else: 32 | __all__.append("dali_dataloader") 33 | 34 | try: 35 | from kaizen.utils import auto_umap # noqa: F401 36 | except ImportError: 37 | pass 38 | else: 39 | __all__.append("auto_umap") 40 | -------------------------------------------------------------------------------- /kaizen/losses/byol.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def byol_loss_func(p: torch.Tensor, z: torch.Tensor, simplified: bool = True) -> torch.Tensor: 6 | """Computes BYOL's loss given batch of predicted features p and projected momentum features z. 7 | 8 | Args: 9 | p (torch.Tensor): NxD Tensor containing predicted features from view 1 10 | z (torch.Tensor): NxD Tensor containing projected momentum features from view 2 11 | simplified (bool): faster computation, but with same result. Defaults to True. 12 | 13 | Returns: 14 | torch.Tensor: BYOL's loss. 15 | """ 16 | 17 | if simplified: 18 | return 2 - 2 * F.cosine_similarity(p, z.detach(), dim=-1).mean() 19 | else: 20 | p = F.normalize(p, dim=-1) 21 | z = F.normalize(z, dim=-1) 22 | 23 | return 2 - 2 * (p * z.detach()).sum(dim=1).mean() 24 | -------------------------------------------------------------------------------- /kaizen/losses/simsiam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def simsiam_loss_func(p: torch.Tensor, z: torch.Tensor, simplified: bool = True) -> torch.Tensor: 6 | """Computes SimSiam's loss given batch of predicted features p from view 1 and 7 | a batch of projected features z from view 2. 8 | 9 | Args: 10 | p (torch.Tensor): Tensor containing predicted features from view 1. 11 | z (torch.Tensor): Tensor containing projected features from view 2. 12 | simplified (bool): faster computation, but with same result. 13 | 14 | Returns: 15 | torch.Tensor: SimSiam loss. 16 | """ 17 | 18 | if simplified: 19 | return -F.cosine_similarity(p, z.detach(), dim=-1).mean() 20 | else: 21 | p = F.normalize(p, dim=-1) 22 | z = F.normalize(z, dim=-1) 23 | 24 | return -(p * z.detach()).sum(dim=1).mean() 25 | -------------------------------------------------------------------------------- /kaizen/methods/multi_layer_classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class MultiLayerClassifier(nn.Module): 6 | def __init__(self, input_dim, num_classes, layer_units=[]): 7 | super().__init__() 8 | self.input_dim = input_dim 9 | self.num_classes = num_classes 10 | self.layer_units = layer_units 11 | 12 | in_dim = self.input_dim 13 | self.all_layers = [] 14 | for i, num_units in enumerate(layer_units): 15 | layer_name = f"fc_{i}" 16 | layer = nn.Linear(in_dim, num_units) 17 | in_dim = num_units 18 | setattr(self, layer_name, layer) 19 | self.all_layers.append(layer) 20 | self.fc_output = nn.Linear(in_dim, self.num_classes) 21 | 22 | def forward(self, x): 23 | for l in self.all_layers: 24 | x = F.relu(l(x)) 25 | return self.fc_output(x) 26 | 27 | 28 | -------------------------------------------------------------------------------- /kaizen/losses/nnclr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def nnclr_loss_func(nn: torch.Tensor, p: torch.Tensor, temperature: float = 0.1) -> torch.Tensor: 6 | """Computes NNCLR's loss given batch of nearest-neighbors nn from view 1 and 7 | predicted features p from view 2. 8 | 9 | Args: 10 | nn (torch.Tensor): NxD Tensor containing nearest neighbors' features from view 1. 11 | p (torch.Tensor): NxD Tensor containing predicted features from view 2 12 | temperature (float, optional): temperature of the softmax in the contrastive loss. Defaults 13 | to 0.1. 14 | 15 | Returns: 16 | torch.Tensor: NNCLR loss. 17 | """ 18 | 19 | nn = F.normalize(nn, dim=-1) 20 | p = F.normalize(p, dim=-1) 21 | 22 | logits = nn @ p.T / temperature 23 | 24 | n = p.size(0) 25 | labels = torch.arange(n, device=p.device) 26 | 27 | loss = F.cross_entropy(logits, labels) 28 | return loss 29 | -------------------------------------------------------------------------------- /kaizen/distillers/__init__.py: -------------------------------------------------------------------------------- 1 | from kaizen.distillers.base import base_distill_wrapper 2 | from kaizen.distillers.contrastive import contrastive_distill_wrapper 3 | from kaizen.distillers.decorrelative import decorrelative_distill_wrapper 4 | from kaizen.distillers.knowledge import knowledge_distill_wrapper 5 | from kaizen.distillers.predictive import predictive_distill_wrapper 6 | from kaizen.distillers.predictive_mse import predictive_mse_distill_wrapper 7 | 8 | 9 | __all__ = [ 10 | "base_distill_wrapper", 11 | "contrastive_distill_wrapper", 12 | "decorrelative_distill_wrapper", 13 | "nearest_neighbor_distill_wrapper", 14 | "predictive_distill_wrapper", 15 | "predictive_mse_distill_wrapper", 16 | ] 17 | 18 | DISTILLERS = { 19 | "base": base_distill_wrapper, 20 | "contrastive": contrastive_distill_wrapper, 21 | "decorrelative": decorrelative_distill_wrapper, 22 | "knowledge": knowledge_distill_wrapper, 23 | "predictive": predictive_distill_wrapper, 24 | "predictive_mse": predictive_mse_distill_wrapper, 25 | } 26 | -------------------------------------------------------------------------------- /kaizen/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from kaizen.losses.barlow import barlow_loss_func 2 | from kaizen.losses.byol import byol_loss_func 3 | from kaizen.losses.deepclusterv2 import deepclusterv2_loss_func 4 | from kaizen.losses.dino import DINOLoss 5 | from kaizen.losses.moco import moco_loss_func 6 | from kaizen.losses.nnclr import nnclr_loss_func 7 | from kaizen.losses.ressl import ressl_loss_func 8 | from kaizen.losses.simclr import manual_simclr_loss_func, simclr_loss_func, simclr_distill_loss_func 9 | from kaizen.losses.simsiam import simsiam_loss_func 10 | from kaizen.losses.swav import swav_loss_func 11 | from kaizen.losses.vicreg import vicreg_loss_func 12 | from kaizen.losses.wmse import wmse_loss_func 13 | 14 | __all__ = [ 15 | "barlow_loss_func", 16 | "byol_loss_func", 17 | "deepclusterv2_loss_func", 18 | "DINOLoss", 19 | "moco_loss_func", 20 | "nnclr_loss_func", 21 | "ressl_loss_func", 22 | "simclr_loss_func", 23 | "manual_simclr_loss_func", 24 | "simclr_distill_loss_func", 25 | "simsiam_loss_func", 26 | "swav_loss_func", 27 | "vicreg_loss_func", 28 | "wmse_loss_func", 29 | ] 30 | -------------------------------------------------------------------------------- /kaizen/utils/gather_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | 4 | 5 | class GatherLayer(torch.autograd.Function): 6 | """Gathers tensors from all processes, supporting backward propagation.""" 7 | 8 | @staticmethod 9 | def forward(ctx, input): 10 | ctx.save_for_backward(input) 11 | if dist.is_available() and dist.is_initialized(): 12 | output = [torch.zeros_like(input) for _ in range(dist.get_world_size())] 13 | dist.all_gather(output, input) 14 | else: 15 | output = [input] 16 | return tuple(output) 17 | 18 | @staticmethod 19 | def backward(ctx, *grads): 20 | (input,) = ctx.saved_tensors 21 | if dist.is_available() and dist.is_initialized(): 22 | grad_out = torch.zeros_like(input) 23 | grad_out[:] = grads[dist.get_rank()] 24 | else: 25 | grad_out = grads[0] 26 | return grad_out 27 | 28 | 29 | def gather(X, dim=0): 30 | """Gathers tensors from all processes, supporting backward propagation.""" 31 | return torch.cat(GatherLayer.apply(X), dim=dim) 32 | -------------------------------------------------------------------------------- /kaizen/losses/moco.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def moco_loss_func( 6 | query: torch.Tensor, key: torch.Tensor, queue: torch.Tensor, temperature=0.1 7 | ) -> torch.Tensor: 8 | """Computes MoCo's loss given a batch of queries from view 1, a batch of keys from view 2 and a 9 | queue of past elements. 10 | 11 | Args: 12 | query (torch.Tensor): NxD Tensor containing the queries from view 1. 13 | key (torch.Tensor): NxD Tensor containing the queries from view 2. 14 | queue (torch.Tensor): a queue of negative samples for the contrastive loss. 15 | temperature (float, optional): [description]. temperature of the softmax in the contrastive 16 | loss. Defaults to 0.1. 17 | 18 | Returns: 19 | torch.Tensor: MoCo loss. 20 | """ 21 | 22 | pos = torch.einsum("nc,nc->n", [query, key]).unsqueeze(-1) 23 | neg = torch.einsum("nc,ck->nk", [query, queue]) 24 | logits = torch.cat([pos, neg], dim=1) 25 | logits /= temperature 26 | targets = torch.zeros(query.size(0), device=query.device, dtype=torch.long) 27 | return F.cross_entropy(logits, targets) 28 | -------------------------------------------------------------------------------- /kaizen/losses/swav.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def swav_loss_func( 8 | preds: List[torch.Tensor], assignments: List[torch.Tensor], temperature: float = 0.1 9 | ) -> torch.Tensor: 10 | """Computes SwAV's loss given list of batch predictions from multiple views 11 | and a list of cluster assignments from the same multiple views. 12 | 13 | Args: 14 | preds (torch.Tensor): list of NxC Tensors containing nearest neighbors' features from 15 | view 1. 16 | assignments (torch.Tensor): list of NxC Tensor containing predicted features from view 2. 17 | temperature (torch.Tensor): softmax temperature for the loss. Defaults to 0.1. 18 | 19 | Returns: 20 | torch.Tensor: SwAV loss. 21 | """ 22 | 23 | losses = [] 24 | for v1 in range(len(preds)): 25 | for v2 in np.delete(np.arange(len(preds)), v1): 26 | a = assignments[v1] 27 | p = preds[v2] / temperature 28 | loss = -torch.mean(torch.sum(a * torch.log_softmax(p, dim=1), dim=1)) 29 | losses.append(loss) 30 | return sum(losses) / len(losses) 31 | -------------------------------------------------------------------------------- /kaizen/losses/deepclusterv2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def deepclusterv2_loss_func( 6 | outputs: torch.Tensor, assignments: torch.Tensor, temperature: float = 0.1 7 | ) -> torch.Tensor: 8 | """Computes DeepClusterV2's loss given a tensor containing logits from multiple views 9 | and a tensor containing cluster assignments from the same multiple views. 10 | 11 | Args: 12 | outputs (torch.Tensor): tensor of size PxVxNxC where P is the number of prototype 13 | layers and V is the number of views. 14 | assignments (torch.Tensor): tensor of size PxVxNxC containing the assignments 15 | generated using k-means. 16 | temperature (float, optional): softmax temperature for the loss. Defaults to 0.1. 17 | 18 | Returns: 19 | torch.Tensor: DeepClusterV2 loss. 20 | """ 21 | loss = 0 22 | for h in range(outputs.size(0)): 23 | scores = outputs[h].view(-1, outputs.size(-1)) / temperature 24 | targets = assignments[h].repeat(outputs.size(1)).to(outputs.device, non_blocking=True) 25 | loss += F.cross_entropy(scores, targets, ignore_index=-1) 26 | return loss / outputs.size(0) 27 | -------------------------------------------------------------------------------- /kaizen/args/continual.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from .utils import strtobool 3 | DISTILLER_LIBRARIES = ["default", "factory"] 4 | 5 | def continual_args(parser: ArgumentParser): 6 | """Adds continual learning arguments to a parser. 7 | 8 | Args: 9 | parser (ArgumentParser): parser to add dataset args to. 10 | """ 11 | # base continual learning args 12 | parser.add_argument("--num_tasks", type=int, default=2) 13 | parser.add_argument("--task_idx", type=int, required=True) 14 | 15 | SPLIT_STRATEGIES = ["class", "data", "domain"] 16 | parser.add_argument("--split_strategy", choices=SPLIT_STRATEGIES, type=str, required=True) 17 | 18 | # distillation args 19 | parser.add_argument("--distiller", type=str, default=None) 20 | parser.add_argument("--distiller_classifier", type=str, default=None) 21 | parser.add_argument("--distiller_library", type=str, choices=DISTILLER_LIBRARIES, default=DISTILLER_LIBRARIES[0]) 22 | 23 | # Memory Bank/Replay args 24 | parser.add_argument("--replay", type=strtobool, default=False) 25 | parser.add_argument("--replay_proportion", type=float, default=1.0) 26 | parser.add_argument("--replay_memory_bank_size", type=int, default=None) 27 | parser.add_argument("--replay_batch_size", type=int, default=64) 28 | -------------------------------------------------------------------------------- /kaizen/losses/ressl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def ressl_loss_func( 6 | q: torch.Tensor, 7 | k: torch.Tensor, 8 | queue: torch.Tensor, 9 | temperature_q: float = 0.1, 10 | temperature_k: float = 0.04, 11 | ) -> torch.Tensor: 12 | """Computes ReSSL's loss given a batch of queries from view 1, a batch of keys from view 2 and a 13 | queue of past elements. 14 | 15 | Args: 16 | query (torch.Tensor): NxD Tensor containing the queries from view 1. 17 | key (torch.Tensor): NxD Tensor containing the queries from view 2. 18 | queue (torch.Tensor): a queue of negative samples for the contrastive loss. 19 | temperature_q (float, optional): [description]. temperature of the softmax for the query. 20 | Defaults to 0.1. 21 | temperature_k (float, optional): [description]. temperature of the softmax for the key. 22 | Defaults to 0.04. 23 | 24 | Returns: 25 | torch.Tensor: ReSSL loss. 26 | """ 27 | 28 | logits_q = torch.einsum("nc,kc->nk", [q, queue]) 29 | logits_k = torch.einsum("nc,kc->nk", [k, queue]) 30 | 31 | loss = -torch.sum( 32 | F.softmax(logits_k.detach() / temperature_k, dim=1) 33 | * F.log_softmax(logits_q / temperature_q, dim=1), 34 | dim=1, 35 | ).mean() 36 | 37 | return loss 38 | -------------------------------------------------------------------------------- /kaizen/distiller_factories/__init__.py: -------------------------------------------------------------------------------- 1 | from kaizen.distiller_factories.base import base_frozen_model_factory 2 | from kaizen.distiller_factories.contrastive import contrastive_distill_factory 3 | from kaizen.distiller_factories.decorrelative import decorrelative_distill_factory 4 | from kaizen.distiller_factories.knowledge import knowledge_distill_factory 5 | from kaizen.distiller_factories.predictive import predictive_distill_factory 6 | from kaizen.distiller_factories.predictive_mse import predictive_mse_distill_factory 7 | from kaizen.distiller_factories.soft_label import soft_label_distill_factory 8 | 9 | 10 | __all__ = [ 11 | "base_frozen_model_factory", 12 | "contrastive_distill_factory", 13 | "decorrelative_distill_factory", 14 | "nearest_neighbor_distill_wrapper", # TODO: Check what this is 15 | "knowledge_distill_factory", 16 | "predictive_distill_factory", 17 | "predictive_mse_distill_factory", 18 | "soft_label_distill_factory" 19 | ] 20 | 21 | DISTILLER_FACTORIES = { 22 | "base": base_frozen_model_factory, 23 | "contrastive": contrastive_distill_factory, 24 | "decorrelative": decorrelative_distill_factory, 25 | "knowledge": knowledge_distill_factory, 26 | "predictive": predictive_distill_factory, 27 | "predictive_mse": predictive_mse_distill_factory, 28 | "soft_label": soft_label_distill_factory 29 | } 30 | -------------------------------------------------------------------------------- /kaizen/losses/barlow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torch.distributed as dist 4 | 5 | 6 | def barlow_loss_func( 7 | z1: torch.Tensor, z2: torch.Tensor, lamb: float = 5e-3, scale_loss: float = 0.025 8 | ) -> torch.Tensor: 9 | """Computes Barlow Twins' loss given batch of projected features z1 from view 1 and 10 | projected features z2 from view 2. 11 | 12 | Args: 13 | z1 (torch.Tensor): NxD Tensor containing projected features from view 1. 14 | z2 (torch.Tensor): NxD Tensor containing projected features from view 2. 15 | lamb (float, optional): off-diagonal scaling factor for the cross-covariance matrix. 16 | Defaults to 5e-3. 17 | scale_loss (float, optional): final scaling factor of the loss. Defaults to 0.025. 18 | 19 | Returns: 20 | torch.Tensor: Barlow Twins' loss. 21 | """ 22 | 23 | N, D = z1.size() 24 | 25 | # to match the original code 26 | bn = torch.nn.BatchNorm1d(D, affine=False).to(z1.device) 27 | z1 = bn(z1) 28 | z2 = bn(z2) 29 | 30 | corr = torch.einsum("bi, bj -> ij", z1, z2) / N 31 | 32 | if dist.is_available() and dist.is_initialized(): 33 | dist.all_reduce(corr) 34 | world_size = dist.get_world_size() 35 | corr /= world_size 36 | 37 | diag = torch.eye(D, device=corr.device) 38 | cdif = (corr - diag).pow(2) 39 | cdif[~diag.bool()] *= lamb 40 | loss = scale_loss * cdif.sum() 41 | return loss 42 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright (c) 2024 Nokia 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without modification, are permitted (subject to the limitations in the disclaimer below) provided that the following conditions are met: 5 | 6 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | * Neither the name of [Owner Organization] nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 9 | 10 | NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /kaizen/methods/__init__.py: -------------------------------------------------------------------------------- 1 | from kaizen.methods.barlow_twins import BarlowTwins 2 | from kaizen.methods.base import BaseModel 3 | from kaizen.methods.byol import BYOL 4 | from kaizen.methods.deepclusterv2 import DeepClusterV2 5 | from kaizen.methods.dino import DINO 6 | from kaizen.methods.linear import LinearModel 7 | from kaizen.methods.mocov2plus import MoCoV2Plus 8 | from kaizen.methods.nnclr import NNCLR 9 | from kaizen.methods.ressl import ReSSL 10 | from kaizen.methods.simclr import SimCLR 11 | from kaizen.methods.simsiam import SimSiam 12 | from kaizen.methods.swav import SwAV 13 | from kaizen.methods.vicreg import VICReg 14 | from kaizen.methods.wmse import WMSE 15 | from kaizen.methods.full_model import FullModel 16 | 17 | METHODS = { 18 | # base classes 19 | "base": BaseModel, 20 | "linear": LinearModel, 21 | "full_model": FullModel, 22 | # methods 23 | "barlow_twins": BarlowTwins, 24 | "byol": BYOL, 25 | "deepclusterv2": DeepClusterV2, 26 | "dino": DINO, 27 | "mocov2plus": MoCoV2Plus, 28 | "nnclr": NNCLR, 29 | "ressl": ReSSL, 30 | "simclr": SimCLR, 31 | "simsiam": SimSiam, 32 | "swav": SwAV, 33 | "vicreg": VICReg, 34 | "wmse": WMSE, 35 | } 36 | __all__ = [ 37 | "BarlowTwins", 38 | "BYOL", 39 | "BaseModel", 40 | "DeepClusterV2", 41 | "DINO", 42 | "LinearModel", 43 | "FullModel", 44 | "MoCoV2Plus", 45 | "NNCLR", 46 | "ReSSL", 47 | "SimCLR", 48 | "SimSiam", 49 | "SwAV", 50 | "VICReg", 51 | "WMSE", 52 | ] 53 | 54 | try: 55 | from kaizen.methods import dali # noqa: F401 56 | except ImportError: 57 | pass 58 | else: 59 | __all__.append("dali") 60 | -------------------------------------------------------------------------------- /kaizen/distillers/base.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Any, Sequence 3 | import torch 4 | 5 | 6 | def base_distill_wrapper(Method=object): 7 | class BaseDistillWrapper(Method): 8 | def __init__(self, **kwargs) -> None: 9 | super().__init__(**kwargs) 10 | 11 | self.output_dim = kwargs["output_dim"] 12 | 13 | self.frozen_encoder = deepcopy(self.encoder) 14 | self.frozen_projector = deepcopy(self.projector) 15 | 16 | def on_train_start(self): 17 | super().on_train_start() 18 | 19 | if self.current_task_idx > 0: 20 | 21 | self.frozen_encoder = deepcopy(self.encoder) 22 | self.frozen_projector = deepcopy(self.projector) 23 | 24 | for pg in self.frozen_encoder.parameters(): 25 | pg.requires_grad = False 26 | for pg in self.frozen_projector.parameters(): 27 | pg.requires_grad = False 28 | 29 | @torch.no_grad() 30 | def frozen_forward(self, X): 31 | feats = self.frozen_encoder(X) 32 | return feats, self.frozen_projector(feats) 33 | 34 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: 35 | _, (X1, X2), _ = batch[f"task{self.current_task_idx}"] 36 | 37 | out = super().training_step(batch, batch_idx) 38 | 39 | frozen_feats1, frozen_z1 = self.frozen_forward(X1) 40 | frozen_feats2, frozen_z2 = self.frozen_forward(X2) 41 | 42 | out.update( 43 | {"frozen_feats": [frozen_feats1, frozen_feats2], "frozen_z": [frozen_z1, frozen_z2]} 44 | ) 45 | return out 46 | 47 | return BaseDistillWrapper 48 | -------------------------------------------------------------------------------- /bash_files/byol_cifar_distill_classifier_l1000_soft_label_replay_0.01_b32.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset cifar100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --split_strategy class \ 6 | --max_epochs 500 \ 7 | --num_tasks 5 \ 8 | --task_idx 0 \ 9 | --gpus 0 \ 10 | --precision 16 \ 11 | --optimizer sgd \ 12 | --lars \ 13 | --grad_clip_lars \ 14 | --eta_lars 0.02 \ 15 | --exclude_bias_n_norm \ 16 | --scheduler warmup_cosine \ 17 | --lr 1.0 \ 18 | --classifier_lr 0.1 \ 19 | --weight_decay 1e-5 \ 20 | --batch_size 256 \ 21 | --num_workers 2 \ 22 | --brightness 0.4 \ 23 | --contrast 0.4 \ 24 | --saturation 0.2 \ 25 | --hue 0.1 \ 26 | --gaussian_prob 0.0 0.0 \ 27 | --solarization_prob 0.0 0.2 \ 28 | --name cifar100-byol-predictive-distill-classifier-l1000-soft-label-replay-0.01-b32 \ 29 | --project ever-learn-2 \ 30 | --entity your_entity \ 31 | --offline \ 32 | --wandb \ 33 | --save_checkpoint \ 34 | --output_dim 256 \ 35 | --proj_hidden_dim 4096 \ 36 | --pred_hidden_dim 4096 \ 37 | --base_tau_momentum 0.99 \ 38 | --final_tau_momentum 1.0 \ 39 | --momentum_classifier \ 40 | --disable_knn_eval \ 41 | --online_eval_classifier_lr 0.1 \ 42 | --classifier_training True \ 43 | --classifier_stop_gradient True \ 44 | --classifier_layers 1000 \ 45 | --distiller_library factory \ 46 | --method byol \ 47 | --distiller predictive \ 48 | --distiller_classifier soft_label \ 49 | --classifier_distill_lamb 2.0 \ 50 | --classifier_distill_no_predictior True \ 51 | --replay True \ 52 | --replay_proportion 0.01 \ 53 | --replay_batch_size 32 \ 54 | --online_evaluation True \ 55 | --online_evaluation_training_data_source seen_tasks 56 | -------------------------------------------------------------------------------- /kaizen/utils/whitening.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.cuda.amp import custom_fwd 4 | from torch.nn.functional import conv2d 5 | 6 | 7 | class Whitening2d(nn.Module): 8 | def __init__(self, output_dim: int, eps: float = 0.0): 9 | """Layer that computes hard whitening for W-MSE using the Cholesky decomposition. 10 | 11 | Args: 12 | output_dim (int): number of dimension of projected features. 13 | eps (float, optional): eps for numerical stability in Cholesky decomposition. Defaults 14 | to 0.0. 15 | """ 16 | 17 | super(Whitening2d, self).__init__() 18 | self.output_dim = output_dim 19 | self.eps = eps 20 | 21 | @custom_fwd(cast_inputs=torch.float32) 22 | def forward(self, x: torch.Tensor) -> torch.Tensor: 23 | """Performs whitening using the Cholesky decomposition. 24 | 25 | Args: 26 | x (torch.Tensor): a batch or slice of projected features. 27 | 28 | Returns: 29 | torch.Tensor: a batch or slice of whitened features. 30 | """ 31 | 32 | x = x.unsqueeze(2).unsqueeze(3) 33 | m = x.mean(0).view(self.output_dim, -1).mean(-1).view(1, -1, 1, 1) 34 | xn = x - m 35 | 36 | T = xn.permute(1, 0, 2, 3).contiguous().view(self.output_dim, -1) 37 | f_cov = torch.mm(T, T.permute(1, 0)) / (T.shape[-1] - 1) 38 | 39 | eye = torch.eye(self.output_dim).type(f_cov.type()) 40 | 41 | f_cov_shrinked = (1 - self.eps) * f_cov + self.eps * eye 42 | 43 | inv_sqrt = torch.triangular_solve(eye, torch.cholesky(f_cov_shrinked), upper=False)[0] 44 | inv_sqrt = inv_sqrt.contiguous().view(self.output_dim, self.output_dim, 1, 1) 45 | 46 | decorrelated = conv2d(xn, inv_sqrt) 47 | 48 | return decorrelated.squeeze(2).squeeze(2) 49 | -------------------------------------------------------------------------------- /bash_files/mocov2plus_cifar_distill_classifier_l1000_soft_label_replay_0.01_b32.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset cifar100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --split_strategy class \ 6 | --max_epochs 2 \ 7 | --num_tasks 5 \ 8 | --task_idx 0 \ 9 | --gpus 0 \ 10 | --precision 16 \ 11 | --optimizer sgd \ 12 | --lars \ 13 | --grad_clip_lars \ 14 | --eta_lars 0.02 \ 15 | --exclude_bias_n_norm \ 16 | --scheduler warmup_cosine \ 17 | --lr 1.0 \ 18 | --classifier_lr 0.1 \ 19 | --weight_decay 1e-5 \ 20 | --batch_size 256 \ 21 | --num_workers 2 \ 22 | --brightness 0.4 \ 23 | --contrast 0.4 \ 24 | --saturation 0.2 \ 25 | --hue 0.1 \ 26 | --gaussian_prob 0.0 0.0 \ 27 | --solarization_prob 0.0 0.2 \ 28 | --name cifar100-mocov2plus-contrastive-distill-classifier-l1000-soft-label-replay-0.01-b32 \ 29 | --project ever-learn \ 30 | --entity your_entity \ 31 | --offline \ 32 | --wandb \ 33 | --save_checkpoint \ 34 | --output_dim 256 \ 35 | --proj_hidden_dim 2048 \ 36 | --queue_size 65536 \ 37 | --temperature 0.2 \ 38 | --base_tau_momentum 0.99 \ 39 | --final_tau_momentum 0.999 \ 40 | --momentum_classifier \ 41 | --disable_knn_eval \ 42 | --online_eval_classifier_lr 0.1 \ 43 | --classifier_training True \ 44 | --classifier_stop_gradient True \ 45 | --classifier_layers 1000 \ 46 | --distiller_library factory \ 47 | --method mocov2plus \ 48 | --distiller contrastive \ 49 | --classifier_distill_lamb 2.0 \ 50 | --distiller_classifier soft_label \ 51 | --classifier_distill_no_predictior True \ 52 | --replay True \ 53 | --replay_proportion 0.01 \ 54 | --replay_batch_size 32 \ 55 | --online_evaluation True \ 56 | --online_evaluation_training_data_source seen_tasks -------------------------------------------------------------------------------- /kaizen/utils/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Sequence 2 | 3 | import torch 4 | 5 | 6 | def accuracy_at_k( 7 | outputs: torch.Tensor, targets: torch.Tensor, top_k: Sequence[int] = (1, 5) 8 | ) -> Sequence[int]: 9 | """Computes the accuracy over the k top predictions for the specified values of k. 10 | 11 | Args: 12 | outputs (torch.Tensor): output of a classifier (logits or probabilities). 13 | targets (torch.Tensor): ground truth labels. 14 | top_k (Sequence[int], optional): sequence of top k values to compute the accuracy over. 15 | Defaults to (1, 5). 16 | 17 | Returns: 18 | Sequence[int]: accuracies at the desired k. 19 | """ 20 | 21 | with torch.no_grad(): 22 | maxk = max(top_k) 23 | batch_size = targets.size(0) 24 | 25 | _, pred = outputs.topk(maxk, 1, True, True) 26 | pred = pred.t() 27 | correct = pred.eq(targets.view(1, -1).expand_as(pred)) 28 | 29 | res = [] 30 | for k in top_k: 31 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) 32 | res.append(correct_k.mul_(100.0 / batch_size)) 33 | return res 34 | 35 | 36 | def weighted_mean(outputs: List[Dict], key: str, batch_size_key: str) -> float: 37 | """Computes the mean of the values of a key weighted by the batch size. 38 | 39 | Args: 40 | outputs (List[Dict]): list of dicts containing the outputs of a validation step. 41 | key (str): key of the metric of interest. 42 | batch_size_key (str): key of batch size values. 43 | 44 | Returns: 45 | float: weighted mean of the values of a key 46 | """ 47 | 48 | value = 0 49 | n = 0 50 | for out in outputs: 51 | value += out[batch_size_key] * out[key] 52 | n += out[batch_size_key] 53 | value = value / n 54 | return value.squeeze(0) 55 | -------------------------------------------------------------------------------- /kaizen/utils/trunc_normal.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | 4 | import torch 5 | 6 | 7 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 8 | """Copy & paste from PyTorch official master until it's in a few official releases - RW 9 | Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 10 | """ 11 | 12 | def norm_cdf(x): 13 | """Computes standard normal cumulative distribution function""" 14 | 15 | return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 16 | 17 | if (mean < a - 2 * std) or (mean > b + 2 * std): 18 | warnings.warn( 19 | "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 20 | "The distribution of values may be incorrect.", 21 | stacklevel=2, 22 | ) 23 | 24 | with torch.no_grad(): 25 | # Values are generated by using a truncated uniform distribution and 26 | # then using the inverse CDF for the normal distribution. 27 | # Get upper and lower cdf values 28 | l = norm_cdf((a - mean) / std) 29 | u = norm_cdf((b - mean) / std) 30 | 31 | # Uniformly fill tensor with values from [l, u], then translate to 32 | # [2l-1, 2u-1]. 33 | tensor.uniform_(2 * l - 1, 2 * u - 1) 34 | 35 | # Use inverse cdf transform for normal distribution to get truncated 36 | # standard normal 37 | tensor.erfinv_() 38 | 39 | # Transform to proper mean, std 40 | tensor.mul_(std * math.sqrt(2.0)) 41 | tensor.add_(mean) 42 | 43 | # Clamp to ensure it's in the proper range 44 | tensor.clamp_(min=a, max=b) 45 | return tensor 46 | 47 | 48 | def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): 49 | """Copy & paste from PyTorch official master until it's in a few official releases - RW 50 | Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 51 | """ 52 | 53 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 54 | -------------------------------------------------------------------------------- /bash_files/simclr_cifar_distill_classifier_l1000_soft_label_replay_0.01_b32.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset cifar100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --split_strategy class \ 6 | --max_epochs 500 \ 7 | --num_tasks 5 \ 8 | --task_idx 1 \ 9 | --gpus 0 \ 10 | --precision 16 \ 11 | --optimizer sgd \ 12 | --lars \ 13 | --grad_clip_lars \ 14 | --eta_lars 0.02 \ 15 | --exclude_bias_n_norm \ 16 | --scheduler warmup_cosine \ 17 | --lr 1.0 \ 18 | --classifier_lr 0.1 \ 19 | --weight_decay 1e-5 \ 20 | --batch_size 256 \ 21 | --num_workers 3 \ 22 | --brightness 0.4 \ 23 | --contrast 0.4 \ 24 | --saturation 0.2 \ 25 | --hue 0.1 \ 26 | --gaussian_prob 0.0 0.0 \ 27 | --solarization_prob 0.0 0.2 \ 28 | --name cifar100-simclr-contrastive-distill-classifier-l1000-soft-label-replay-0.01-b32 \ 29 | --project ever-learn-2 \ 30 | --entity your_entity \ 31 | --offline \ 32 | --wandb \ 33 | --save_checkpoint \ 34 | --output_dim 256 \ 35 | --proj_hidden_dim 4096 \ 36 | --pred_hidden_dim 4096 \ 37 | --base_tau_momentum 0.99 \ 38 | --final_tau_momentum 1.0 \ 39 | --momentum_classifier \ 40 | --disable_knn_eval \ 41 | --online_eval_classifier_lr 0.1 \ 42 | --classifier_training True \ 43 | --classifier_stop_gradient True \ 44 | --classifier_layers 1000 \ 45 | --distiller_library factory \ 46 | --method simclr \ 47 | --distiller contrastive \ 48 | --classifier_distill_lamb 2.0 \ 49 | --distiller_classifier soft_label \ 50 | --classifier_distill_no_predictior True \ 51 | --replay True \ 52 | --replay_proportion 0.01 \ 53 | --replay_batch_size 32 \ 54 | --online_evaluation True \ 55 | --online_evaluation_training_data_source seen_tasks \ 56 | --pretrained_model experiments/2022_08_31_14_49_21-cifar100-simclr-contrastive-distill-classifier-l1000-lamb5-soft-label-replay-0.1/task0-1o6mzete/cifar100-simclr-contrastive-distill-classifier-l1000-lamb5-soft-label-replay-0.1-task0-ep=499-1o6mzete.ckpt 57 | -------------------------------------------------------------------------------- /bash_files/vicreg_cifar_distill_classifier_l1000_soft_label_replay_0.01_b32.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset cifar100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --split_strategy class \ 6 | --max_epochs 500 \ 7 | --num_tasks 5 \ 8 | --task_idx 4 \ 9 | --gpus 0 \ 10 | --precision 16 \ 11 | --optimizer sgd \ 12 | --lars \ 13 | --grad_clip_lars \ 14 | --eta_lars 0.02 \ 15 | --exclude_bias_n_norm \ 16 | --scheduler warmup_cosine \ 17 | --lr 0.3 \ 18 | --classifier_lr 0.05 \ 19 | --weight_decay 1e-4 \ 20 | --batch_size 256 \ 21 | --num_workers 2 \ 22 | --min_scale 0.2 \ 23 | --brightness 0.4 \ 24 | --contrast 0.4 \ 25 | --saturation 0.2 \ 26 | --hue 0.1 \ 27 | --solarization_prob 0.1 \ 28 | --gaussian_prob 0.0 0.0 \ 29 | --name cifar100-vicreg-predictive_mse-distill-classifier-l1000-soft-label-replay-0.01-b32 \ 30 | --project ever-learn-2 \ 31 | --entity your_entity \ 32 | --offline \ 33 | --wandb \ 34 | --save_checkpoint \ 35 | --proj_hidden_dim 2048 \ 36 | --output_dim 2048 \ 37 | --sim_loss_weight 25.0 \ 38 | --var_loss_weight 25.0 \ 39 | --cov_loss_weight 1.0 \ 40 | --disable_knn_eval \ 41 | --online_eval_classifier_lr 0.1 \ 42 | --classifier_training True \ 43 | --classifier_stop_gradient True \ 44 | --classifier_layers 1000 \ 45 | --distiller_library factory \ 46 | --method vicreg \ 47 | --distiller predictive_mse \ 48 | --classifier_distill_lamb 2.0 \ 49 | --distiller_classifier soft_label \ 50 | --classifier_distill_no_predictior True \ 51 | --replay True \ 52 | --replay_proportion 0.01 \ 53 | --replay_batch_size 32 \ 54 | --online_evaluation True \ 55 | --online_evaluation_training_data_source seen_tasks \ 56 | --pretrained_model experiments/2023_03_05_04_19_14-cifar100-vicreg-predictive_mse-distill-classifier-l1000-lamb2-soft-label-replay-0.01-b32/task3-6348uzl8/cifar100-vicreg-predictive_mse-distill-classifier-l1000-lamb2-soft-label-replay-0.01-b32-task3-ep=499-6348uzl8.ckpt 57 | -------------------------------------------------------------------------------- /kaizen/utils/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils.data.dataset import Dataset 4 | from PIL import Image 5 | 6 | 7 | class DomainNetDataset(Dataset): 8 | def __init__( 9 | self, 10 | data_root, 11 | image_list_root, 12 | domain_names, 13 | split="train", 14 | transform=None, 15 | return_domain=False, 16 | ): 17 | self.data_root = data_root 18 | self.transform = transform 19 | self.domain_names = domain_names 20 | self.return_domain = return_domain 21 | 22 | if domain_names is None: 23 | self.domain_names = [ 24 | "clipart", 25 | "infograph", 26 | "painting", 27 | "quickdraw", 28 | "real", 29 | "sketch", 30 | ] 31 | if not isinstance(domain_names, list): 32 | self.domain_name = [domain_names] 33 | 34 | image_list_paths = [ 35 | os.path.join(image_list_root, d + "_" + split + ".txt") for d in self.domain_names 36 | ] 37 | self.imgs = self._make_dataset(image_list_paths) 38 | 39 | def _make_dataset(self, image_list_paths): 40 | images = [] 41 | for image_list_path in image_list_paths: 42 | image_list = open(image_list_path).readlines() 43 | images += [(val.split()[0], int(val.split()[1])) for val in image_list] 44 | return images 45 | 46 | def _rgb_loader(self, path): 47 | with open(path, "rb") as f: 48 | with Image.open(f) as img: 49 | return img.convert("RGB") 50 | 51 | def __getitem__(self, index): 52 | path, target = self.imgs[index] 53 | img = self._rgb_loader(os.path.join(self.data_root, path)) 54 | 55 | if self.transform is not None: 56 | img = self.transform(img) 57 | 58 | domain = None 59 | if self.return_domain: 60 | domain = [d for d in self.domain_names if d in path] 61 | assert len(domain) == 1 62 | domain = domain[0] 63 | 64 | return domain if self.return_domain else index, img, target 65 | 66 | def __len__(self): 67 | return len(self.imgs) 68 | -------------------------------------------------------------------------------- /kaizen/distillers/predictive_mse.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Any, List, Sequence 3 | 4 | import torch 5 | from torch import nn 6 | from kaizen.distillers.base import base_distill_wrapper 7 | from kaizen.losses.vicreg import invariance_loss 8 | 9 | 10 | def predictive_mse_distill_wrapper(Method=object): 11 | class PredictiveMSEDistillWrapper(base_distill_wrapper(Method)): 12 | def __init__(self, distill_lamb: float, distill_proj_hidden_dim, **kwargs): 13 | super().__init__(**kwargs) 14 | 15 | self.distill_lamb = distill_lamb 16 | output_dim = kwargs["output_dim"] 17 | 18 | self.distill_predictor = nn.Sequential( 19 | nn.Linear(output_dim, distill_proj_hidden_dim), 20 | nn.BatchNorm1d(distill_proj_hidden_dim), 21 | nn.ReLU(), 22 | nn.Linear(distill_proj_hidden_dim, output_dim), 23 | ) 24 | 25 | @staticmethod 26 | def add_model_specific_args( 27 | parent_parser: argparse.ArgumentParser, 28 | ) -> argparse.ArgumentParser: 29 | parser = parent_parser.add_argument_group("contrastive_distiller") 30 | 31 | parser.add_argument("--distill_lamb", type=float, default=25) 32 | parser.add_argument("--distill_proj_hidden_dim", type=int, default=2048) 33 | 34 | return parent_parser 35 | 36 | @property 37 | def learnable_params(self) -> List[dict]: 38 | """Adds distill predictor parameters to the parent's learnable parameters. 39 | 40 | Returns: 41 | List[dict]: list of learnable parameters. 42 | """ 43 | 44 | extra_learnable_params = [ 45 | {"params": self.distill_predictor.parameters()}, 46 | ] 47 | return super().learnable_params + extra_learnable_params 48 | 49 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: 50 | out = super().training_step(batch, batch_idx) 51 | z1, z2 = out["z"] 52 | frozen_z1, frozen_z2 = out["frozen_z"] 53 | 54 | p1 = self.distill_predictor(z1) 55 | p2 = self.distill_predictor(z2) 56 | 57 | distill_loss = (invariance_loss(p1, frozen_z1) + invariance_loss(p2, frozen_z2)) / 2 58 | 59 | self.log("train_predictive_distill_loss", distill_loss, on_epoch=True, sync_dist=True) 60 | 61 | return out["loss"] + self.distill_lamb * distill_loss 62 | 63 | return PredictiveMSEDistillWrapper 64 | -------------------------------------------------------------------------------- /kaizen/distiller_factories/base.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Any, Sequence 3 | import torch 4 | 5 | 6 | def base_frozen_model_factory(MethodClass=object): 7 | class BaseFrozenModel(MethodClass): 8 | def __init__(self, **kwargs) -> None: 9 | super().__init__(**kwargs) 10 | 11 | self.output_dim = kwargs["output_dim"] 12 | self.store_model_frozen_copy() 13 | 14 | def store_model_frozen_copy(self): 15 | self.frozen_encoder = deepcopy(self.encoder) 16 | self.frozen_projector = deepcopy(self.projector) 17 | 18 | for pg in self.frozen_encoder.parameters(): 19 | pg.requires_grad = False 20 | for pg in self.frozen_projector.parameters(): 21 | pg.requires_grad = False 22 | if self.classifier_training: 23 | self.frozen_classifier = deepcopy(self.classifier) 24 | for pg in self.frozen_classifier.parameters(): 25 | pg.requires_grad = False 26 | else: 27 | self.frozen_classifier = None 28 | 29 | def on_train_start(self): 30 | super().on_train_start() 31 | self.store_model_frozen_copy() 32 | 33 | @torch.no_grad() 34 | def frozen_forward(self, X): 35 | feats_encoder = self.frozen_encoder(X) 36 | feats_projector = self.frozen_projector(feats_encoder) 37 | if self.frozen_classifier is not None: 38 | logits_classifier = self.frozen_classifier(feats_encoder) 39 | else: 40 | logits_classifier = None 41 | return feats_encoder, feats_projector, logits_classifier 42 | 43 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: 44 | _, (X1, X2), _ = batch[f"task{self.current_task_idx}"] 45 | if "replay" in batch: 46 | *_, (X1R, X2R), _ = batch["replay"] 47 | X1 = torch.cat([X1, X1R]) 48 | X2 = torch.cat([X2, X2R]) 49 | 50 | out = super().training_step(batch, batch_idx) 51 | 52 | frozen_feats1, frozen_z1, frozen_logits1 = self.frozen_forward(X1) 53 | frozen_feats2, frozen_z2, frozen_logits2 = self.frozen_forward(X2) 54 | 55 | out.update({ 56 | "frozen_feats": [frozen_feats1, frozen_feats2], 57 | "frozen_z": [frozen_z1, frozen_z2], 58 | "frozen_logits": [frozen_logits1, frozen_logits2] 59 | }) 60 | return out 61 | 62 | return BaseFrozenModel 63 | -------------------------------------------------------------------------------- /job_launcher.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import subprocess 4 | import argparse 5 | from datetime import datetime 6 | import inspect 7 | import shutil 8 | 9 | 10 | from main_continual import str_to_dict 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--script", type=str, required=True) 14 | parser.add_argument("--mode", type=str, default="normal") 15 | parser.add_argument("--experiment_dir", type=str, default=None) 16 | parser.add_argument("--base_experiment_dir", type=str, default="./experiments") 17 | parser.add_argument("--gpu", type=str, default="v100-16g") 18 | parser.add_argument("--num_gpus", type=int, default=2) 19 | parser.add_argument("--hours", type=int, default=20) 20 | parser.add_argument("--requeue", type=int, default=0) 21 | 22 | args = parser.parse_args() 23 | 24 | # load file 25 | if os.path.exists(args.script): 26 | with open(args.script) as f: 27 | command = [line.strip().strip("\\").strip() for line in f.readlines()] 28 | else: 29 | print(f"{args.script} does not exist.") 30 | exit() 31 | 32 | assert ( 33 | "--checkpoint_dir" not in command 34 | ), "Please remove the --checkpoint_dir argument, it will be added automatically" 35 | 36 | # collect args 37 | command_args = str_to_dict(" ".join(command).split(" ")[2:]) 38 | 39 | # create experiment directory 40 | if args.experiment_dir is None: 41 | args.experiment_dir = datetime.now().strftime("%Y_%m_%d_%H_%M_%S") 42 | args.experiment_dir += f"-{command_args['--name']}" 43 | full_experiment_dir = os.path.join(args.base_experiment_dir, args.experiment_dir) 44 | os.makedirs(full_experiment_dir, exist_ok=True) # Moved to main_continual.py 45 | print(f"Experiment directory: {full_experiment_dir}") 46 | shutil.copy(args.script, full_experiment_dir) 47 | # add experiment directory to the command 48 | command.extend(["--checkpoint_dir", full_experiment_dir]) 49 | command = " ".join(command) 50 | 51 | print(command) 52 | 53 | # run command 54 | if args.mode == "normal": 55 | p = subprocess.Popen(command, shell=True, stdout=sys.stdout, stderr=sys.stdout) 56 | p.wait() 57 | 58 | elif args.mode == "slurm": 59 | # infer qos 60 | if 0 <= args.hours <= 2: 61 | qos = "qos_gpu-dev" 62 | elif args.hours <= 20: 63 | qos = "qos_gpu-t3" 64 | elif args.hours <= 100: 65 | qos = "qos_gpu-t4" 66 | 67 | # write command 68 | command_path = os.path.join(full_experiment_dir, "command.sh") 69 | with open(command_path, "w") as f: 70 | f.write(command) 71 | 72 | # run command 73 | p = subprocess.Popen(f"sbatch {command_path}", shell=True, stdout=sys.stdout, stderr=sys.stdout) 74 | p.wait() 75 | -------------------------------------------------------------------------------- /kaizen/utils/sinkhorn_knopp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | 4 | 5 | class SinkhornKnopp(torch.nn.Module): 6 | def __init__(self, num_iters: int = 3, epsilon: float = 0.05, world_size: int = 1): 7 | """Approximates optimal transport using the Sinkhorn-Knopp algorithm. 8 | 9 | A simple iterative method to approach the double stochastic matrix is to alternately rescale 10 | rows and columns of the matrix to sum to 1. 11 | 12 | Args: 13 | num_iters (int, optional): number of times to perform row and column normalization. 14 | Defaults to 3. 15 | epsilon (float, optional): weight for the entropy regularization term. Defaults to 0.05. 16 | world_size (int, optional): number of nodes for distributed training. Defaults to 1. 17 | """ 18 | 19 | super().__init__() 20 | self.num_iters = num_iters 21 | self.epsilon = epsilon 22 | self.world_size = world_size 23 | 24 | @torch.no_grad() 25 | def forward(self, Q: torch.Tensor) -> torch.Tensor: 26 | """Produces assignments using Sinkhorn-Knopp algorithm. 27 | 28 | Applies the entropy regularization, normalizes the Q matrix and then normalizes rows and 29 | columns in an alternating fashion for num_iter times. Before returning it normalizes again 30 | the columns in order for the output to be an assignment of samples to prototypes. 31 | 32 | Args: 33 | Q (torch.Tensor): cosine similarities between the features of the 34 | samples and the prototypes. 35 | 36 | Returns: 37 | torch.Tensor: assignment of samples to prototypes according to optimal transport. 38 | """ 39 | 40 | Q = torch.exp(Q / self.epsilon).t() 41 | B = Q.shape[1] * self.world_size 42 | K = Q.shape[0] # num prototypes 43 | 44 | # make the matrix sums to 1 45 | sum_Q = torch.sum(Q) 46 | if dist.is_available() and dist.is_initialized(): 47 | dist.all_reduce(sum_Q) 48 | Q /= sum_Q 49 | 50 | for it in range(self.num_iters): 51 | # normalize each row: total weight per prototype must be 1/K 52 | sum_of_rows = torch.sum(Q, dim=1, keepdim=True) 53 | if dist.is_available() and dist.is_initialized(): 54 | dist.all_reduce(sum_of_rows) 55 | Q /= sum_of_rows 56 | Q /= K 57 | 58 | # normalize each column: total weight per sample must be 1/B 59 | Q /= torch.sum(Q, dim=0, keepdim=True) 60 | Q /= B 61 | 62 | Q *= B # the colomns must sum to 1 so that Q is an assignment 63 | return Q.t() 64 | -------------------------------------------------------------------------------- /kaizen/utils/momentum.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | @torch.no_grad() 8 | def initialize_momentum_params(online_net: nn.Module, momentum_net: nn.Module): 9 | """Copies the parameters of the online network to the momentum network. 10 | 11 | Args: 12 | online_net (nn.Module): online network (e.g. online encoder, online projection, etc...). 13 | momentum_net (nn.Module): momentum network (e.g. momentum encoder, 14 | momentum projection, etc...). 15 | """ 16 | 17 | params_online = online_net.parameters() 18 | params_momentum = momentum_net.parameters() 19 | for po, pm in zip(params_online, params_momentum): 20 | pm.data.copy_(po.data) 21 | pm.requires_grad = False 22 | 23 | 24 | class MomentumUpdater: 25 | def __init__(self, base_tau: float = 0.996, final_tau: float = 1.0): 26 | """Updates momentum parameters using exponential moving average. 27 | 28 | Args: 29 | base_tau (float, optional): base value of the weight decrease coefficient 30 | (should be in [0,1]). Defaults to 0.996. 31 | final_tau (float, optional): final value of the weight decrease coefficient 32 | (should be in [0,1]). Defaults to 1.0. 33 | """ 34 | 35 | super().__init__() 36 | 37 | assert 0 <= base_tau <= 1 38 | assert 0 <= final_tau <= 1 and base_tau <= final_tau 39 | 40 | self.base_tau = base_tau 41 | self.cur_tau = base_tau 42 | self.final_tau = final_tau 43 | 44 | @torch.no_grad() 45 | def update(self, online_net: nn.Module, momentum_net: nn.Module): 46 | """Performs the momentum update for each param group. 47 | 48 | Args: 49 | online_net (nn.Module): online network (e.g. online encoder, online projection, etc...). 50 | momentum_net (nn.Module): momentum network (e.g. momentum encoder, 51 | momentum projection, etc...). 52 | """ 53 | 54 | for op, mp in zip(online_net.parameters(), momentum_net.parameters()): 55 | mp.data = self.cur_tau * mp.data + (1 - self.cur_tau) * op.data 56 | 57 | def update_tau(self, cur_step: int, max_steps: int): 58 | """Computes the next value for the weighting decrease coefficient tau using cosine annealing. 59 | 60 | Args: 61 | cur_step (int): number of gradient steps so far. 62 | max_steps (int): overall number of gradient steps in the whole training. 63 | """ 64 | 65 | self.cur_tau = ( 66 | self.final_tau 67 | - (self.final_tau - self.base_tau) * (math.cos(math.pi * cur_step / max_steps) + 1) / 2 68 | ) 69 | -------------------------------------------------------------------------------- /kaizen/distillers/predictive.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Any, List, Sequence 3 | 4 | import torch 5 | from torch import nn 6 | from kaizen.distillers.base import base_distill_wrapper 7 | from kaizen.losses.byol import byol_loss_func 8 | 9 | 10 | def predictive_distill_wrapper(Method=object): 11 | class PredictiveDistillWrapper(base_distill_wrapper(Method)): 12 | def __init__(self, distill_lamb: float, distill_proj_hidden_dim, **kwargs): 13 | super().__init__(**kwargs) 14 | 15 | self.distill_lamb = distill_lamb 16 | output_dim = kwargs["output_dim"] 17 | 18 | self.distill_predictor = nn.Sequential( 19 | nn.Linear(output_dim, distill_proj_hidden_dim), 20 | nn.BatchNorm1d(distill_proj_hidden_dim), 21 | nn.ReLU(), 22 | nn.Linear(distill_proj_hidden_dim, output_dim), 23 | ) 24 | 25 | @staticmethod 26 | def add_model_specific_args( 27 | parent_parser: argparse.ArgumentParser, 28 | ) -> argparse.ArgumentParser: 29 | parser = parent_parser.add_argument_group("contrastive_distiller") 30 | 31 | parser.add_argument("--distill_lamb", type=float, default=1) 32 | parser.add_argument("--distill_proj_hidden_dim", type=int, default=2048) 33 | 34 | return parent_parser 35 | 36 | @property 37 | def learnable_params(self) -> List[dict]: 38 | """Adds distill predictor parameters to the parent's learnable parameters. 39 | 40 | Returns: 41 | List[dict]: list of learnable parameters. 42 | """ 43 | 44 | extra_learnable_params = [ 45 | { 46 | "name": "distill_predictor", 47 | "params": self.distill_predictor.parameters(), 48 | "lr": self.lr if self.distill_lamb >= 1 else self.lr / self.distill_lamb, 49 | "weight_decay": self.weight_decay, 50 | }, 51 | ] 52 | return super().learnable_params + extra_learnable_params 53 | 54 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: 55 | out = super().training_step(batch, batch_idx) 56 | z1, z2 = out["z"] 57 | frozen_z1, frozen_z2 = out["frozen_z"] 58 | 59 | p1 = self.distill_predictor(z1) 60 | p2 = self.distill_predictor(z2) 61 | 62 | distill_loss = (byol_loss_func(p1, frozen_z1) + byol_loss_func(p2, frozen_z2)) / 2 63 | 64 | self.log("train_predictive_distill_loss", distill_loss, on_epoch=True, sync_dist=True) 65 | 66 | return out["loss"] + self.distill_lamb * distill_loss 67 | 68 | return PredictiveDistillWrapper 69 | -------------------------------------------------------------------------------- /kaizen/args/dataset.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from pathlib import Path 3 | from .utils import strtobool 4 | 5 | def dataset_args(parser: ArgumentParser): 6 | """Adds dataset-related arguments to a parser. 7 | 8 | Args: 9 | parser (ArgumentParser): parser to add dataset args to. 10 | """ 11 | 12 | SUPPORTED_DATASETS = [ 13 | "cifar10", 14 | "cifar100", 15 | "stl10", 16 | "imagenet", 17 | "imagenet100", 18 | "domainnet", 19 | "custom", 20 | ] 21 | 22 | parser.add_argument("--dataset", choices=SUPPORTED_DATASETS, type=str, required=True) 23 | 24 | # dataset path 25 | parser.add_argument("--data_dir", type=Path, required=True) 26 | parser.add_argument("--train_dir", type=Path, default=None) 27 | parser.add_argument("--val_dir", type=Path, default=None) 28 | 29 | # dali (imagenet-100/imagenet/custom only) 30 | parser.add_argument("--dali", action="store_true") 31 | parser.add_argument("--dali_device", type=str, default="gpu") 32 | 33 | # custom dataset only 34 | parser.add_argument("--no_labels", action="store_true") 35 | parser.add_argument("--semi_supervised", default=None, type=float) 36 | 37 | parser.add_argument("--split_seed", type=int, default=5) 38 | parser.add_argument("--global_seed", type=int, default=5) 39 | 40 | 41 | def augmentations_args(parser: ArgumentParser): 42 | """Adds augmentation-related arguments to a parser. 43 | 44 | Args: 45 | parser (ArgumentParser): parser to add augmentation args to. 46 | """ 47 | 48 | # cropping 49 | parser.add_argument("--multicrop", action="store_true") 50 | parser.add_argument("--num_crops", type=int, default=2) 51 | parser.add_argument("--num_small_crops", type=int, default=0) 52 | 53 | # augmentations 54 | parser.add_argument("--brightness", type=float, required=True, nargs="+") 55 | parser.add_argument("--contrast", type=float, required=True, nargs="+") 56 | parser.add_argument("--saturation", type=float, required=True, nargs="+") 57 | parser.add_argument("--hue", type=float, required=True, nargs="+") 58 | parser.add_argument("--gaussian_prob", type=float, default=[0.5], nargs="+") 59 | parser.add_argument("--solarization_prob", type=float, default=[0.0], nargs="+") 60 | parser.add_argument("--min_scale", type=float, default=[0.08], nargs="+") 61 | 62 | # for imagenet or custom dataset 63 | parser.add_argument("--size", type=int, default=[224], nargs="+") 64 | 65 | # for custom dataset 66 | parser.add_argument("--mean", type=float, default=[0.485, 0.456, 0.406], nargs="+") 67 | parser.add_argument("--std", type=float, default=[0.228, 0.224, 0.225], nargs="+") 68 | 69 | # debug 70 | parser.add_argument("--debug_augmentations", action="store_true") 71 | -------------------------------------------------------------------------------- /kaizen/distillers/contrastive.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Any, List, Sequence 3 | 4 | import torch 5 | from torch import nn 6 | from kaizen.distillers.base import base_distill_wrapper 7 | from kaizen.losses.simclr import simclr_distill_loss_func 8 | 9 | 10 | def contrastive_distill_wrapper(Method=object): 11 | class ContrastiveDistillWrapper(base_distill_wrapper(Method)): 12 | def __init__( 13 | self, 14 | distill_lamb: float, 15 | distill_proj_hidden_dim: int, 16 | distill_temperature: float, 17 | **kwargs 18 | ): 19 | super().__init__(**kwargs) 20 | 21 | self.distill_lamb = distill_lamb 22 | self.distill_temperature = distill_temperature 23 | output_dim = kwargs["output_dim"] 24 | 25 | self.distill_predictor = nn.Sequential( 26 | nn.Linear(output_dim, distill_proj_hidden_dim), 27 | nn.BatchNorm1d(distill_proj_hidden_dim), 28 | nn.ReLU(), 29 | nn.Linear(distill_proj_hidden_dim, output_dim), 30 | ) 31 | 32 | @staticmethod 33 | def add_model_specific_args( 34 | parent_parser: argparse.ArgumentParser, 35 | ) -> argparse.ArgumentParser: 36 | parser = parent_parser.add_argument_group("contrastive_distiller") 37 | 38 | parser.add_argument("--distill_lamb", type=float, default=1) 39 | parser.add_argument("--distill_proj_hidden_dim", type=int, default=2048) 40 | parser.add_argument("--distill_temperature", type=float, default=0.2) 41 | 42 | return parent_parser 43 | 44 | @property 45 | def learnable_params(self) -> List[dict]: 46 | """Adds distill predictor parameters to the parent's learnable parameters. 47 | 48 | Returns: 49 | List[dict]: list of learnable parameters. 50 | """ 51 | 52 | extra_learnable_params = [ 53 | {"params": self.distill_predictor.parameters()}, 54 | ] 55 | return super().learnable_params + extra_learnable_params 56 | 57 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: 58 | out = super().training_step(batch, batch_idx) 59 | z1, z2 = out["z"] 60 | frozen_z1, frozen_z2 = out["frozen_z"] 61 | 62 | p1 = self.distill_predictor(z1) 63 | p2 = self.distill_predictor(z2) 64 | 65 | distill_loss = ( 66 | simclr_distill_loss_func(p1, p2, frozen_z1, frozen_z2, self.distill_temperature) 67 | + simclr_distill_loss_func(frozen_z1, frozen_z2, p1, p2, self.distill_temperature) 68 | ) / 2 69 | 70 | self.log("train_contrastive_distill_loss", distill_loss, on_epoch=True, sync_dist=True) 71 | 72 | return out["loss"] + self.distill_lamb * distill_loss 73 | 74 | return ContrastiveDistillWrapper 75 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | experiments/ 2 | trained_models/ 3 | 4 | # checkpoint 5 | last_checkpoint.txt 6 | 7 | # tensorboard dir 8 | runs/ 9 | 10 | # wandb dir 11 | wandb/ 12 | wandb*/ 13 | 14 | # umap dir 15 | auto_umap/ 16 | 17 | # datasets dir 18 | datasets/ 19 | # saved models 20 | *.pt 21 | *.pth 22 | *.ckpt 23 | *.tar 24 | 25 | *.png 26 | !imgs/*.png 27 | *.jpg 28 | *.jpeg 29 | 30 | saved_models/ 31 | model_storage/ 32 | model_storage*/ 33 | lightning_logs/ 34 | 35 | *.json 36 | 37 | *logs*/ 38 | 39 | # Created by https://www.gitignore.io/api/python,visualstudiocode 40 | # Edit at https://www.gitignore.io/?templates=python,visualstudiocode 41 | 42 | ### Python ### 43 | # Byte-compiled / optimized / DLL files 44 | __pycache__/ 45 | *.py[cod] 46 | *$py.class 47 | 48 | # C extensions 49 | *.so 50 | 51 | # Distribution / packaging 52 | .Python 53 | build/ 54 | develop-eggs/ 55 | dist/ 56 | downloads/ 57 | eggs/ 58 | .eggs/ 59 | lib/ 60 | lib64/ 61 | parts/ 62 | sdist/ 63 | var/ 64 | wheels/ 65 | pip-wheel-metadata/ 66 | share/python-wheels/ 67 | *.egg-info/ 68 | .installed.cfg 69 | *.egg 70 | MANIFEST 71 | 72 | # PyInstaller 73 | # Usually these files are written by a python script from a template 74 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 75 | *.manifest 76 | *.spec 77 | 78 | # Installer logs 79 | pip-log.txt 80 | pip-delete-this-directory.txt 81 | 82 | # Unit test / coverage reports 83 | htmlcov/ 84 | .tox/ 85 | .nox/ 86 | .coverage 87 | .coverage.* 88 | .cache 89 | nosetests.xml 90 | coverage.xml 91 | *.cover 92 | .hypothesis/ 93 | .pytest_cache/ 94 | 95 | # Translations 96 | *.mo 97 | *.pot 98 | 99 | # Scrapy stuff: 100 | .scrapy 101 | 102 | # Sphinx documentation 103 | docs/_build/ 104 | 105 | # PyBuilder 106 | target/ 107 | 108 | # pyenv 109 | .python-version 110 | 111 | # pipenv 112 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 113 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 114 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 115 | # install all needed dependencies. 116 | #Pipfile.lock 117 | 118 | # celery beat schedule file 119 | celerybeat-schedule 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Spyder project settings 125 | .spyderproject 126 | .spyproject 127 | 128 | # Rope project settings 129 | .ropeproject 130 | 131 | # Mr Developer 132 | .mr.developer.cfg 133 | .project 134 | .pydevproject 135 | 136 | # mkdocs documentation 137 | /site 138 | 139 | # mypy 140 | .mypy_cache/ 141 | .dmypy.json 142 | dmypy.json 143 | 144 | # Pyre type checker 145 | .pyre/ 146 | 147 | ### VisualStudioCode ### 148 | .vscode/* 149 | !.vscode/settings.json 150 | !.vscode/tasks.json 151 | !.vscode/launch.json 152 | !.vscode/extensions.json 153 | 154 | ### VisualStudioCode Patch ### 155 | # Ignore all local history of files 156 | .history 157 | 158 | # End of https://www.gitignore.io/api/python,visualstudiocode 159 | -------------------------------------------------------------------------------- /kaizen/losses/vicreg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def invariance_loss(z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor: 6 | """Computes mse loss given batch of projected features z1 from view 1 and 7 | projected features z2 from view 2. 8 | 9 | Args: 10 | z1 (torch.Tensor): NxD Tensor containing projected features from view 1. 11 | z2 (torch.Tensor): NxD Tensor containing projected features from view 2. 12 | 13 | Returns: 14 | torch.Tensor: invariance loss (mean squared error). 15 | """ 16 | 17 | return F.mse_loss(z1, z2) 18 | 19 | 20 | def variance_loss(z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor: 21 | """Computes variance loss given batch of projected features z1 from view 1 and 22 | projected features z2 from view 2. 23 | 24 | Args: 25 | z1 (torch.Tensor): NxD Tensor containing projected features from view 1. 26 | z2 (torch.Tensor): NxD Tensor containing projected features from view 2. 27 | 28 | Returns: 29 | torch.Tensor: variance regularization loss. 30 | """ 31 | 32 | eps = 1e-4 33 | std_z1 = torch.sqrt(z1.var(dim=0) + eps) 34 | std_z2 = torch.sqrt(z2.var(dim=0) + eps) 35 | std_loss = torch.mean(F.relu(1 - std_z1)) + torch.mean(F.relu(1 - std_z2)) 36 | return std_loss 37 | 38 | 39 | def covariance_loss(z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor: 40 | """Computes covariance loss given batch of projected features z1 from view 1 and 41 | projected features z2 from view 2. 42 | 43 | Args: 44 | z1 (torch.Tensor): NxD Tensor containing projected features from view 1. 45 | z2 (torch.Tensor): NxD Tensor containing projected features from view 2. 46 | 47 | Returns: 48 | torch.Tensor: covariance regularization loss. 49 | """ 50 | 51 | N, D = z1.size() 52 | 53 | z1 = z1 - z1.mean(dim=0) 54 | z2 = z2 - z2.mean(dim=0) 55 | cov_z1 = (z1.T @ z1) / (N - 1) 56 | cov_z2 = (z2.T @ z2) / (N - 1) 57 | 58 | diag = torch.eye(D, device=z1.device) 59 | cov_loss = cov_z1[~diag.bool()].pow_(2).sum() / D + cov_z2[~diag.bool()].pow_(2).sum() / D 60 | return cov_loss 61 | 62 | 63 | def vicreg_loss_func( 64 | z1: torch.Tensor, 65 | z2: torch.Tensor, 66 | sim_loss_weight: float = 25.0, 67 | var_loss_weight: float = 25.0, 68 | cov_loss_weight: float = 1.0, 69 | ) -> torch.Tensor: 70 | """Computes VICReg's loss given batch of projected features z1 from view 1 and 71 | projected features z2 from view 2. 72 | 73 | Args: 74 | z1 (torch.Tensor): NxD Tensor containing projected features from view 1. 75 | z2 (torch.Tensor): NxD Tensor containing projected features from view 2. 76 | sim_loss_weight (float): invariance loss weight. 77 | var_loss_weight (float): variance loss weight. 78 | cov_loss_weight (float): covariance loss weight. 79 | 80 | Returns: 81 | torch.Tensor: VICReg loss. 82 | """ 83 | 84 | sim_loss = invariance_loss(z1, z2) 85 | var_loss = variance_loss(z1, z2) 86 | cov_loss = covariance_loss(z1, z2) 87 | 88 | loss = sim_loss_weight * sim_loss + var_loss_weight * var_loss + cov_loss_weight * cov_loss 89 | return loss 90 | -------------------------------------------------------------------------------- /kaizen/distillers/decorrelative.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Any, List, Sequence 3 | 4 | import torch 5 | from torch import nn 6 | from kaizen.distillers.base import base_distill_wrapper 7 | from kaizen.losses.barlow import barlow_loss_func 8 | 9 | 10 | def decorrelative_distill_wrapper(Method=object): 11 | class DecorrelativeDistillWrapper(base_distill_wrapper(Method)): 12 | def __init__( 13 | self, 14 | distill_lamb: float, 15 | distill_proj_hidden_dim: int, 16 | distill_barlow_lamb: float, 17 | distill_scale_loss: float, 18 | **kwargs 19 | ): 20 | super().__init__(**kwargs) 21 | 22 | output_dim = kwargs["output_dim"] 23 | self.distill_lamb = distill_lamb 24 | self.distill_barlow_lamb = distill_barlow_lamb 25 | self.distill_scale_loss = distill_scale_loss 26 | 27 | self.distill_predictor = nn.Sequential( 28 | nn.Linear(output_dim, distill_proj_hidden_dim), 29 | nn.BatchNorm1d(distill_proj_hidden_dim), 30 | nn.ReLU(), 31 | nn.Linear(distill_proj_hidden_dim, output_dim), 32 | ) 33 | 34 | @staticmethod 35 | def add_model_specific_args( 36 | parent_parser: argparse.ArgumentParser, 37 | ) -> argparse.ArgumentParser: 38 | parser = parent_parser.add_argument_group("contrastive_distiller") 39 | 40 | parser.add_argument("--distill_lamb", type=float, default=1) 41 | parser.add_argument("--distill_proj_hidden_dim", type=int, default=2048) 42 | parser.add_argument("--distill_barlow_lamb", type=float, default=5e-3) 43 | parser.add_argument("--distill_scale_loss", type=float, default=0.1) 44 | 45 | return parent_parser 46 | 47 | @property 48 | def learnable_params(self) -> List[dict]: 49 | """Adds distill predictor parameters to the parent's learnable parameters. 50 | 51 | Returns: 52 | List[dict]: list of learnable parameters. 53 | """ 54 | 55 | extra_learnable_params = [ 56 | {"params": self.distill_predictor.parameters()}, 57 | ] 58 | return super().learnable_params + extra_learnable_params 59 | 60 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: 61 | out = super().training_step(batch, batch_idx) 62 | z1, z2 = out["z"] 63 | frozen_z1, frozen_z2 = out["frozen_z"] 64 | 65 | p1 = self.distill_predictor(z1) 66 | p2 = self.distill_predictor(z2) 67 | 68 | distill_loss = ( 69 | barlow_loss_func( 70 | p1, 71 | frozen_z1, 72 | lamb=self.distill_barlow_lamb, 73 | scale_loss=self.distill_scale_loss, 74 | ) 75 | + barlow_loss_func( 76 | p2, 77 | frozen_z2, 78 | lamb=self.distill_barlow_lamb, 79 | scale_loss=self.distill_scale_loss, 80 | ) 81 | ) / 2 82 | 83 | self.log( 84 | "train_decorrelative_distill_loss", distill_loss, on_epoch=True, sync_dist=True 85 | ) 86 | 87 | return out["loss"] + self.distill_lamb * distill_loss 88 | 89 | return DecorrelativeDistillWrapper 90 | -------------------------------------------------------------------------------- /main_continual.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import itertools 3 | import subprocess 4 | import sys 5 | import os 6 | import json 7 | 8 | 9 | def str_to_dict(command): 10 | d = {} 11 | for part, part_next in itertools.zip_longest(command[:-1], command[1:]): 12 | if part[:2] == "--": 13 | if part_next[:2] != "--": 14 | d[part] = part_next 15 | else: 16 | d[part] = part 17 | elif part[:2] != "--" and part_next[:2] != "--": 18 | part_prev = list(d.keys())[-1] 19 | if not isinstance(d[part_prev], list): 20 | d[part_prev] = [d[part_prev]] 21 | if not part_next[:2] == "--": 22 | d[part_prev].append(part_next) 23 | return d 24 | 25 | 26 | def dict_to_list(command): 27 | s = [] 28 | for k, v in command.items(): 29 | s.append(k) 30 | if k != v and v[:2] != "--": 31 | s.append(v) 32 | return s 33 | 34 | 35 | def run_bash_command(args): 36 | for i, a in enumerate(args): 37 | if isinstance(a, list): 38 | args[i] = " ".join(a) 39 | command = ("python3 main_pretrain.py", *args) 40 | command = " ".join(command) 41 | p = subprocess.Popen(command, shell=True) 42 | p.wait() 43 | 44 | 45 | if __name__ == "__main__": 46 | args = sys.argv[1:] 47 | args = str_to_dict(args) 48 | os.makedirs(args['--checkpoint_dir'], exist_ok=True) 49 | 50 | # parse args from the script 51 | num_tasks = int(args["--num_tasks"]) 52 | start_task_idx = int(args.get("--task_idx", 0)) 53 | distill_args = {k: v for k, v in args.items() if "distill" in k} 54 | 55 | # delete things that shouldn't be used for task_idx 0 56 | args.pop("--task_idx", None) 57 | for k in distill_args.keys(): 58 | args.pop(k, None) 59 | 60 | # check if this experiment is being resumed 61 | # look for the file last_checkpoint.txt 62 | last_checkpoint_file = os.path.join(args["--checkpoint_dir"], "last_checkpoint.txt") 63 | if os.path.exists(last_checkpoint_file): 64 | with open(last_checkpoint_file) as f: 65 | ckpt_path, args_path = [line.rstrip() for line in f.readlines()] 66 | start_task_idx = json.load(open(args_path))["task_idx"] 67 | args["--resume_from_checkpoint"] = ckpt_path 68 | 69 | # main task loop 70 | for task_idx in range(start_task_idx, num_tasks): 71 | print(f"\n#### Starting Task {task_idx} ####") 72 | 73 | task_args = copy.deepcopy(args) 74 | 75 | print(task_idx, start_task_idx, task_args) 76 | 77 | # add pretrained model arg 78 | if task_args.get("--no_continual_learning", "False").startswith("True"): 79 | task_args.pop("--resume_from_checkpoint", None) 80 | task_args.pop("--pretrained_model", None) 81 | 82 | elif task_idx != 0 and task_idx != start_task_idx: 83 | task_args.pop("--resume_from_checkpoint", None) 84 | task_args.pop("--pretrained_model", None) 85 | assert os.path.exists(last_checkpoint_file) 86 | ckpt_path = open(last_checkpoint_file).readlines()[0].rstrip() 87 | task_args["--pretrained_model"] = ckpt_path 88 | 89 | if not task_args.get("--no_distill", "False").startswith("True"): 90 | if task_idx != 0 and distill_args: 91 | task_args.update(distill_args) 92 | 93 | task_args["--task_idx"] = str(task_idx) 94 | task_args = dict_to_list(task_args) 95 | 96 | run_bash_command(task_args) 97 | -------------------------------------------------------------------------------- /kaizen/distiller_factories/predictive_mse.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Any, List, Sequence 3 | 4 | import torch 5 | from torch import nn 6 | from kaizen.losses.vicreg import invariance_loss 7 | from kaizen.args.utils import strtobool 8 | 9 | def predictive_mse_distill_factory( 10 | Method=object, class_tag="", 11 | distill_current_key="z", distill_frozen_key="frozen_z", output_dim=256 12 | ): 13 | distill_lamb_name = f"{class_tag}_distill_lamb" 14 | distill_proj_hidden_dim_name = f"{class_tag}_distill_proj_hidden_dim" 15 | distill_no_predictior_name = f"{class_tag}_distill_no_predictior" 16 | 17 | distill_predictor_name = f"{class_tag}_distill_predictor" 18 | class PredictiveMSEDistillWrapper(Method): 19 | def __init__(self, **kwargs): 20 | distill_lamb = kwargs.pop(distill_lamb_name) 21 | distill_proj_hidden_dim = kwargs.pop(distill_proj_hidden_dim_name) 22 | distill_no_predictior = kwargs.pop(distill_no_predictior_name) 23 | super().__init__(**kwargs) 24 | 25 | setattr(self, distill_lamb_name, distill_lamb) 26 | setattr(self, distill_proj_hidden_dim_name, distill_proj_hidden_dim) 27 | setattr(self, distill_no_predictior_name, distill_no_predictior) 28 | if distill_no_predictior: 29 | setattr(self, distill_predictor_name, nn.Identity()) 30 | else: 31 | setattr(self, distill_predictor_name, nn.Sequential( 32 | nn.Linear(output_dim, distill_proj_hidden_dim), 33 | nn.BatchNorm1d(distill_proj_hidden_dim), 34 | nn.ReLU(), 35 | nn.Linear(distill_proj_hidden_dim, output_dim), 36 | )) 37 | 38 | @staticmethod 39 | def add_model_specific_args( 40 | parent_parser: argparse.ArgumentParser, 41 | ) -> argparse.ArgumentParser: 42 | parser = parent_parser.add_argument_group(f"predictive_mse_{class_tag}_distiller") 43 | 44 | parser.add_argument(f"--{distill_lamb_name}", type=float, default=25) 45 | parser.add_argument(f"--{distill_proj_hidden_dim_name}", type=int, default=2048) 46 | parser.add_argument(f"--{distill_no_predictior_name}", type=strtobool, default=False) 47 | 48 | return parent_parser 49 | 50 | @property 51 | def learnable_params(self) -> List[dict]: 52 | """Adds distill predictor parameters to the parent's learnable parameters. 53 | 54 | Returns: 55 | List[dict]: list of learnable parameters. 56 | """ 57 | 58 | extra_learnable_params = [ 59 | { 60 | "params": getattr(self, distill_predictor_name).parameters() 61 | }, 62 | ] 63 | return super().learnable_params + extra_learnable_params 64 | 65 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: 66 | out = super().training_step(batch, batch_idx) 67 | z1, z2 = out[distill_current_key] 68 | frozen_z1, frozen_z2 = out[distill_frozen_key] 69 | 70 | p1 = getattr(self, distill_predictor_name)(z1) 71 | p2 = getattr(self, distill_predictor_name)(z2) 72 | 73 | distill_loss = (invariance_loss(p1, frozen_z1) + invariance_loss(p2, frozen_z2)) / 2 74 | 75 | self.log(f"train_{class_tag}_predictive_mse_distill_loss", distill_loss, on_epoch=True, sync_dist=True) 76 | 77 | out["loss"] += getattr(self, distill_lamb_name) * distill_loss 78 | return out 79 | 80 | return PredictiveMSEDistillWrapper 81 | -------------------------------------------------------------------------------- /kaizen/methods/barlow_twins.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Any, List, Sequence 3 | 4 | import torch 5 | import torch.nn as nn 6 | from kaizen.losses.barlow import barlow_loss_func 7 | from kaizen.methods.base import BaseModel 8 | 9 | 10 | class BarlowTwins(BaseModel): 11 | def __init__( 12 | self, proj_hidden_dim: int, output_dim: int, lamb: float, scale_loss: float, **kwargs 13 | ): 14 | """Implements Barlow Twins (https://arxiv.org/abs/2103.03230) 15 | 16 | Args: 17 | proj_hidden_dim (int): number of neurons of the hidden layers of the projector. 18 | output_dim (int): number of dimensions of projected features. 19 | lamb (float): off-diagonal scaling factor for the cross-covariance matrix. 20 | scale_loss (float): scaling factor of the loss. 21 | """ 22 | 23 | super().__init__(**kwargs) 24 | 25 | self.lamb = lamb 26 | self.scale_loss = scale_loss 27 | 28 | # projector 29 | self.projector = nn.Sequential( 30 | nn.Linear(self.features_dim, proj_hidden_dim), 31 | nn.BatchNorm1d(proj_hidden_dim), 32 | nn.ReLU(), 33 | nn.Linear(proj_hidden_dim, proj_hidden_dim), 34 | nn.BatchNorm1d(proj_hidden_dim), 35 | nn.ReLU(), 36 | nn.Linear(proj_hidden_dim, output_dim), 37 | ) 38 | 39 | @staticmethod 40 | def add_model_specific_args(parent_parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 41 | parent_parser = super(BarlowTwins, BarlowTwins).add_model_specific_args(parent_parser) 42 | parser = parent_parser.add_argument_group("barlow_twins") 43 | 44 | # projector 45 | parser.add_argument("--output_dim", type=int, default=2048) 46 | parser.add_argument("--proj_hidden_dim", type=int, default=2048) 47 | 48 | # parameters 49 | parser.add_argument("--lamb", type=float, default=5e-3) 50 | parser.add_argument("--scale_loss", type=float, default=0.025) 51 | return parent_parser 52 | 53 | @property 54 | def learnable_params(self) -> List[dict]: 55 | """Adds projector parameters to parent's learnable parameters. 56 | 57 | Returns: 58 | List[dict]: list of learnable parameters. 59 | """ 60 | 61 | extra_learnable_params = [{"params": self.projector.parameters()}] 62 | return super().learnable_params + extra_learnable_params 63 | 64 | def forward(self, X, *args, **kwargs): 65 | out = super().forward(X, *args, **kwargs) 66 | z = self.projector(out["feats"]) 67 | return {**out, "z": z} 68 | 69 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: 70 | """Training step for Barlow Twins reusing BaseModel training step. 71 | 72 | Args: 73 | batch (Sequence[Any]): a batch of data in the format of [img_indexes, [X], Y], where 74 | [X] is a list of size self.num_crops containing batches of images. 75 | batch_idx (int): index of the batch. 76 | 77 | Returns: 78 | torch.Tensor: total loss composed of Barlow loss and classification loss. 79 | """ 80 | 81 | out = super().training_step(batch, batch_idx) 82 | 83 | feats1, feats2 = out["feats"] 84 | 85 | z1 = self.projector(feats1) 86 | z2 = self.projector(feats2) 87 | 88 | # ------- barlow twins loss ------- 89 | barlow_loss = barlow_loss_func(z1, z2, lamb=self.lamb, scale_loss=self.scale_loss) 90 | 91 | self.log("train_barlow_loss", barlow_loss, on_epoch=True, sync_dist=True) 92 | 93 | out.update({"loss": out["loss"] + barlow_loss, "z": [z1, z2]}) 94 | return out 95 | -------------------------------------------------------------------------------- /kaizen/distiller_factories/predictive.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Any, List, Sequence 3 | 4 | import torch 5 | from torch import nn 6 | from kaizen.losses.byol import byol_loss_func 7 | from kaizen.args.utils import strtobool 8 | 9 | def predictive_distill_factory( 10 | Method=object, class_tag="", 11 | distill_current_key="z", distill_frozen_key="frozen_z", output_dim=256 12 | ): 13 | distill_lamb_name = f"{class_tag}_distill_lamb" 14 | distill_proj_hidden_dim_name = f"{class_tag}_distill_proj_hidden_dim" 15 | distill_no_predictior_name = f"{class_tag}_distill_no_predictior" 16 | 17 | distill_predictor_name = f"{class_tag}_distill_predictor" 18 | class PredictiveDistillWrapper(Method): 19 | def __init__(self, **kwargs): 20 | distill_lamb = kwargs.pop(distill_lamb_name) 21 | distill_proj_hidden_dim = kwargs.pop(distill_proj_hidden_dim_name) 22 | distill_no_predictior = kwargs.pop(distill_no_predictior_name) 23 | super().__init__(**kwargs) 24 | 25 | setattr(self, distill_lamb_name, distill_lamb) 26 | setattr(self, distill_proj_hidden_dim_name, distill_proj_hidden_dim) 27 | setattr(self, distill_no_predictior_name, distill_no_predictior) 28 | if distill_no_predictior: 29 | setattr(self, distill_predictor_name, nn.Identity()) 30 | else: 31 | setattr(self, distill_predictor_name, nn.Sequential( 32 | nn.Linear(output_dim, distill_proj_hidden_dim), 33 | nn.BatchNorm1d(distill_proj_hidden_dim), 34 | nn.ReLU(), 35 | nn.Linear(distill_proj_hidden_dim, output_dim), 36 | )) 37 | 38 | 39 | @staticmethod 40 | def add_model_specific_args( 41 | parent_parser: argparse.ArgumentParser, 42 | ) -> argparse.ArgumentParser: 43 | parser = parent_parser.add_argument_group(f"predictive_{class_tag}_distiller") 44 | 45 | parser.add_argument(f"--{distill_lamb_name}", type=float, default=1) 46 | parser.add_argument(f"--{distill_proj_hidden_dim_name}", type=int, default=2048) 47 | parser.add_argument(f"--{distill_no_predictior_name}", type=strtobool, default=False) 48 | 49 | return parent_parser 50 | 51 | @property 52 | def learnable_params(self) -> List[dict]: 53 | """Adds distill predictor parameters to the parent's learnable parameters. 54 | 55 | Returns: 56 | List[dict]: list of learnable parameters. 57 | """ 58 | 59 | extra_learnable_params = [ 60 | { 61 | "name": f"{class_tag}_distill_predictor", 62 | "params": getattr(self, distill_predictor_name).parameters(), 63 | "lr": self.lr if getattr(self, distill_lamb_name) >= 1 else self.lr / getattr(self, distill_lamb_name), 64 | "weight_decay": self.weight_decay, 65 | }, 66 | ] 67 | return super().learnable_params + extra_learnable_params 68 | 69 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: 70 | out = super().training_step(batch, batch_idx) 71 | z1, z2 = out[distill_current_key] 72 | frozen_z1, frozen_z2 = out[distill_frozen_key] 73 | 74 | p1 = getattr(self, distill_predictor_name)(z1) 75 | p2 = getattr(self, distill_predictor_name)(z2) 76 | 77 | distill_loss = (byol_loss_func(p1, frozen_z1) + byol_loss_func(p2, frozen_z2)) / 2 78 | 79 | self.log(f"train_{class_tag}_predictive_distill_loss", distill_loss, on_epoch=True, sync_dist=True) 80 | 81 | out["loss"] += getattr(self, distill_lamb_name) * distill_loss 82 | return out 83 | 84 | return PredictiveDistillWrapper 85 | -------------------------------------------------------------------------------- /kaizen/distiller_factories/soft_label.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Any, List, Sequence 3 | 4 | import torch 5 | from torch import nn 6 | from torch.functional import F 7 | from kaizen.args.utils import strtobool 8 | 9 | def soft_label_distill_factory( 10 | Method=object, class_tag="", 11 | distill_current_key="z", distill_frozen_key="frozen_z", output_dim=256 12 | ): 13 | distill_lamb_name = f"{class_tag}_distill_lamb" 14 | distill_proj_hidden_dim_name = f"{class_tag}_distill_proj_hidden_dim" 15 | distill_no_predictior_name = f"{class_tag}_distill_no_predictior" 16 | 17 | distill_predictor_name = f"{class_tag}_distill_predictor" 18 | class PredictiveDistillWrapper(Method): 19 | def __init__(self, **kwargs): 20 | distill_lamb = kwargs.pop(distill_lamb_name) 21 | distill_proj_hidden_dim = kwargs.pop(distill_proj_hidden_dim_name) 22 | distill_no_predictior = kwargs.pop(distill_no_predictior_name) 23 | super().__init__(**kwargs) 24 | 25 | setattr(self, distill_lamb_name, distill_lamb) 26 | setattr(self, distill_proj_hidden_dim_name, distill_proj_hidden_dim) 27 | setattr(self, distill_no_predictior_name, distill_no_predictior) 28 | if distill_no_predictior: 29 | setattr(self, distill_predictor_name, nn.Identity()) 30 | else: 31 | setattr(self, distill_predictor_name, nn.Sequential( 32 | nn.Linear(output_dim, distill_proj_hidden_dim), 33 | nn.BatchNorm1d(distill_proj_hidden_dim), 34 | nn.ReLU(), 35 | nn.Linear(distill_proj_hidden_dim, output_dim), 36 | )) 37 | 38 | 39 | @staticmethod 40 | def add_model_specific_args( 41 | parent_parser: argparse.ArgumentParser, 42 | ) -> argparse.ArgumentParser: 43 | parser = parent_parser.add_argument_group(f"predictive_{class_tag}_distiller") 44 | 45 | parser.add_argument(f"--{distill_lamb_name}", type=float, default=1) 46 | parser.add_argument(f"--{distill_proj_hidden_dim_name}", type=int, default=2048) 47 | parser.add_argument(f"--{distill_no_predictior_name}", type=strtobool, default=False) 48 | 49 | return parent_parser 50 | 51 | @property 52 | def learnable_params(self) -> List[dict]: 53 | """Adds distill predictor parameters to the parent's learnable parameters. 54 | 55 | Returns: 56 | List[dict]: list of learnable parameters. 57 | """ 58 | 59 | extra_learnable_params = [ 60 | { 61 | "name": f"{class_tag}_distill_predictor", 62 | "params": getattr(self, distill_predictor_name).parameters(), 63 | "lr": self.lr if getattr(self, distill_lamb_name) >= 1 else self.lr / getattr(self, distill_lamb_name), 64 | "weight_decay": self.weight_decay, 65 | }, 66 | ] 67 | return super().learnable_params + extra_learnable_params 68 | 69 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: 70 | out = super().training_step(batch, batch_idx) 71 | z1, z2 = out[distill_current_key] 72 | frozen_z1, frozen_z2 = out[distill_frozen_key] 73 | 74 | p1 = getattr(self, distill_predictor_name)(z1) 75 | p2 = getattr(self, distill_predictor_name)(z2) 76 | 77 | distill_loss = (F.cross_entropy(p1, frozen_z1.softmax(dim=1)) + F.cross_entropy(p2, frozen_z2.softmax(dim=1))) / 2 78 | 79 | self.log(f"train_{class_tag}_soft_label_distill_loss", distill_loss, on_epoch=True, sync_dist=True) 80 | 81 | out["loss"] += getattr(self, distill_lamb_name) * distill_loss 82 | return out 83 | 84 | return PredictiveDistillWrapper 85 | -------------------------------------------------------------------------------- /kaizen/utils/lars.py: -------------------------------------------------------------------------------- 1 | """ 2 | References: 3 | - https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py 4 | - https://arxiv.org/pdf/1708.03888.pdf 5 | - https://github.com/noahgolmant/pytorch-lars/blob/master/lars.py 6 | """ 7 | import torch 8 | from torch.optim import Optimizer 9 | 10 | 11 | class LARSWrapper: 12 | def __init__( 13 | self, 14 | optimizer: Optimizer, 15 | eta: float = 1e-3, 16 | clip: bool = False, 17 | eps: float = 1e-8, 18 | exclude_bias_n_norm: bool = False, 19 | ): 20 | """Wrapper that adds LARS scheduling to any optimizer. 21 | This helps stability with huge batch sizes. 22 | 23 | Args: 24 | optimizer (Optimizer): torch optimizer. 25 | eta (float, optional): trust coefficient. Defaults to 1e-3. 26 | clip (bool, optional): clip gradient values. Defaults to False. 27 | eps (float, optional): adaptive_lr stability coefficient. Defaults to 1e-8. 28 | exclude_bias_n_norm (bool, optional): exclude bias and normalization layers from lars. 29 | Defaults to False. 30 | """ 31 | 32 | self.optim = optimizer 33 | self.eta = eta 34 | self.eps = eps 35 | self.clip = clip 36 | self.exclude_bias_n_norm = exclude_bias_n_norm 37 | 38 | # transfer optim methods 39 | self.state_dict = self.optim.state_dict 40 | self.load_state_dict = self.optim.load_state_dict 41 | self.zero_grad = self.optim.zero_grad 42 | self.add_param_group = self.optim.add_param_group 43 | 44 | self.__setstate__ = self.optim.__setstate__ # type: ignore 45 | self.__getstate__ = self.optim.__getstate__ # type: ignore 46 | self.__repr__ = self.optim.__repr__ # type: ignore 47 | 48 | @property 49 | def defaults(self): 50 | return self.optim.defaults 51 | 52 | @defaults.setter 53 | def defaults(self, defaults): 54 | self.optim.defaults = defaults 55 | 56 | @property # type: ignore 57 | def __class__(self): 58 | return Optimizer 59 | 60 | @property 61 | def state(self): 62 | return self.optim.state 63 | 64 | @state.setter 65 | def state(self, state): 66 | self.optim.state = state 67 | 68 | @property 69 | def param_groups(self): 70 | return self.optim.param_groups 71 | 72 | @param_groups.setter 73 | def param_groups(self, value): 74 | self.optim.param_groups = value 75 | 76 | @torch.no_grad() 77 | def step(self, closure=None): 78 | weight_decays = [] 79 | 80 | for group in self.optim.param_groups: 81 | weight_decay = group.get("weight_decay", 0) 82 | weight_decays.append(weight_decay) 83 | 84 | # reset weight decay 85 | group["weight_decay"] = 0 86 | 87 | # update the parameters 88 | for p in group["params"]: 89 | if p.grad is not None and (p.ndim != 1 or not self.exclude_bias_n_norm): 90 | self.update_p(p, group, weight_decay) 91 | 92 | # update the optimizer 93 | self.optim.step(closure=closure) 94 | 95 | # return weight decay control to optimizer 96 | for group_idx, group in enumerate(self.optim.param_groups): 97 | group["weight_decay"] = weight_decays[group_idx] 98 | 99 | def update_p(self, p, group, weight_decay): 100 | # calculate new norms 101 | p_norm = torch.norm(p.data) 102 | g_norm = torch.norm(p.grad.data) 103 | 104 | if p_norm != 0 and g_norm != 0: 105 | # calculate new lr 106 | new_lr = (self.eta * p_norm) / (g_norm + p_norm * weight_decay + self.eps) 107 | 108 | # clip lr 109 | if self.clip: 110 | new_lr = min(new_lr / group["lr"], 1) 111 | 112 | # update params with clipped lr 113 | p.grad.data += weight_decay * p.data 114 | p.grad.data *= new_lr 115 | -------------------------------------------------------------------------------- /kaizen/distiller_factories/contrastive.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Any, List, Sequence 3 | 4 | import torch 5 | from torch import nn 6 | from kaizen.losses.simclr import simclr_distill_loss_func 7 | from kaizen.args.utils import strtobool 8 | 9 | def contrastive_distill_factory( 10 | Method=object, class_tag="", 11 | distill_current_key="z", distill_frozen_key="frozen_z", output_dim=256 12 | ): 13 | distill_lamb_name = f"{class_tag}_distill_lamb" 14 | distill_proj_hidden_dim_name = f"{class_tag}_distill_proj_hidden_dim" 15 | distill_temperature_name = f"{class_tag}_distill_temperature" 16 | distill_no_predictior_name = f"{class_tag}_distill_no_predictior" 17 | 18 | distill_predictor_name = f"{class_tag}_distill_predictor" 19 | class ContrastiveDistillWrapper(Method): 20 | def __init__(self, **kwargs): 21 | distill_lamb: float = kwargs.pop(distill_lamb_name) 22 | distill_proj_hidden_dim: int = kwargs.pop(distill_proj_hidden_dim_name) 23 | distill_temperature: float = kwargs.pop(distill_temperature_name) 24 | distill_no_predictior = kwargs.pop(distill_no_predictior_name) 25 | super().__init__(**kwargs) 26 | 27 | setattr(self, distill_lamb_name, distill_lamb) 28 | setattr(self, distill_proj_hidden_dim_name, distill_proj_hidden_dim) 29 | setattr(self, distill_temperature_name, distill_temperature) 30 | setattr(self, distill_no_predictior_name, distill_no_predictior) 31 | if distill_no_predictior: 32 | setattr(self, distill_predictor_name, nn.Identity()) 33 | else: 34 | setattr(self, distill_predictor_name, nn.Sequential( 35 | nn.Linear(output_dim, distill_proj_hidden_dim), 36 | nn.BatchNorm1d(distill_proj_hidden_dim), 37 | nn.ReLU(), 38 | nn.Linear(distill_proj_hidden_dim, output_dim), 39 | )) 40 | 41 | @staticmethod 42 | def add_model_specific_args( 43 | parent_parser: argparse.ArgumentParser, 44 | ) -> argparse.ArgumentParser: 45 | parser = parent_parser.add_argument_group(f"contrastive_{class_tag}_distiller") 46 | 47 | parser.add_argument(f"--{distill_lamb_name}", type=float, default=1) 48 | parser.add_argument(f"--{distill_proj_hidden_dim_name}", type=int, default=2048) 49 | parser.add_argument(f"--{distill_temperature_name}", type=float, default=0.2) 50 | parser.add_argument(f"--{distill_no_predictior_name}", type=strtobool, default=False) 51 | 52 | return parent_parser 53 | 54 | @property 55 | def learnable_params(self) -> List[dict]: 56 | """Adds distill predictor parameters to the parent's learnable parameters. 57 | 58 | Returns: 59 | List[dict]: list of learnable parameters. 60 | """ 61 | 62 | extra_learnable_params = [ 63 | {"params": getattr(self, distill_predictor_name).parameters()}, 64 | ] 65 | return super().learnable_params + extra_learnable_params 66 | 67 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: 68 | out = super().training_step(batch, batch_idx) 69 | z1, z2 = out[distill_current_key] 70 | frozen_z1, frozen_z2 = out[distill_frozen_key] 71 | 72 | p1 = getattr(self, distill_predictor_name)(z1) 73 | p2 = getattr(self, distill_predictor_name)(z2) 74 | 75 | distill_loss = ( 76 | simclr_distill_loss_func(p1, p2, frozen_z1, frozen_z2, getattr(self, distill_temperature_name)) 77 | + simclr_distill_loss_func(frozen_z1, frozen_z2, p1, p2, getattr(self, distill_temperature_name)) 78 | ) / 2 79 | 80 | self.log(f"train_{class_tag}_contrastive_distill_loss", distill_loss, on_epoch=True, sync_dist=True) 81 | 82 | out["loss"] += getattr(self, distill_lamb_name) * distill_loss 83 | return out 84 | 85 | return ContrastiveDistillWrapper 86 | -------------------------------------------------------------------------------- /main_linear.py: -------------------------------------------------------------------------------- 1 | import os 2 | import types 3 | 4 | import torch 5 | import torch.nn as nn 6 | from pytorch_lightning import Trainer, seed_everything 7 | from pytorch_lightning.callbacks import LearningRateMonitor 8 | from pytorch_lightning.loggers import WandbLogger 9 | from pytorch_lightning.plugins import DDPPlugin 10 | from torchvision.models import resnet18, resnet50 11 | 12 | from kaizen.args.setup import parse_args_linear 13 | 14 | try: 15 | from kaizen.methods.dali import ClassificationABC 16 | except ImportError: 17 | _dali_avaliable = False 18 | else: 19 | _dali_avaliable = True 20 | from kaizen.methods.linear import LinearModel 21 | from kaizen.utils.classification_dataloader import prepare_data 22 | from kaizen.utils.checkpointer import Checkpointer 23 | 24 | 25 | def main(): 26 | args = parse_args_linear() 27 | 28 | # split classes into tasks 29 | tasks = None 30 | if args.split_strategy == "class": 31 | assert args.num_classes % args.num_tasks == 0 32 | torch.manual_seed(args.split_seed) 33 | tasks = torch.randperm(args.num_classes).chunk(args.num_tasks) 34 | 35 | seed_everything(args.global_seed) 36 | 37 | if args.encoder == "resnet18": 38 | backbone = resnet18() 39 | elif args.encoder == "resnet50": 40 | backbone = resnet50() 41 | else: 42 | raise ValueError("Only [resnet18, resnet50] are currently supported.") 43 | 44 | if args.cifar: 45 | backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=2, bias=False) 46 | backbone.maxpool = nn.Identity() 47 | backbone.fc = nn.Identity() 48 | 49 | assert ( 50 | args.pretrained_feature_extractor.endswith(".ckpt") 51 | or args.pretrained_feature_extractor.endswith(".pth") 52 | or args.pretrained_feature_extractor.endswith(".pt") 53 | ) 54 | ckpt_path = args.pretrained_feature_extractor 55 | 56 | state = torch.load(ckpt_path)["state_dict"] 57 | for k in list(state.keys()): 58 | if "encoder" in k: 59 | state[k.replace("encoder.", "")] = state[k] 60 | del state[k] 61 | backbone.load_state_dict(state, strict=False) 62 | 63 | print(f"Loaded {ckpt_path}") 64 | 65 | if args.dali: 66 | assert _dali_avaliable, "Dali is not currently avaiable, please install it first." 67 | MethodClass = types.new_class( 68 | f"Dali{LinearModel.__name__}", (ClassificationABC, LinearModel) 69 | ) 70 | else: 71 | MethodClass = LinearModel 72 | 73 | model = MethodClass(backbone, **args.__dict__, tasks=tasks) 74 | 75 | train_loader, val_loader = prepare_data( 76 | args.dataset, 77 | data_dir=args.data_dir, 78 | train_dir=args.train_dir, 79 | val_dir=args.val_dir, 80 | batch_size=args.batch_size, 81 | num_workers=args.num_workers, 82 | semi_supervised=args.semi_supervised, 83 | ) 84 | 85 | callbacks = [] 86 | 87 | # wandb logging 88 | if args.wandb: 89 | wandb_logger = WandbLogger( 90 | name=args.name, project=args.project, entity=args.entity, offline=args.offline 91 | ) 92 | wandb_logger.watch(model, log="gradients", log_freq=100) 93 | wandb_logger.log_hyperparams(args) 94 | 95 | # lr logging 96 | lr_monitor = LearningRateMonitor(logging_interval="epoch") 97 | callbacks.append(lr_monitor) 98 | 99 | # save checkpoint on last epoch only 100 | ckpt = Checkpointer( 101 | args, 102 | logdir=os.path.join(args.checkpoint_dir, "linear"), 103 | frequency=args.checkpoint_frequency, 104 | ) 105 | callbacks.append(ckpt) 106 | 107 | trainer = Trainer.from_argparse_args( 108 | args, 109 | logger=wandb_logger if args.wandb else None, 110 | callbacks=callbacks, 111 | plugins=DDPPlugin(find_unused_parameters=True), 112 | checkpoint_callback=False, 113 | terminate_on_nan=True, 114 | accelerator="ddp", 115 | ) 116 | if args.dali: 117 | trainer.fit(model, val_dataloaders=val_loader) 118 | else: 119 | trainer.fit(model, train_loader, val_loader) 120 | 121 | 122 | if __name__ == "__main__": 123 | main() 124 | -------------------------------------------------------------------------------- /kaizen/losses/dino.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.distributed as dist 5 | import numpy as np 6 | 7 | 8 | class DINOLoss(nn.Module): 9 | def __init__( 10 | self, 11 | num_prototypes: int, 12 | warmup_teacher_temp: float, 13 | teacher_temp: float, 14 | warmup_teacher_temp_epochs: float, 15 | num_epochs: int, 16 | student_temp: float = 0.1, 17 | num_crops: int = 2, 18 | center_momentum: float = 0.9, 19 | ): 20 | """Auxiliary module to compute DINO's loss. 21 | 22 | Args: 23 | num_prototypes (int): number of prototypes. 24 | warmup_teacher_temp (float): base temperature for the temperature schedule 25 | of the teacher. 26 | teacher_temp (float): final temperature for the teacher. 27 | warmup_teacher_temp_epochs (float): number of epochs for the cosine annealing schedule. 28 | num_epochs (int): total number of epochs. 29 | student_temp (float, optional): temperature for the student. Defaults to 0.1. 30 | num_crops (int, optional): number of crops/views. Defaults to 2. 31 | center_momentum (float, optional): momentum for the EMA update of the center of 32 | mass of the teacher. Defaults to 0.9. 33 | """ 34 | 35 | super().__init__() 36 | self.epoch = 0 37 | self.student_temp = student_temp 38 | self.center_momentum = center_momentum 39 | self.num_crops = num_crops 40 | self.register_buffer("center", torch.zeros(1, num_prototypes)) 41 | # we apply a warm up for the teacher temperature because 42 | # a too high temperature makes the training unstable at the beginning 43 | self.teacher_temp_schedule = np.concatenate( 44 | ( 45 | np.linspace(warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs), 46 | np.ones(num_epochs - warmup_teacher_temp_epochs) * teacher_temp, 47 | ) 48 | ) 49 | 50 | def forward(self, student_output: torch.Tensor, teacher_output: torch.Tensor) -> torch.Tensor: 51 | """Computes DINO's loss given a batch of logits of the student and a batch of logits of the 52 | teacher. 53 | 54 | Args: 55 | student_output (torch.Tensor): NxP Tensor containing student logits for all views. 56 | teacher_output (torch.Tensor): NxP Tensor containing teacher logits for all views. 57 | 58 | Returns: 59 | torch.Tensor: DINO loss. 60 | """ 61 | 62 | student_out = student_output / self.student_temp 63 | student_out = student_out.chunk(self.num_crops) 64 | 65 | # teacher centering and sharpening 66 | temp = self.teacher_temp_schedule[self.epoch] 67 | teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1) 68 | teacher_out = teacher_out.detach().chunk(2) 69 | 70 | total_loss = 0 71 | n_loss_terms = 0 72 | for iq, q in enumerate(teacher_out): 73 | for v in range(len(student_out)): 74 | if v == iq: 75 | # we skip cases where student and teacher operate on the same view 76 | continue 77 | loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1) 78 | total_loss += loss.mean() 79 | n_loss_terms += 1 80 | total_loss /= n_loss_terms 81 | self.update_center(teacher_output) 82 | return total_loss 83 | 84 | @torch.no_grad() 85 | def update_center(self, teacher_output: torch.Tensor): 86 | """Updates the center for DINO's loss using exponential moving average. 87 | 88 | Args: 89 | teacher_output (torch.Tensor): NxP Tensor containing teacher logits of all views. 90 | """ 91 | 92 | batch_center = torch.sum(teacher_output, dim=0, keepdim=True) 93 | if dist.is_available() and dist.is_initialized(): 94 | dist.all_reduce(batch_center) 95 | batch_center = batch_center / dist.get_world_size() 96 | batch_center = batch_center / len(teacher_output) 97 | 98 | # ema update 99 | self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum) 100 | -------------------------------------------------------------------------------- /kaizen/methods/vicreg.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Any, Dict, List, Sequence 3 | 4 | import torch 5 | import torch.nn as nn 6 | from kaizen.losses.vicreg import vicreg_loss_func 7 | from kaizen.methods.base import BaseModel 8 | 9 | 10 | class VICReg(BaseModel): 11 | def __init__( 12 | self, 13 | output_dim: int, 14 | proj_hidden_dim: int, 15 | sim_loss_weight: float, 16 | var_loss_weight: float, 17 | cov_loss_weight: float, 18 | **kwargs 19 | ): 20 | """Implements VICReg (https://arxiv.org/abs/2105.04906) 21 | 22 | Args: 23 | output_dim (int): number of dimensions of the projected features. 24 | proj_hidden_dim (int): number of neurons in the hidden layers of the projector. 25 | sim_loss_weight (float): weight of the invariance term. 26 | var_loss_weight (float): weight of the variance term. 27 | cov_loss_weight (float): weight of the covariance term. 28 | """ 29 | 30 | super().__init__(**kwargs) 31 | 32 | self.sim_loss_weight = sim_loss_weight 33 | self.var_loss_weight = var_loss_weight 34 | self.cov_loss_weight = cov_loss_weight 35 | 36 | # projector 37 | self.projector = nn.Sequential( 38 | nn.Linear(self.features_dim, proj_hidden_dim), 39 | nn.BatchNorm1d(proj_hidden_dim), 40 | nn.ReLU(), 41 | nn.Linear(proj_hidden_dim, proj_hidden_dim), 42 | nn.BatchNorm1d(proj_hidden_dim), 43 | nn.ReLU(), 44 | nn.Linear(proj_hidden_dim, output_dim), 45 | ) 46 | 47 | @staticmethod 48 | def add_model_specific_args(parent_parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 49 | parent_parser = super(VICReg, VICReg).add_model_specific_args(parent_parser) 50 | parser = parent_parser.add_argument_group("vicreg") 51 | 52 | # projector 53 | parser.add_argument("--output_dim", type=int, default=2048) 54 | parser.add_argument("--proj_hidden_dim", type=int, default=2048) 55 | 56 | # parameters 57 | parser.add_argument("--sim_loss_weight", default=25, type=float) 58 | parser.add_argument("--var_loss_weight", default=25, type=float) 59 | parser.add_argument("--cov_loss_weight", default=1.0, type=float) 60 | return parent_parser 61 | 62 | @property 63 | def learnable_params(self) -> List[dict]: 64 | """Adds projector parameters to the parent's learnable parameters. 65 | 66 | Returns: 67 | List[dict]: list of learnable parameters. 68 | """ 69 | 70 | extra_learnable_params = [{"params": self.projector.parameters()}] 71 | return super().learnable_params + extra_learnable_params 72 | 73 | def forward(self, X: torch.Tensor, *args, **kwargs) -> Dict[str, Any]: 74 | """Performs the forward pass of the encoder and the projector. 75 | 76 | Args: 77 | X (torch.Tensor): a batch of images in the tensor format. 78 | 79 | Returns: 80 | Dict[str, Any]: a dict containing the outputs of the parent and the projected features. 81 | """ 82 | 83 | out = super().forward(X, *args, **kwargs) 84 | z = self.projector(out["feats"]) 85 | return {**out, "z": z} 86 | 87 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: 88 | """Training step for VICReg reusing BaseModel training step. 89 | 90 | Args: 91 | batch (Sequence[Any]): a batch of data in the format of [img_indexes, [X], Y], where 92 | [X] is a list of size self.num_crops containing batches of images. 93 | batch_idx (int): index of the batch. 94 | 95 | Returns: 96 | torch.Tensor: total loss composed of VICReg loss and classification loss. 97 | """ 98 | 99 | out = super().training_step(batch, batch_idx) 100 | feats1, feats2 = out["feats"] 101 | 102 | z1 = self.projector(feats1) 103 | z2 = self.projector(feats2) 104 | 105 | # ------- barlow twins loss ------- 106 | vicreg_loss = vicreg_loss_func( 107 | z1, 108 | z2, 109 | sim_loss_weight=self.sim_loss_weight, 110 | var_loss_weight=self.var_loss_weight, 111 | cov_loss_weight=self.cov_loss_weight, 112 | ) 113 | 114 | self.log("train_vicreg_loss", vicreg_loss, on_epoch=True, sync_dist=True) 115 | 116 | out.update({"loss": out["loss"] + vicreg_loss, "z": [z1, z2]}) 117 | return out 118 | -------------------------------------------------------------------------------- /kaizen/distillers/knowledge.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Any, List, Sequence 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | from kaizen.distillers.base import base_distill_wrapper 8 | 9 | 10 | def cross_entropy(preds, targets): 11 | return -torch.mean( 12 | torch.sum(F.softmax(targets, dim=-1) * torch.log_softmax(preds, dim=-1), dim=-1) 13 | ) 14 | 15 | 16 | def knowledge_distill_wrapper(Method=object): 17 | class KnowledgeDistillWrapper(base_distill_wrapper(Method)): 18 | def __init__( 19 | self, 20 | distill_lamb: float, 21 | distill_proj_hidden_dim: int, 22 | distill_temperature: float, 23 | **kwargs 24 | ): 25 | super().__init__(**kwargs) 26 | 27 | self.distill_lamb = distill_lamb 28 | self.distill_temperature = distill_temperature 29 | output_dim = kwargs["output_dim"] 30 | num_prototypes = kwargs["num_prototypes"] 31 | 32 | self.frozen_prototypes = nn.utils.weight_norm( 33 | nn.Linear(output_dim, num_prototypes, bias=False) 34 | ) 35 | for frozen_pg, pg in zip( 36 | self.frozen_prototypes.parameters(), self.prototypes.parameters() 37 | ): 38 | frozen_pg.data.copy_(pg.data) 39 | frozen_pg.requires_grad = False 40 | 41 | self.distill_predictor = nn.Sequential( 42 | nn.Linear(output_dim, distill_proj_hidden_dim), 43 | nn.BatchNorm1d(distill_proj_hidden_dim), 44 | nn.ReLU(), 45 | nn.Linear(distill_proj_hidden_dim, output_dim), 46 | ) 47 | 48 | self.distill_prototypes = nn.utils.weight_norm( 49 | nn.Linear(output_dim, num_prototypes, bias=False) 50 | ) 51 | 52 | @staticmethod 53 | def add_model_specific_args( 54 | parent_parser: argparse.ArgumentParser, 55 | ) -> argparse.ArgumentParser: 56 | parser = parent_parser.add_argument_group("knowledge_distiller") 57 | 58 | parser.add_argument("--distill_lamb", type=float, default=1) 59 | parser.add_argument("--distill_proj_hidden_dim", type=int, default=2048) 60 | parser.add_argument("--distill_temperature", type=float, default=0.1) 61 | 62 | return parent_parser 63 | 64 | @property 65 | def learnable_params(self) -> List[dict]: 66 | """Adds distill predictor parameters to the parent's learnable parameters. 67 | 68 | Returns: 69 | List[dict]: list of learnable parameters. 70 | """ 71 | 72 | extra_learnable_params = [ 73 | {"params": self.distill_predictor.parameters()}, 74 | {"params": self.distill_prototypes.parameters()}, 75 | ] 76 | return super().learnable_params + extra_learnable_params 77 | 78 | def on_train_start(self): 79 | super().on_train_start() 80 | 81 | if self.current_task_idx > 0: 82 | for frozen_pg, pg in zip( 83 | self.frozen_prototypes.parameters(), self.prototypes.parameters() 84 | ): 85 | frozen_pg.data.copy_(pg.data) 86 | frozen_pg.requires_grad = False 87 | 88 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: 89 | out = super().training_step(batch, batch_idx) 90 | z1, z2 = out["z"] 91 | frozen_z1, frozen_z2 = out["frozen_z"] 92 | 93 | with torch.no_grad(): 94 | frozen_z1 = F.normalize(frozen_z1) 95 | frozen_z2 = F.normalize(frozen_z2) 96 | frozen_p1 = self.frozen_prototypes(frozen_z1) / self.distill_temperature 97 | frozen_p2 = self.frozen_prototypes(frozen_z2) / self.distill_temperature 98 | 99 | distill_z1 = F.normalize(self.distill_predictor(z1)) 100 | distill_z2 = F.normalize(self.distill_predictor(z2)) 101 | distill_p1 = self.distill_prototypes(distill_z1) / self.distill_temperature 102 | distill_p2 = self.distill_prototypes(distill_z2) / self.distill_temperature 103 | 104 | distill_loss = ( 105 | cross_entropy(distill_p1, frozen_p1) + cross_entropy(distill_p2, frozen_p2) 106 | ) / 2 107 | 108 | self.log("train_knowledge_distill_loss", distill_loss, on_epoch=True, sync_dist=True) 109 | 110 | return out["loss"] + self.distill_lamb * distill_loss 111 | 112 | return KnowledgeDistillWrapper 113 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Kaizen: Practical Self-Supervised Continual Learning With Continual Fine-Tuning 2 | 3 | Official implementation of the algorithm outlined in 4 | > **[Kaizen: Practical Self-Supervised Continual Learning With Continual Fine-Tuning](https://openaccess.thecvf.com/content/WACV2024/html/Tang_Kaizen_Practical_Self-Supervised_Continual_Learning_With_Continual_Fine-Tuning_WACV_2024_paper.html)**
5 | > Chi Ian Tang, Lorena Qendro, Dimitris Spathis, Fahim Kawsar, Cecilia Mascolo, Akhil Mathur
6 | > **WACV 2024** 7 | 8 | > **Abstract:** *Self-supervised learning (SSL) has shown remarkable performance in computer vision tasks when trained offline. However, in a Continual Learning (CL) scenario where new data is introduced progressively, models still suffer from catastrophic forgetting. Retraining a model from scratch to adapt to newly generated data is time-consuming and inefficient. Previous approaches suggested re-purposing self-supervised objectives with knowledge distillation to mitigate forgetting across tasks, assuming that labels from all tasks are available during fine-tuning. In this paper, we generalize self-supervised continual learning in a practical setting where available labels can be leveraged in any step of the SSL process. With an increasing number of continual tasks, this offers more flexibility in the pre-training and fine-tuning phases. With Kaizen, we introduce a training architecture that is able to mitigate catastrophic forgetting for both the feature extractor and classifier with a carefully designed loss function. By using a set of comprehensive evaluation metrics reflecting different aspects of continual learning, we demonstrated that Kaizen significantly outperforms previous SSL models in competitive vision benchmarks, with up to 16.5% accuracy improvement on split CIFAR-100. Kaizen is able to balance the trade-off between knowledge retention and learning from new data with an end-to-end model, paving the way for practical deployment of continual learning systems.* 9 | 10 | ![method](./imgs/method.png) 11 | ![results_continual](./imgs/results_continual.png) 12 | ![results_final](./imgs/results_final.png) 13 | 14 | # Installation 15 | Please run the following command to install the required packages: 16 | ``` 17 | pip install -r requirements.txt 18 | ``` 19 | In order to work with the [WandB](https://wandb.ai/site) logging library, you may need to run the following: 20 | ``` 21 | source setup_commands.sh 22 | ``` 23 | 24 | # Commands 25 | 26 | Bash files for launching the experiments are provided in the `bash_files` folder. Using the `job_launcher.py` can launch experiments which automatically train a model continually for the number of tasks specified in the corresponding bash file. 27 | 28 | For example, to launch the MoCoV2+ experiment, run: 29 | 30 | ``` 31 | DATA_DIR=/YOUR/DATA/DIR/ CUDA_VISIBLE_DEVICES=0 python job_launcher.py --script bash_files/mocov2plus_cifar_distill_classifier_l1000_soft_label_replay_0.01_b32.sh 32 | ``` 33 | 34 | Note that it is the default behaviour of each script to use gpu:0, so setting the `CUDA_VISIBLE_DEVICES=0` environment variable allows control over which gpu to be used by the script. 35 | 36 | # Copyright, Acknowledgment & Warranty 37 | This repository is provided for research reproducibility only. The authors reserve all rights. (Copyright (c) 2023 Chi Ian Tang, Lorena Qendro, Dimitris Spathis, Fahim Kawsar, Cecilia Mascolo, Akhil Mathur). 38 | 39 | Part of this repository is based on the implementation of [cassle](https://github.com/DonkeyShot21/cassle) ([https://github.com/DonkeyShot21/cassle](https://github.com/DonkeyShot21/cassle)), used under the MIT License, Copyright (c) 2021 Enrico Fini, Victor Turrisi, Xavier Alameda-Pineda, Elisa Ricci, Karteek Alahari, Julien Mairal. 40 | 41 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 42 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 43 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 44 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 45 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 46 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 47 | SOFTWARE. 48 | 49 | # Citation 50 | 51 | Please cite our paper if any part of this repository is used in any research work: 52 | ``` 53 | @inproceedings{tang2024kaizen, 54 | title={Kaizen: Practical Self-Supervised Continual Learning With Continual Fine-Tuning}, 55 | author={Tang, Chi Ian and Qendro, Lorena and Spathis, Dimitris and Kawsar, Fahim and Mascolo, Cecilia and Mathur, Akhil}, 56 | booktitle={Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision}, 57 | pages={2841--2850}, 58 | year={2024} 59 | } 60 | ``` 61 | 62 | # License 63 | 64 | This project is licensed under the [BSD 3-Clause Clear License](LICENSE.md). 65 | -------------------------------------------------------------------------------- /kaizen/distiller_factories/decorrelative.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Any, List, Sequence 3 | 4 | import torch 5 | from torch import nn 6 | from kaizen.losses.barlow import barlow_loss_func 7 | from kaizen.args.utils import strtobool 8 | 9 | 10 | def decorrelative_distill_factory( 11 | Method=object, class_tag="", 12 | distill_current_key="z", distill_frozen_key="frozen_z", output_dim=256 13 | ): 14 | distill_lamb_name = f"{class_tag}_distill_lamb" 15 | distill_proj_hidden_dim_name = f"{class_tag}_distill_proj_hidden_dim" 16 | distill_barlow_lamb_name = f"{class_tag}_distill_barlow_lamb" 17 | distill_scale_loss_name = f"{class_tag}_distill_scale_loss" 18 | distill_no_predictior_name = f"{class_tag}_distill_no_predictior" 19 | 20 | distill_predictor_name = f"{class_tag}_distill_predictor" 21 | class DecorrelativeDistillWrapper(Method): 22 | def __init__(self, **kwargs): 23 | distill_lamb: float = kwargs.pop(distill_lamb_name) 24 | distill_proj_hidden_dim: int = kwargs.pop(distill_proj_hidden_dim_name) 25 | distill_barlow_lamb: float = kwargs.pop(distill_barlow_lamb_name) 26 | distill_scale_loss: float = kwargs.pop(distill_scale_loss_name) 27 | distill_no_predictior = kwargs.pop(distill_no_predictior_name) 28 | super().__init__(**kwargs) 29 | 30 | setattr(self, distill_lamb_name, distill_lamb) 31 | setattr(self, distill_proj_hidden_dim_name, distill_proj_hidden_dim) 32 | setattr(self, distill_barlow_lamb_name, distill_barlow_lamb) 33 | setattr(self, distill_scale_loss_name, distill_scale_loss) 34 | setattr(self, distill_no_predictior_name, distill_no_predictior) 35 | if distill_no_predictior: 36 | setattr(self, distill_predictor_name, nn.Identity()) 37 | else: 38 | setattr(self, distill_predictor_name, nn.Sequential( 39 | nn.Linear(output_dim, distill_proj_hidden_dim), 40 | nn.BatchNorm1d(distill_proj_hidden_dim), 41 | nn.ReLU(), 42 | nn.Linear(distill_proj_hidden_dim, output_dim), 43 | )) 44 | 45 | 46 | @staticmethod 47 | def add_model_specific_args( 48 | parent_parser: argparse.ArgumentParser, 49 | ) -> argparse.ArgumentParser: 50 | parser = parent_parser.add_argument_group(f"decorrelative_{class_tag}_distiller") 51 | 52 | parser.add_argument(f"--{distill_lamb_name}", type=float, default=1) 53 | parser.add_argument(f"--{distill_proj_hidden_dim_name}", type=int, default=2048) 54 | parser.add_argument(f"--{distill_barlow_lamb_name}", type=float, default=5e-3) 55 | parser.add_argument(f"--{distill_scale_loss_name}", type=float, default=0.1) 56 | parser.add_argument(f"--{distill_no_predictior_name}", type=strtobool, default=False) 57 | 58 | return parent_parser 59 | 60 | @property 61 | def learnable_params(self) -> List[dict]: 62 | """Adds distill predictor parameters to the parent's learnable parameters. 63 | 64 | Returns: 65 | List[dict]: list of learnable parameters. 66 | """ 67 | 68 | extra_learnable_params = [ 69 | {"params": getattr(self, distill_predictor_name).parameters()}, 70 | ] 71 | return super().learnable_params + extra_learnable_params 72 | 73 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: 74 | out = super().training_step(batch, batch_idx) 75 | z1, z2 = out[distill_current_key] 76 | frozen_z1, frozen_z2 = out[distill_frozen_key] 77 | 78 | p1 = getattr(self, distill_predictor_name)(z1) 79 | p2 = getattr(self, distill_predictor_name)(z2) 80 | 81 | distill_loss = ( 82 | barlow_loss_func( 83 | p1, 84 | frozen_z1, 85 | lamb=getattr(self, distill_barlow_lamb_name), 86 | scale_loss=getattr(self, distill_scale_loss_name), 87 | ) 88 | + barlow_loss_func( 89 | p2, 90 | frozen_z2, 91 | lamb=getattr(self, distill_barlow_lamb_name), 92 | scale_loss=getattr(self, distill_scale_loss_name), 93 | ) 94 | ) / 2 95 | 96 | self.log( 97 | f"train_{class_tag}_decorrelative_distill_loss", distill_loss, on_epoch=True, sync_dist=True 98 | ) 99 | 100 | out["loss"] += getattr(self, distill_lamb_name) * distill_loss 101 | return out 102 | 103 | return DecorrelativeDistillWrapper 104 | -------------------------------------------------------------------------------- /kaizen/methods/wmse.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Sequence 2 | 3 | import torch 4 | import torch.nn as nn 5 | from kaizen.losses.wmse import wmse_loss_func 6 | from kaizen.methods.base import BaseModel 7 | from kaizen.utils.whitening import Whitening2d 8 | 9 | 10 | class WMSE(BaseModel): 11 | def __init__( 12 | self, 13 | output_dim: int, 14 | proj_hidden_dim: int, 15 | whitening_iters: int, 16 | whitening_size: int, 17 | whitening_eps: float, 18 | **kwargs 19 | ): 20 | """Implements W-MSE (https://arxiv.org/abs/2007.06346) 21 | 22 | Args: 23 | output_dim (int): number of dimensions of the projected features. 24 | proj_hidden_dim (int): number of neurons in the hidden layers of the projector. 25 | whitening_iters (int): number of times to perform whitening. 26 | whitening_size (int): size of the batch slice for whitening. 27 | whitening_eps (float): epsilon for numerical stability in whitening. 28 | """ 29 | 30 | super().__init__(**kwargs) 31 | 32 | self.whitening_iters = whitening_iters 33 | self.whitening_size = whitening_size 34 | 35 | assert self.whitening_size <= self.batch_size 36 | 37 | # projector 38 | self.projector = nn.Sequential( 39 | nn.Linear(self.features_dim, proj_hidden_dim), 40 | nn.BatchNorm1d(proj_hidden_dim), 41 | nn.ReLU(), 42 | nn.Linear(proj_hidden_dim, output_dim), 43 | ) 44 | 45 | self.whitening = Whitening2d(output_dim, eps=whitening_eps) 46 | 47 | @staticmethod 48 | def add_model_specific_args(parent_parser): 49 | parent_parser = super(WMSE, WMSE).add_model_specific_args(parent_parser) 50 | parser = parent_parser.add_argument_group("simclr") 51 | 52 | # projector 53 | parser.add_argument("--output_dim", type=int, default=128) 54 | parser.add_argument("--proj_hidden_dim", type=int, default=1024) 55 | 56 | # wmse 57 | parser.add_argument("--whitening_iters", type=int, default=1) 58 | parser.add_argument("--whitening_size", type=int, default=256) 59 | parser.add_argument("--whitening_eps", type=float, default=0) 60 | 61 | return parent_parser 62 | 63 | @property 64 | def learnable_params(self) -> List[Dict]: 65 | """Adds projector parameters to the parent's learnable parameters. 66 | 67 | Returns: 68 | List[dict]: list of learnable parameters. 69 | """ 70 | 71 | extra_learnable_params = [{"params": self.projector.parameters()}] 72 | return super().learnable_params + extra_learnable_params 73 | 74 | def forward(self, X: torch.Tensor, *args, **kwargs) -> Dict[str, Any]: 75 | """Performs the forward pass of the encoder and the projector. 76 | 77 | Args: 78 | X (torch.Tensor): a batch of images in the tensor format. 79 | 80 | Returns: 81 | Dict[str, Any]: a dict containing the outputs of the parent and the projected features. 82 | """ 83 | 84 | out = super().forward(X, *args, **kwargs) 85 | v = self.projector(out["feats"]) 86 | return {**out, "v": v} 87 | 88 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: 89 | """Training step for W-MSE reusing BaseModel training step. 90 | 91 | Args: 92 | batch (Sequence[Any]): a batch of data in the format of [img_indexes, [X], Y], where 93 | [X] is a list of size self.num_crops containing batches of images 94 | batch_idx (int): index of the batch 95 | 96 | Returns: 97 | torch.Tensor: total loss composed of W-MSE loss and classification loss 98 | """ 99 | 100 | out = super().training_step(batch, batch_idx) 101 | class_loss = out["loss"] 102 | feats = out["feats"] 103 | 104 | v = torch.cat([self.projector(f) for f in feats]) 105 | 106 | # ------- wmse loss ------- 107 | bs = self.batch_size 108 | num_losses, wmse_loss = 0, 0 109 | for _ in range(self.whitening_iters): 110 | z = torch.empty_like(v) 111 | perm = torch.randperm(bs).view(-1, self.whitening_size) 112 | for idx in perm: 113 | for i in range(self.num_crops): 114 | z[idx + i * bs] = self.whitening(v[idx + i * bs]).type_as(z) 115 | for i in range(self.num_crops - 1): 116 | for j in range(i + 1, self.num_crops): 117 | x0 = z[i * bs : (i + 1) * bs] 118 | x1 = z[j * bs : (j + 1) * bs] 119 | wmse_loss += wmse_loss_func(x0, x1) 120 | num_losses += 1 121 | wmse_loss /= num_losses 122 | 123 | self.log("train_neg_cos_sim", wmse_loss, on_epoch=True, sync_dist=True) 124 | 125 | return wmse_loss + class_loss 126 | -------------------------------------------------------------------------------- /kaizen/methods/simsiam.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Any, Dict, List, Sequence 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from kaizen.losses.simsiam import simsiam_loss_func 8 | from kaizen.methods.base import BaseModel 9 | 10 | 11 | class SimSiam(BaseModel): 12 | def __init__( 13 | self, 14 | output_dim: int, 15 | proj_hidden_dim: int, 16 | pred_hidden_dim: int, 17 | **kwargs, 18 | ): 19 | """Implements SimSiam (https://arxiv.org/abs/2011.10566). 20 | 21 | Args: 22 | output_dim (int): number of dimensions of projected features. 23 | proj_hidden_dim (int): number of neurons of the hidden layers of the projector. 24 | pred_hidden_dim (int): number of neurons of the hidden layers of the predictor. 25 | """ 26 | 27 | super().__init__(**kwargs) 28 | 29 | # projector 30 | self.projector = nn.Sequential( 31 | nn.Linear(self.features_dim, proj_hidden_dim, bias=False), 32 | nn.BatchNorm1d(proj_hidden_dim), 33 | nn.ReLU(), 34 | nn.Linear(proj_hidden_dim, proj_hidden_dim, bias=False), 35 | nn.BatchNorm1d(proj_hidden_dim), 36 | nn.ReLU(), 37 | nn.Linear(proj_hidden_dim, output_dim), 38 | nn.BatchNorm1d(output_dim, affine=False), 39 | ) 40 | self.projector[6].bias.requires_grad = False # hack: not use bias as it is followed by BN 41 | 42 | # predictor 43 | self.predictor = nn.Sequential( 44 | nn.Linear(output_dim, pred_hidden_dim, bias=False), 45 | nn.BatchNorm1d(pred_hidden_dim), 46 | nn.ReLU(), 47 | nn.Linear(pred_hidden_dim, output_dim), 48 | ) 49 | 50 | @staticmethod 51 | def add_model_specific_args(parent_parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 52 | parent_parser = super(SimSiam, SimSiam).add_model_specific_args(parent_parser) 53 | parser = parent_parser.add_argument_group("simsiam") 54 | 55 | # projector 56 | parser.add_argument("--output_dim", type=int, default=128) 57 | parser.add_argument("--proj_hidden_dim", type=int, default=2048) 58 | 59 | # predictor 60 | parser.add_argument("--pred_hidden_dim", type=int, default=512) 61 | return parent_parser 62 | 63 | @property 64 | def learnable_params(self) -> List[dict]: 65 | """Adds projector and predictor parameters to the parent's learnable parameters. 66 | 67 | Returns: 68 | List[dict]: list of learnable parameters. 69 | """ 70 | 71 | extra_learnable_params: List[dict] = [ 72 | {"params": self.projector.parameters()}, 73 | {"params": self.predictor.parameters(), "static_lr": True}, 74 | ] 75 | return super().learnable_params + extra_learnable_params 76 | 77 | def forward(self, X: torch.Tensor, *args, **kwargs) -> Dict[str, Any]: 78 | """Performs the forward pass of the encoder, the projector and the predictor. 79 | 80 | Args: 81 | X (torch.Tensor): a batch of images in the tensor format. 82 | 83 | Returns: 84 | Dict[str, Any]: 85 | a dict containing the outputs of the parent 86 | and the projected and predicted features. 87 | """ 88 | 89 | out = super().forward(X, *args, **kwargs) 90 | z = self.projector(out["feats"]) 91 | p = self.predictor(z) 92 | return {**out, "z": z, "p": p} 93 | 94 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: 95 | """Training step for SimSiam reusing BaseModel training step. 96 | 97 | Args: 98 | batch (Sequence[Any]): a batch of data in the format of [img_indexes, [X], Y], where 99 | [X] is a list of size self.num_crops containing batches of images 100 | batch_idx (int): index of the batch 101 | 102 | Returns: 103 | torch.Tensor: total loss composed of SimSiam loss and classification loss 104 | """ 105 | 106 | out = super().training_step(batch, batch_idx) 107 | feats1, feats2 = out["feats"] 108 | 109 | z1 = self.projector(feats1) 110 | z2 = self.projector(feats2) 111 | 112 | p1 = self.predictor(z1) 113 | p2 = self.predictor(z2) 114 | 115 | # ------- contrastive loss ------- 116 | neg_cos_sim = simsiam_loss_func(p1, z2) / 2 + simsiam_loss_func(p2, z1) / 2 117 | 118 | # calculate std of features 119 | z1_std = F.normalize(z1, dim=-1).std(dim=0).mean() 120 | z2_std = F.normalize(z2, dim=-1).std(dim=0).mean() 121 | z_std = (z1_std + z2_std) / 2 122 | 123 | metrics = { 124 | "train_neg_cos_sim": neg_cos_sim, 125 | "train_z_std": z_std, 126 | } 127 | self.log_dict(metrics, on_epoch=True, sync_dist=True) 128 | 129 | out.update({"loss": out["loss"] + neg_cos_sim, "z": [z1, z2]}) 130 | return out 131 | -------------------------------------------------------------------------------- /kaizen/losses/simclr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from typing import Optional 4 | 5 | 6 | def simclr_distill_loss_func( 7 | p1: torch.Tensor, 8 | p2: torch.Tensor, 9 | z1: torch.Tensor, 10 | z2: torch.Tensor, 11 | temperature: float = 0.1, 12 | ) -> torch.Tensor: 13 | 14 | device = z1.device 15 | 16 | b = z1.size(0) 17 | 18 | p = F.normalize(torch.cat([p1, p2]), dim=-1) 19 | z = F.normalize(torch.cat([z1, z2]), dim=-1) 20 | 21 | logits = torch.einsum("if, jf -> ij", p, z) / temperature 22 | logits_max, _ = torch.max(logits, dim=1, keepdim=True) 23 | logits = logits - logits_max.detach() 24 | 25 | # positive mask are matches i, j (i from aug1, j from aug2), where i == j and matches j, i 26 | pos_mask = torch.zeros((2 * b, 2 * b), dtype=torch.bool, device=device) 27 | pos_mask.fill_diagonal_(True) 28 | 29 | # all matches excluding the main diagonal 30 | logit_mask = torch.ones_like(pos_mask, device=device) 31 | logit_mask.fill_diagonal_(True) 32 | logit_mask[:, b:].fill_diagonal_(True) 33 | logit_mask[b:, :].fill_diagonal_(True) 34 | 35 | exp_logits = torch.exp(logits) * logit_mask 36 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 37 | 38 | # compute mean of log-likelihood over positives 39 | mean_log_prob_pos = (pos_mask * log_prob).sum(1) / pos_mask.sum(1) 40 | # loss 41 | loss = -mean_log_prob_pos.mean() 42 | return loss 43 | 44 | 45 | def simclr_loss_func( 46 | z1: torch.Tensor, 47 | z2: torch.Tensor, 48 | temperature: float = 0.1, 49 | extra_pos_mask: Optional[torch.Tensor] = None, 50 | ) -> torch.Tensor: 51 | """Computes SimCLR's loss given batch of projected features z1 from view 1 and 52 | projected features z2 from view 2. 53 | 54 | Args: 55 | z1 (torch.Tensor): NxD Tensor containing projected features from view 1. 56 | z2 (torch.Tensor): NxD Tensor containing projected features from view 2. 57 | temperature (float): temperature factor for the loss. Defaults to 0.1. 58 | extra_pos_mask (Optional[torch.Tensor]): boolean mask containing extra positives other 59 | than normal across-view positives. Defaults to None. 60 | 61 | Returns: 62 | torch.Tensor: SimCLR loss. 63 | """ 64 | 65 | device = z1.device 66 | 67 | b = z1.size(0) 68 | z = torch.cat((z1, z2), dim=0) 69 | z = F.normalize(z, dim=-1) 70 | 71 | logits = torch.einsum("if, jf -> ij", z, z) / temperature 72 | logits_max, _ = torch.max(logits, dim=1, keepdim=True) 73 | logits = logits - logits_max.detach() 74 | 75 | # positive mask are matches i, j (i from aug1, j from aug2), where i == j and matches j, i 76 | pos_mask = torch.zeros((2 * b, 2 * b), dtype=torch.bool, device=device) 77 | pos_mask[:, b:].fill_diagonal_(True) 78 | pos_mask[b:, :].fill_diagonal_(True) 79 | 80 | # if we have extra "positives" 81 | if extra_pos_mask is not None: 82 | pos_mask = torch.bitwise_or(pos_mask, extra_pos_mask) 83 | 84 | # all matches excluding the main diagonal 85 | logit_mask = torch.ones_like(pos_mask, device=device).fill_diagonal_(0) 86 | 87 | exp_logits = torch.exp(logits) * logit_mask 88 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 89 | 90 | # compute mean of log-likelihood over positives 91 | mean_log_prob_pos = (pos_mask * log_prob).sum(1) / pos_mask.sum(1) 92 | # loss 93 | loss = -mean_log_prob_pos.mean() 94 | return loss 95 | 96 | 97 | def manual_simclr_loss_func( 98 | z: torch.Tensor, pos_mask: torch.Tensor, neg_mask: torch.Tensor, temperature: float = 0.1 99 | ) -> torch.Tensor: 100 | """Manually computes SimCLR's loss given batch of projected features z 101 | from different views, a positive boolean mask of all positives and 102 | a negative boolean mask of all negatives. 103 | 104 | Args: 105 | z (torch.Tensor): NxViewsxD Tensor containing projected features from the views. 106 | pos_mask (torch.Tensor): boolean mask containing all positives for z * z.T. 107 | neg_mask (torch.Tensor): boolean mask containing all negatives for z * z.T. 108 | temperature (float): temperature factor for the loss. 109 | 110 | Return: 111 | torch.Tensor: manual SimCLR loss. 112 | """ 113 | 114 | z = F.normalize(z, dim=-1) 115 | 116 | logits = torch.einsum("if, jf -> ij", z, z) / temperature 117 | logits_max, _ = torch.max(logits, dim=1, keepdim=True) 118 | logits = logits - logits_max.detach() 119 | 120 | negatives = torch.sum(torch.exp(logits) * neg_mask, dim=1, keepdim=True) 121 | exp_logits = torch.exp(logits) 122 | log_prob = torch.log(exp_logits / (exp_logits + negatives)) 123 | 124 | # compute mean of log-likelihood over positive 125 | mean_log_prob_pos = (pos_mask * log_prob).sum(1) 126 | 127 | indexes = pos_mask.sum(1) > 0 128 | pos_mask = pos_mask[indexes] 129 | mean_log_prob_pos = mean_log_prob_pos[indexes] / pos_mask.sum(1) 130 | 131 | # loss 132 | loss = -mean_log_prob_pos.mean() 133 | return loss 134 | -------------------------------------------------------------------------------- /kaizen/utils/checkpointer.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | import string 5 | import time 6 | from argparse import ArgumentParser, Namespace 7 | from pathlib import Path 8 | from typing import Optional, Union 9 | 10 | import pytorch_lightning as pl 11 | from pytorch_lightning.callbacks import Callback 12 | 13 | 14 | def random_string(letter_count=4, digit_count=4): 15 | tmp_random = random.Random(time.time()) 16 | rand_str = "".join((tmp_random.choice(string.ascii_lowercase) for x in range(letter_count))) 17 | rand_str += "".join((tmp_random.choice(string.digits) for x in range(digit_count))) 18 | rand_str = list(rand_str) 19 | tmp_random.shuffle(rand_str) 20 | return "".join(rand_str) 21 | 22 | 23 | class Checkpointer(Callback): 24 | def __init__( 25 | self, 26 | args: Namespace, 27 | logdir: Union[str, Path] = Path("trained_models"), 28 | frequency: int = 1, 29 | keep_previous_checkpoints: bool = False, 30 | ): 31 | """Custom checkpointer callback that stores checkpoints in an easier to access way. 32 | 33 | Args: 34 | args (Namespace): namespace object containing at least an attribute name. 35 | logdir (Union[str, Path], optional): base directory to store checkpoints. 36 | Defaults to "trained_models". 37 | frequency (int, optional): number of epochs between each checkpoint. Defaults to 1. 38 | keep_previous_checkpoints (bool, optional): whether to keep previous checkpoints or not. 39 | Defaults to False. 40 | """ 41 | 42 | super().__init__() 43 | 44 | assert "task" not in args.name 45 | 46 | self.args = args 47 | self.logdir = Path(logdir) 48 | self.frequency = frequency 49 | self.keep_previous_checkpoints = keep_previous_checkpoints 50 | 51 | @staticmethod 52 | def add_checkpointer_args(parent_parser: ArgumentParser): 53 | """Adds user-required arguments to a parser. 54 | 55 | Args: 56 | parent_parser (ArgumentParser): parser to add new args to. 57 | """ 58 | 59 | parser = parent_parser.add_argument_group("checkpointer") 60 | parser.add_argument("--checkpoint_dir", default=Path("trained_models"), type=Path) 61 | parser.add_argument("--checkpoint_frequency", default=1, type=int) 62 | return parent_parser 63 | 64 | def initial_setup(self, trainer: pl.Trainer): 65 | """Creates the directories and does the initial setup needed. 66 | 67 | Args: 68 | trainer (pl.Trainer): pytorch lightning trainer object. 69 | """ 70 | 71 | if trainer.logger is None: 72 | if os.path.exists(self.logdir): 73 | existing_versions = set(os.listdir(self.logdir)) 74 | else: 75 | existing_versions = set() 76 | version = "offline-" + random_string() 77 | while version in existing_versions: 78 | version = "offline-" + random_string() 79 | else: 80 | version = str(trainer.logger.version) 81 | if version is not None: 82 | task_idx = getattr(self.args, "task_idx", "_all") 83 | self.path = self.logdir / f"task{task_idx}-{version}" 84 | self.ckpt_placeholder = f"{self.args.name}" + "-task{}-ep={}" + f"-{version}.ckpt" 85 | self.last_ckpt: Optional[str] = None 86 | 87 | # create logging dirs 88 | if trainer.is_global_zero: 89 | os.makedirs(self.path, exist_ok=True) 90 | 91 | def save_args(self, trainer: pl.Trainer): 92 | """Stores arguments into a json file. 93 | 94 | Args: 95 | trainer (pl.Trainer): pytorch lightning trainer object. 96 | """ 97 | 98 | if trainer.is_global_zero: 99 | args = vars(self.args) 100 | self.json_path = self.path / "args.json" 101 | json.dump(args, open(self.json_path, "w"), default=lambda o: "") 102 | 103 | def save(self, trainer: pl.Trainer): 104 | """Saves current checkpoint. 105 | 106 | Args: 107 | trainer (pl.Trainer): pytorch lightning trainer object. 108 | """ 109 | 110 | if trainer.is_global_zero and not trainer.sanity_checking: 111 | epoch = trainer.current_epoch # type: ignore 112 | task_idx = getattr(self.args, "task_idx", "_all") 113 | ckpt = self.path / self.ckpt_placeholder.format(task_idx, epoch) 114 | trainer.save_checkpoint(ckpt) 115 | 116 | if self.last_ckpt and self.last_ckpt != ckpt and not self.keep_previous_checkpoints: 117 | if os.path.exists(self.last_ckpt): 118 | os.remove(self.last_ckpt) 119 | 120 | with open(self.logdir / "last_checkpoint.txt", "w") as f: 121 | f.write(str(ckpt) + "\n" + str(self.json_path)) 122 | 123 | self.last_ckpt = ckpt 124 | 125 | def on_train_start(self, trainer: pl.Trainer, _): 126 | """Executes initial setup and saves arguments. 127 | 128 | Args: 129 | trainer (pl.Trainer): pytorch lightning trainer object. 130 | """ 131 | 132 | self.initial_setup(trainer) 133 | self.save_args(trainer) 134 | 135 | def on_train_epoch_end(self, trainer: pl.Trainer, _): 136 | """Tries to save current checkpoint at the end of each validation epoch. 137 | 138 | Args: 139 | trainer (pl.Trainer): pytorch lightning trainer object. 140 | """ 141 | 142 | epoch = trainer.current_epoch # type: ignore 143 | if epoch % self.frequency == 0: 144 | self.save(trainer) 145 | -------------------------------------------------------------------------------- /kaizen/methods/byol.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Any, Dict, List, Sequence, Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from kaizen.losses.byol import byol_loss_func 8 | from kaizen.methods.base import BaseMomentumModel 9 | from kaizen.utils.momentum import initialize_momentum_params 10 | 11 | 12 | class BYOL(BaseMomentumModel): 13 | def __init__( 14 | self, 15 | output_dim: int, 16 | proj_hidden_dim: int, 17 | pred_hidden_dim: int, 18 | **kwargs, 19 | ): 20 | """Implements BYOL (https://arxiv.org/abs/2006.07733). 21 | 22 | Args: 23 | output_dim (int): number of dimensions of projected features. 24 | proj_hidden_dim (int): number of neurons of the hidden layers of the projector. 25 | pred_hidden_dim (int): number of neurons of the hidden layers of the predictor. 26 | """ 27 | 28 | super().__init__(**kwargs) 29 | 30 | # projector 31 | self.projector = nn.Sequential( 32 | nn.Linear(self.features_dim, proj_hidden_dim), 33 | nn.BatchNorm1d(proj_hidden_dim), 34 | nn.ReLU(), 35 | nn.Linear(proj_hidden_dim, output_dim), 36 | ) 37 | 38 | # momentum projector 39 | self.momentum_projector = nn.Sequential( 40 | nn.Linear(self.features_dim, proj_hidden_dim), 41 | nn.BatchNorm1d(proj_hidden_dim), 42 | nn.ReLU(), 43 | nn.Linear(proj_hidden_dim, output_dim), 44 | ) 45 | initialize_momentum_params(self.projector, self.momentum_projector) 46 | 47 | # predictor 48 | self.predictor = nn.Sequential( 49 | nn.Linear(output_dim, pred_hidden_dim), 50 | nn.BatchNorm1d(pred_hidden_dim), 51 | nn.ReLU(), 52 | nn.Linear(pred_hidden_dim, output_dim), 53 | ) 54 | 55 | @staticmethod 56 | def add_model_specific_args(parent_parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 57 | parent_parser = super(BYOL, BYOL).add_model_specific_args(parent_parser) 58 | parser = parent_parser.add_argument_group("byol") 59 | 60 | # projector 61 | parser.add_argument("--output_dim", type=int, default=256) 62 | parser.add_argument("--proj_hidden_dim", type=int, default=2048) 63 | 64 | # predictor 65 | parser.add_argument("--pred_hidden_dim", type=int, default=512) 66 | 67 | return parent_parser 68 | 69 | @property 70 | def learnable_params(self) -> List[dict]: 71 | """Adds projector and predictor parameters to the parent's learnable parameters. 72 | 73 | Returns: 74 | List[dict]: list of learnable parameters. 75 | """ 76 | 77 | extra_learnable_params = [ 78 | {"params": self.projector.parameters()}, 79 | {"params": self.predictor.parameters()}, 80 | ] 81 | return super().learnable_params + extra_learnable_params 82 | 83 | @property 84 | def momentum_pairs(self) -> List[Tuple[Any, Any]]: 85 | """Adds (projector, momentum_projector) to the parent's momentum pairs. 86 | 87 | Returns: 88 | List[Tuple[Any, Any]]: list of momentum pairs. 89 | """ 90 | 91 | extra_momentum_pairs = [(self.projector, self.momentum_projector)] 92 | return super().momentum_pairs + extra_momentum_pairs 93 | 94 | def forward(self, X: torch.Tensor, *args, **kwargs) -> Dict[str, Any]: 95 | """Performs forward pass of the online encoder (encoder, projector and predictor). 96 | 97 | Args: 98 | X (torch.Tensor): batch of images in tensor format. 99 | 100 | Returns: 101 | Dict[str, Any]: a dict containing the outputs of the parent and the logits of the head. 102 | """ 103 | 104 | out = super().forward(X, *args, **kwargs) 105 | z = self.projector(out["feats"]) 106 | p = self.predictor(z) 107 | return {**out, "z": z, "p": p} 108 | 109 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: 110 | """Training step for BYOL reusing BaseModel training step. 111 | 112 | Args: 113 | batch (Sequence[Any]): a batch of data in the format of [img_indexes, [X], Y], where 114 | [X] is a list of size self.num_crops containing batches of images. 115 | batch_idx (int): index of the batch. 116 | 117 | Returns: 118 | torch.Tensor: total loss composed of BYOL and classification loss. 119 | """ 120 | 121 | out = super().training_step(batch, batch_idx) 122 | feats1, feats2 = out["feats"] 123 | momentum_feats1, momentum_feats2 = out["momentum_feats"] 124 | 125 | z1 = self.projector(feats1) 126 | z2 = self.projector(feats2) 127 | p1 = self.predictor(z1) 128 | p2 = self.predictor(z2) 129 | 130 | # forward momentum encoder 131 | with torch.no_grad(): 132 | z1_momentum = self.momentum_projector(momentum_feats1) 133 | z2_momentum = self.momentum_projector(momentum_feats2) 134 | 135 | # ------- contrastive loss ------- 136 | neg_cos_sim = byol_loss_func(p1, z2_momentum) + byol_loss_func(p2, z1_momentum) 137 | 138 | # calculate std of features 139 | z1_std = F.normalize(z1, dim=-1).std(dim=0).mean() 140 | z2_std = F.normalize(z2, dim=-1).std(dim=0).mean() 141 | z_std = (z1_std + z2_std) / 2 142 | 143 | metrics = { 144 | "train_neg_cos_sim": neg_cos_sim, 145 | "train_z_std": z_std, 146 | } 147 | self.log_dict(metrics, on_epoch=True, sync_dist=True) 148 | 149 | out.update({"loss": out["loss"] + neg_cos_sim, "z": [z1, z2]}) 150 | return out 151 | -------------------------------------------------------------------------------- /evaluate_folder.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import subprocess 3 | import glob 4 | import os 5 | import re 6 | import multiprocessing as mp 7 | import signal 8 | import time 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--folder", type=str, required=True) 12 | parser.add_argument("--script", type=str, required=True) 13 | parser.add_argument("--start_task", type=int, default=0) 14 | parser.add_argument("--num_tasks", type=int, default=5) 15 | parser.add_argument("--gpus", type=str, default="0,1", help="Comma-separated list of gpu indices") 16 | parser.add_argument("--tag", type=str, default="") 17 | 18 | 19 | def consumer(consumer_id, job_queue, results_queue, setup_args=None): 20 | job = None 21 | try: 22 | print(f'{consumer_id} {os.getpid()} Starting consumer') 23 | results_queue.put({ 24 | 'type': 'init', 25 | 'data': os.getpid() 26 | }) 27 | # consumer_vars = consumer_setup(consumer_id, setup_args) 28 | while(True): 29 | job = job_queue.get(timeout=60) 30 | print(f"{consumer_id} {os.getpid()} get job") # {job} 31 | if job is None: 32 | break 33 | else: 34 | return_value = process_job(setup_args, job) 35 | results_queue.put({ 36 | 'type': 'job_finished', 37 | 'data': return_value 38 | }) 39 | print(f"{consumer_id} {os.getpid()} exitting loop") 40 | except Exception as e: 41 | print("DEBUG: There is some issue with the below arg settings. \n Copy the args and recreate the error by running contrastive_training.py for further debugging!") 42 | print(e) 43 | print(job) 44 | finally: 45 | if job is not None: 46 | job_queue.put(job) 47 | results_queue.put(None) 48 | print(f'Stopping consumer {consumer_id} {os.getpid()}') 49 | 50 | def process_job(consumer_vars, job): 51 | new_job = consumer_vars + " " + job 52 | run_command(new_job) 53 | 54 | def run_command(command): 55 | p = subprocess.Popen(command, shell=True) 56 | p.wait() 57 | 58 | def main(args): 59 | gpus = args.gpus.split(",") 60 | 61 | task_folders = glob.glob(os.path.join(args.folder, "*")) 62 | print(task_folders) 63 | task_folder_lookup = {} 64 | for folder in task_folders: 65 | re_match = re.search("task(?P\d*)-.*", os.path.basename(folder)) 66 | if re_match is not None: 67 | re_match_dict = re_match.groupdict() 68 | if "task_idx" in re_match_dict: 69 | task_idx = int(re_match_dict["task_idx"]) 70 | models = sorted(glob.glob(os.path.join(folder, '*.ckpt'))) 71 | if len(models) > 0: 72 | task_folder_lookup[task_idx] = models[-1] 73 | 74 | all_jobs = [] 75 | for task_idx in range(args.start_task, args.num_tasks): 76 | if task_idx in task_folder_lookup: 77 | job = f'TAG="T{task_idx}-{args.tag}" TASK_IDX={task_idx} NUM_TASKS={args.num_tasks} \ 78 | PRETRAINED_PATH="{task_folder_lookup[task_idx]}" \ 79 | bash {args.script}' 80 | all_jobs.append(job) 81 | # run_command(job) 82 | else: 83 | print("==============") 84 | print(f"[ERR] Cannot find model for task {task_idx}") 85 | print("==============") 86 | 87 | job_queue = mp.Queue() 88 | results_queue = mp.Queue() 89 | processes = [mp.Process(target=consumer, args=(i, job_queue, results_queue, f"CUDA_VISIBLE_DEVICES={gpus[i]}")) for i in range(len(gpus))] 90 | process_pids = [] 91 | active_consumer_counter = len(processes) 92 | finished_job_counter = 0 93 | try: 94 | print("Putting jobs...") 95 | for job in all_jobs: 96 | job_queue.put(job) 97 | 98 | print(f"{os.getpid()} Server - starting consumers") 99 | for p in processes: 100 | p.start() 101 | for _ in range(len(processes)): 102 | job_queue.put(None) 103 | print(f"{os.getpid()} Server - finished putting jobs") 104 | 105 | while(True): 106 | job_results = results_queue.get() 107 | print("Job results", job_results) 108 | if job_results is None: 109 | active_consumer_counter -= 1 110 | if active_consumer_counter == 0: 111 | break 112 | elif job_results['type'] == 'init': 113 | process_pids.append(job_results['data']) 114 | elif job_results['type'] == 'job_finished': 115 | finished_job_counter += 1 116 | 117 | print('Closing workers') 118 | for p in processes: 119 | p.join(60) 120 | except KeyboardInterrupt: 121 | print("Interrupted from Keyboard") 122 | finally: 123 | print("Terminating Processes", processes) 124 | for p in processes: 125 | try: 126 | p.terminate() 127 | except Exception as e: 128 | print(f"Unable to terminate process {p}, processes might still exist.", e) 129 | print("Killing Processes", process_pids) 130 | for pid in process_pids: 131 | try: 132 | # os.kill(pid, signal.SIGTERM) 133 | os.kill(pid, signal.SIGKILL) 134 | except Exception as e: 135 | print(f"Unable to kill process {pid}, processes might still exist.", e) 136 | try: 137 | job_queue.close() 138 | results_queue.close() 139 | except Exception as e: 140 | print("Unable to close job queues, processes might still be open", e) 141 | 142 | print(f'Finished, processed {finished_job_counter} jobs') 143 | 144 | if __name__ == "__main__": 145 | args, unknown_args = parser.parse_known_args() 146 | main(args) 147 | 148 | -------------------------------------------------------------------------------- /kaizen/distiller_factories/knowledge.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Any, List, Sequence 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | from kaizen.args.utils import strtobool 8 | 9 | def cross_entropy(preds, targets): 10 | return -torch.mean( 11 | torch.sum(F.softmax(targets, dim=-1) * torch.log_softmax(preds, dim=-1), dim=-1) 12 | ) 13 | 14 | def knowledge_distill_factory( 15 | Method=object, class_tag="", 16 | distill_current_key="z", distill_frozen_key="frozen_z", output_dim=256 17 | ): 18 | distill_lamb_name = f"{class_tag}_distill_lamb" 19 | distill_proj_hidden_dim_name = f"{class_tag}_distill_proj_hidden_dim" 20 | distill_temperature_name = f"{class_tag}_distill_temperature" 21 | distill_no_predictior_name = f"{class_tag}_distill_no_predictior" 22 | 23 | frozen_prototypes_name = f"{class_tag}_frozen_prototypes" 24 | distill_predictor_name = f"{class_tag}_distill_predictor" 25 | distill_prototypes_name = f"{class_tag}_distill_prototypes" 26 | class KnowledgeDistillWrapper(Method): 27 | def __init__(self, **kwargs): 28 | distill_lamb: float = kwargs.pop(distill_lamb_name) 29 | distill_proj_hidden_dim: int = kwargs.pop(distill_proj_hidden_dim_name) 30 | distill_temperature: float = kwargs.pop(distill_temperature_name) 31 | distill_no_predictior = kwargs.pop(distill_no_predictior_name) 32 | super().__init__(**kwargs) 33 | 34 | setattr(self, distill_lamb_name, distill_lamb) 35 | setattr(self, distill_proj_hidden_dim_name, distill_proj_hidden_dim) 36 | setattr(self, distill_temperature_name, distill_temperature) 37 | setattr(self, distill_no_predictior_name, distill_no_predictior) 38 | # TODO: Allow different num_prototypes for different distillers 39 | # TODO: Verify that these prototypes can be used for classifier distillation 40 | num_prototypes = kwargs["num_prototypes"] 41 | 42 | setattr(self, frozen_prototypes_name, nn.utils.weight_norm( 43 | nn.Linear(output_dim, num_prototypes, bias=False) 44 | )) 45 | for frozen_pg, pg in zip( 46 | getattr(self, frozen_prototypes_name).parameters(), self.prototypes.parameters() # TODO: Check this 47 | ): 48 | frozen_pg.data.copy_(pg.data) 49 | frozen_pg.requires_grad = False 50 | 51 | if distill_no_predictior: 52 | setattr(self, distill_predictor_name, nn.Identity()) 53 | else: 54 | setattr(self, distill_predictor_name, nn.Sequential( 55 | nn.Linear(output_dim, distill_proj_hidden_dim), 56 | nn.BatchNorm1d(distill_proj_hidden_dim), 57 | nn.ReLU(), 58 | nn.Linear(distill_proj_hidden_dim, output_dim), 59 | )) 60 | 61 | 62 | setattr(self, distill_prototypes_name, nn.utils.weight_norm( 63 | nn.Linear(output_dim, num_prototypes, bias=False) 64 | )) 65 | 66 | @staticmethod 67 | def add_model_specific_args( 68 | parent_parser: argparse.ArgumentParser, 69 | ) -> argparse.ArgumentParser: 70 | parser = parent_parser.add_argument_group(f"knowledge_{class_tag}_distiller") 71 | 72 | parser.add_argument(f"--{distill_lamb_name}", type=float, default=1) 73 | parser.add_argument(f"--{distill_proj_hidden_dim_name}", type=int, default=2048) 74 | parser.add_argument(f"--{distill_temperature_name}", type=float, default=0.1) 75 | parser.add_argument(f"--{distill_no_predictior_name}", type=strtobool, default=True) 76 | 77 | return parent_parser 78 | 79 | @property 80 | def learnable_params(self) -> List[dict]: 81 | """Adds distill predictor parameters to the parent's learnable parameters. 82 | 83 | Returns: 84 | List[dict]: list of learnable parameters. 85 | """ 86 | 87 | extra_learnable_params = [ 88 | {"params": getattr(self, distill_predictor_name).parameters()}, 89 | {"params": getattr(self, distill_prototypes_name).parameters()}, 90 | ] 91 | return super().learnable_params + extra_learnable_params 92 | 93 | def on_train_start(self): 94 | super().on_train_start() 95 | 96 | if self.current_task_idx > 0: 97 | for frozen_pg, pg in zip( 98 | getattr(self, frozen_prototypes_name).parameters(), self.prototypes.parameters() 99 | ): 100 | # TODO: logits and prototypes have different shape 101 | frozen_pg.data.copy_(pg.data) 102 | frozen_pg.requires_grad = False 103 | 104 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: 105 | out = super().training_step(batch, batch_idx) 106 | z1, z2 = out[distill_current_key] 107 | frozen_z1, frozen_z2 = out[distill_frozen_key] 108 | 109 | with torch.no_grad(): 110 | frozen_z1 = F.normalize(frozen_z1) 111 | frozen_z2 = F.normalize(frozen_z2) 112 | frozen_p1 = getattr(self, frozen_prototypes_name)(frozen_z1) / getattr(self, distill_temperature_name) 113 | frozen_p2 = getattr(self, frozen_prototypes_name)(frozen_z2) / getattr(self, distill_temperature_name) 114 | 115 | distill_z1 = F.normalize(self.distill_predictor(z1)) 116 | distill_z2 = F.normalize(self.distill_predictor(z2)) 117 | distill_p1 = getattr(self, distill_prototypes_name)(distill_z1) / getattr(self, distill_temperature_name) 118 | distill_p2 = getattr(self, distill_prototypes_name)(distill_z2) / getattr(self, distill_temperature_name) 119 | 120 | distill_loss = ( 121 | cross_entropy(distill_p1, frozen_p1) + cross_entropy(distill_p2, frozen_p2) 122 | ) / 2 123 | 124 | self.log(f"train_{class_tag}_knowledge_distill_loss", distill_loss, on_epoch=True, sync_dist=True) 125 | 126 | out["loss"] += getattr(self, distill_lamb_name) * distill_loss 127 | return out 128 | 129 | return KnowledgeDistillWrapper 130 | -------------------------------------------------------------------------------- /kaizen/methods/ressl.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Any, Dict, List, Sequence, Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from kaizen.losses.ressl import ressl_loss_func 8 | from kaizen.methods.base import BaseMomentumModel 9 | from kaizen.utils.gather_layer import gather 10 | from kaizen.utils.momentum import initialize_momentum_params 11 | 12 | 13 | class ReSSL(BaseMomentumModel): 14 | def __init__( 15 | self, 16 | output_dim: int, 17 | proj_hidden_dim: int, 18 | temperature_q: float, 19 | temperature_k: float, 20 | queue_size: int, 21 | **kwargs, 22 | ): 23 | """Implements ReSSL (https://arxiv.org/abs/2107.09282v1). 24 | 25 | Args: 26 | output_dim (int): number of dimensions of projected features. 27 | proj_hidden_dim (int): number of neurons of the hidden layers of the projector. 28 | pred_hidden_dim (int): number of neurons of the hidden layers of the predictor. 29 | temperature_q (float): temperature for the contrastive augmentations. 30 | temperature_k (float): temperature for the weak augmentation. 31 | """ 32 | 33 | super().__init__(**kwargs) 34 | 35 | # projector 36 | self.projector = nn.Sequential( 37 | nn.Linear(self.features_dim, proj_hidden_dim), 38 | nn.ReLU(), 39 | nn.Linear(proj_hidden_dim, output_dim), 40 | ) 41 | 42 | # momentum projector 43 | self.momentum_projector = nn.Sequential( 44 | nn.Linear(self.features_dim, proj_hidden_dim), 45 | nn.ReLU(), 46 | nn.Linear(proj_hidden_dim, output_dim), 47 | ) 48 | initialize_momentum_params(self.projector, self.momentum_projector) 49 | 50 | self.temperature_q = temperature_q 51 | self.temperature_k = temperature_k 52 | self.queue_size = queue_size 53 | 54 | # queue 55 | self.register_buffer("queue", torch.randn(self.queue_size, output_dim)) 56 | self.queue = F.normalize(self.queue, dim=1) 57 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 58 | 59 | @staticmethod 60 | def add_model_specific_args(parent_parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 61 | parent_parser = super(ReSSL, ReSSL).add_model_specific_args(parent_parser) 62 | parser = parent_parser.add_argument_group("ressl") 63 | 64 | # projector 65 | parser.add_argument("--output_dim", type=int, default=256) 66 | parser.add_argument("--proj_hidden_dim", type=int, default=2048) 67 | 68 | # queue settings 69 | parser.add_argument("--queue_size", default=65536, type=int) 70 | 71 | # parameters 72 | parser.add_argument("--temperature_q", type=float, default=0.1) 73 | parser.add_argument("--temperature_k", type=float, default=0.04) 74 | 75 | return parent_parser 76 | 77 | @property 78 | def learnable_params(self) -> List[dict]: 79 | """Adds projector parameters to the parent's learnable parameters. 80 | 81 | Returns: 82 | List[dict]: list of learnable parameters. 83 | """ 84 | 85 | extra_learnable_params = [ 86 | {"params": self.projector.parameters()}, 87 | ] 88 | return super().learnable_params + extra_learnable_params 89 | 90 | @property 91 | def momentum_pairs(self) -> List[Tuple[Any, Any]]: 92 | """Adds (projector, momentum_projector) to the parent's momentum pairs. 93 | 94 | Returns: 95 | List[Tuple[Any, Any]]: list of momentum pairs. 96 | """ 97 | 98 | extra_momentum_pairs = [(self.projector, self.momentum_projector)] 99 | return super().momentum_pairs + extra_momentum_pairs 100 | 101 | @torch.no_grad() 102 | def dequeue_and_enqueue(self, k: torch.Tensor): 103 | """Adds new samples and removes old samples from the queue in a fifo manner. 104 | 105 | Args: 106 | z (torch.Tensor): batch of projected features. 107 | """ 108 | 109 | k = gather(k) 110 | 111 | batch_size = k.shape[0] 112 | 113 | ptr = int(self.queue_ptr) # type: ignore 114 | assert self.queue_size % batch_size == 0 115 | 116 | self.queue[ptr : ptr + batch_size, :] = k 117 | ptr = (ptr + batch_size) % self.queue_size 118 | 119 | self.queue_ptr[0] = ptr # type: ignore 120 | 121 | def forward(self, X: torch.Tensor, *args, **kwargs) -> Dict[str, Any]: 122 | """Performs forward pass of the online encoder (encoder, projector and predictor). 123 | 124 | Args: 125 | X (torch.Tensor): batch of images in tensor format. 126 | 127 | Returns: 128 | Dict[str, Any]: a dict containing the outputs of the parent and the logits of the head. 129 | """ 130 | 131 | out = super().forward(X, *args, **kwargs) 132 | q = F.normalize(self.projector(out["feats"]), dim=-1) 133 | return {**out, "q": q} 134 | 135 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: 136 | """Training step for BYOL reusing BaseModel training step. 137 | 138 | Args: 139 | batch (Sequence[Any]): a batch of data in the format of [img_indexes, [X], Y], where 140 | [X] is a list of size self.num_crops containing batches of images. 141 | batch_idx (int): index of the batch. 142 | 143 | Returns: 144 | torch.Tensor: total loss composed of BYOL and classification loss. 145 | """ 146 | 147 | out = super().training_step(batch, batch_idx) 148 | class_loss = out["loss"] 149 | feats1, _ = out["feats"] 150 | _, momentum_feats2 = out["momentum_feats"] 151 | 152 | q = self.projector(feats1) 153 | 154 | # forward momentum encoder 155 | with torch.no_grad(): 156 | k = self.momentum_projector(momentum_feats2) 157 | 158 | q = F.normalize(q, dim=-1) 159 | k = F.normalize(k, dim=-1) 160 | 161 | # ------- contrastive loss ------- 162 | queue = self.queue.clone().detach() 163 | ressl_loss = ressl_loss_func(q, k, queue, self.temperature_q, self.temperature_k) 164 | 165 | self.log("ressl_loss", ressl_loss, on_epoch=True, sync_dist=True) 166 | 167 | # dequeue and enqueue 168 | self.dequeue_and_enqueue(k) 169 | 170 | return ressl_loss + class_loss 171 | -------------------------------------------------------------------------------- /kaizen/methods/simclr.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Any, Dict, List, Sequence 3 | 4 | import torch 5 | import torch.nn as nn 6 | from einops import repeat 7 | from kaizen.losses.simclr import manual_simclr_loss_func, simclr_loss_func 8 | from kaizen.methods.base import BaseModel 9 | 10 | 11 | class SimCLR(BaseModel): 12 | def __init__( 13 | self, 14 | output_dim: int, 15 | proj_hidden_dim: int, 16 | temperature: float, 17 | supervised: bool = False, 18 | **kwargs, 19 | ): 20 | """Implements SimCLR (https://arxiv.org/abs/2002.05709). 21 | 22 | Args: 23 | output_dim (int): number of dimensions of the projected features. 24 | proj_hidden_dim (int): number of neurons in the hidden layers of the projector. 25 | temperature (float): temperature for the softmax in the contrastive loss. 26 | supervised (bool): whether or not to use supervised contrastive loss. Defaults to False. 27 | """ 28 | 29 | super().__init__(**kwargs) 30 | 31 | self.temperature = temperature 32 | self.supervised = supervised 33 | 34 | # projector 35 | self.projector = nn.Sequential( 36 | nn.Linear(self.features_dim, proj_hidden_dim), 37 | nn.ReLU(), 38 | nn.Linear(proj_hidden_dim, output_dim), 39 | ) 40 | 41 | @staticmethod 42 | def add_model_specific_args(parent_parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 43 | parent_parser = super(SimCLR, SimCLR).add_model_specific_args(parent_parser) 44 | parser = parent_parser.add_argument_group("simclr") 45 | 46 | # projector 47 | parser.add_argument("--output_dim", type=int, default=128) 48 | parser.add_argument("--proj_hidden_dim", type=int, default=2048) 49 | 50 | # parameters 51 | parser.add_argument("--temperature", type=float, default=0.1) 52 | 53 | # supervised-simclr 54 | parser.add_argument("--supervised", action="store_true") 55 | return parent_parser 56 | 57 | @property 58 | def learnable_params(self) -> List[dict]: 59 | """Adds projector parameters to the parent's learnable parameters. 60 | 61 | Returns: 62 | List[dict]: list of learnable parameters. 63 | """ 64 | 65 | extra_learnable_params = [{"params": self.projector.parameters()}] 66 | return super().learnable_params + extra_learnable_params 67 | 68 | def forward(self, X: torch.tensor, *args, **kwargs) -> Dict[str, Any]: 69 | """Performs the forward pass of the encoder, the projector and the predictor. 70 | 71 | Args: 72 | X (torch.Tensor): a batch of images in the tensor format. 73 | 74 | Returns: 75 | Dict[str, Any]: 76 | a dict containing the outputs of the parent 77 | and the projected and predicted features. 78 | """ 79 | 80 | out = super().forward(X, *args, **kwargs) 81 | z = self.projector(out["feats"]) 82 | return {**out, "z": z} 83 | 84 | @torch.no_grad() 85 | def gen_extra_positives_gt(self, Y: torch.Tensor) -> torch.Tensor: 86 | """Generates extra positives for supervised contrastive learning. 87 | 88 | Args: 89 | Y (torch.Tensor): labels of the samples of the batch. 90 | 91 | Returns: 92 | torch.Tensor: matrix with extra positives generated using the labels. 93 | """ 94 | 95 | if self.multicrop: 96 | n_augs = self.num_crops + self.num_small_crops 97 | else: 98 | n_augs = 2 99 | labels_matrix = repeat(Y, "b -> c (d b)", c=n_augs * Y.size(0), d=n_augs) 100 | labels_matrix = (labels_matrix == labels_matrix.t()).fill_diagonal_(False) 101 | return labels_matrix 102 | 103 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: 104 | """Training step for SimCLR and supervised SimCLR reusing BaseModel training step. 105 | 106 | Args: 107 | batch (Sequence[Any]): a batch of data in the format of [img_indexes, [X], Y], where 108 | [X] is a list of size self.num_crops containing batches of images. 109 | batch_idx (int): index of the batch. 110 | 111 | Returns: 112 | torch.Tensor: total loss composed of SimCLR loss and classification loss. 113 | """ 114 | 115 | indexes, *_, target = batch[f"task{self.current_task_idx}"] 116 | 117 | out = super().training_step(batch, batch_idx) 118 | 119 | if self.multicrop: 120 | n_augs = self.num_crops + self.num_small_crops 121 | 122 | feats = out["feats"] 123 | 124 | z = torch.cat([self.projector(f) for f in feats]) 125 | 126 | # ------- contrastive loss ------- 127 | if self.supervised: 128 | pos_mask = self.gen_extra_positives_gt(target) 129 | else: 130 | index_matrix = repeat(indexes, "b -> c (d b)", c=n_augs * indexes.size(0), d=n_augs) 131 | pos_mask = (index_matrix == index_matrix.t()).fill_diagonal_(False) 132 | neg_mask = (~pos_mask).fill_diagonal_(False) 133 | 134 | nce_loss = manual_simclr_loss_func( 135 | z, 136 | pos_mask=pos_mask, 137 | neg_mask=neg_mask, 138 | temperature=self.temperature, 139 | ) 140 | else: 141 | feats1, feats2 = out["feats"] 142 | 143 | z1 = self.projector(feats1) 144 | z2 = self.projector(feats2) 145 | 146 | # ------- contrastive loss ------- 147 | if self.supervised: 148 | pos_mask = self.gen_extra_positives_gt(target) 149 | nce_loss = simclr_loss_func( 150 | z1, z2, extra_pos_mask=pos_mask, temperature=self.temperature 151 | ) 152 | else: 153 | nce_loss = simclr_loss_func(z1, z2, temperature=self.temperature) 154 | 155 | # compute number of extra positives 156 | n_positives = ( 157 | (pos_mask != 0).sum().float() 158 | if self.supervised 159 | else torch.tensor(0.0, device=self.device) 160 | ) 161 | 162 | metrics = { 163 | "train_nce_loss": nce_loss, 164 | "train_n_positives": n_positives, 165 | } 166 | self.log_dict(metrics, on_epoch=True, sync_dist=True) 167 | 168 | out.update({"loss": out["loss"] + nce_loss, "z": [z1, z2]}) 169 | return out 170 | -------------------------------------------------------------------------------- /kaizen/utils/kmeans.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Sequence 2 | 3 | import numpy as np 4 | import torch 5 | import torch.distributed as dist 6 | import torch.nn.functional as F 7 | from scipy.sparse import csr_matrix 8 | 9 | 10 | class KMeans: 11 | def __init__( 12 | self, 13 | world_size: int, 14 | rank: int, 15 | num_crops: int, 16 | dataset_size: int, 17 | proj_features_dim: int, 18 | num_prototypes: int, 19 | kmeans_iters: int = 10, 20 | ): 21 | """Class that performs K-Means on the hypersphere. 22 | 23 | Args: 24 | world_size (int): world size. 25 | rank (int): rank of the current process. 26 | num_crops (int): number of crops. 27 | dataset_size (int): total size of the dataset (number of samples). 28 | proj_features_dim (int): number of dimensions of the projected features. 29 | num_prototypes (int): number of prototypes. 30 | kmeans_iters (int, optional): number of iterations for the k-means clustering. 31 | Defaults to 10. 32 | """ 33 | self.world_size = world_size 34 | self.rank = rank 35 | self.num_crops = num_crops 36 | self.dataset_size = dataset_size 37 | self.proj_features_dim = proj_features_dim 38 | self.num_prototypes = num_prototypes 39 | self.kmeans_iters = kmeans_iters 40 | 41 | @staticmethod 42 | def get_indices_sparse(data: np.ndarray): 43 | cols = np.arange(data.size) 44 | M = csr_matrix((cols, (data.ravel(), cols)), shape=(int(data.max()) + 1, data.size)) 45 | return [np.unravel_index(row.data, data.shape) for row in M] 46 | 47 | def cluster_memory( 48 | self, 49 | local_memory_index: torch.Tensor, 50 | local_memory_embeddings: torch.Tensor, 51 | ) -> Sequence[Any]: 52 | """Performs K-Means clustering on the hypersphere and returns centroids and 53 | assignments for each sample. 54 | 55 | Args: 56 | local_memory_index (torch.Tensor): memory bank cointaining indices of the 57 | samples. 58 | local_memory_embeddings (torch.Tensor): memory bank cointaining embeddings 59 | of the samples. 60 | 61 | Returns: 62 | Sequence[Any]: assignments and centroids. 63 | """ 64 | j = 0 65 | device = local_memory_embeddings.device 66 | assignments = -torch.ones(len(self.num_prototypes), self.dataset_size).long() 67 | centroids_list = [] 68 | with torch.no_grad(): 69 | for i_K, K in enumerate(self.num_prototypes): 70 | # run distributed k-means 71 | 72 | # init centroids with elements from memory bank of rank 0 73 | centroids = torch.empty(K, self.proj_features_dim).to(device, non_blocking=True) 74 | if self.rank == 0: 75 | random_idx = torch.randperm(len(local_memory_embeddings[j]))[:K] 76 | assert len(random_idx) >= K, "please reduce the number of centroids" 77 | centroids = local_memory_embeddings[j][random_idx] 78 | if dist.is_available() and dist.is_initialized(): 79 | dist.broadcast(centroids, 0) 80 | 81 | for n_iter in range(self.kmeans_iters + 1): 82 | 83 | # E step 84 | dot_products = torch.mm(local_memory_embeddings[j], centroids.t()) 85 | _, local_assignments = dot_products.max(dim=1) 86 | 87 | # finish 88 | if n_iter == self.kmeans_iters: 89 | break 90 | 91 | # M step 92 | where_helper = self.get_indices_sparse(local_assignments.cpu().numpy()) 93 | counts = torch.zeros(K).to(device, non_blocking=True).int() 94 | emb_sums = torch.zeros(K, self.proj_features_dim).to(device, non_blocking=True) 95 | for k in range(len(where_helper)): 96 | if len(where_helper[k][0]) > 0: 97 | emb_sums[k] = torch.sum( 98 | local_memory_embeddings[j][where_helper[k][0]], 99 | dim=0, 100 | ) 101 | counts[k] = len(where_helper[k][0]) 102 | if dist.is_available() and dist.is_initialized(): 103 | dist.all_reduce(counts) 104 | dist.all_reduce(emb_sums) 105 | mask = counts > 0 106 | centroids[mask] = emb_sums[mask] / counts[mask].unsqueeze(1) 107 | 108 | # normalize centroids 109 | centroids = F.normalize(centroids, dim=1, p=2) 110 | 111 | centroids_list.append(centroids) 112 | 113 | if dist.is_available() and dist.is_initialized(): 114 | # gather the assignments 115 | assignments_all = torch.empty( 116 | self.world_size, 117 | local_assignments.size(0), 118 | dtype=local_assignments.dtype, 119 | device=local_assignments.device, 120 | ) 121 | assignments_all = list(assignments_all.unbind(0)) 122 | 123 | dist_process = dist.all_gather( 124 | assignments_all, local_assignments, async_op=True 125 | ) 126 | dist_process.wait() 127 | assignments_all = torch.cat(assignments_all).cpu() 128 | 129 | # gather the indexes 130 | indexes_all = torch.empty( 131 | self.world_size, 132 | local_memory_index.size(0), 133 | dtype=local_memory_index.dtype, 134 | device=local_memory_index.device, 135 | ) 136 | indexes_all = list(indexes_all.unbind(0)) 137 | dist_process = dist.all_gather(indexes_all, local_memory_index, async_op=True) 138 | dist_process.wait() 139 | indexes_all = torch.cat(indexes_all).cpu() 140 | 141 | else: 142 | assignments_all = local_assignments 143 | indexes_all = local_memory_index 144 | 145 | # log assignments 146 | assignments[i_K][indexes_all] = assignments_all 147 | 148 | # next memory bank to use 149 | j = (j + 1) % self.num_crops 150 | 151 | return assignments, centroids_list 152 | -------------------------------------------------------------------------------- /kaizen/utils/auto_umap.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from argparse import ArgumentParser, Namespace 4 | from pathlib import Path 5 | from typing import Optional, Union 6 | 7 | import pandas as pd 8 | import pytorch_lightning as pl 9 | import seaborn as sns 10 | import torch 11 | import umap 12 | import wandb 13 | from matplotlib import pyplot as plt 14 | from pytorch_lightning.callbacks import Callback 15 | 16 | from .gather_layer import gather 17 | 18 | 19 | class AutoUMAP(Callback): 20 | def __init__( 21 | self, 22 | args: Namespace, 23 | logdir: Union[str, Path] = Path("auto_umap"), 24 | frequency: int = 1, 25 | keep_previous: bool = False, 26 | color_palette: str = "hls", 27 | ): 28 | """UMAP callback that automatically runs UMAP on the validation dataset and uploads the 29 | figure to wandb. 30 | 31 | Args: 32 | args (Namespace): namespace object containing at least an attribute name. 33 | logdir (Union[str, Path], optional): base directory to store checkpoints. 34 | Defaults to Path("auto_umap"). 35 | frequency (int, optional): number of epochs between each UMAP. Defaults to 1. 36 | color_palette (str, optional): color scheme for the classes. Defaults to "hls". 37 | keep_previous (bool, optional): whether to keep previous plots or not. 38 | Defaults to False. 39 | """ 40 | 41 | super().__init__() 42 | 43 | self.args = args 44 | self.logdir = Path(logdir) 45 | self.frequency = frequency 46 | self.color_palette = color_palette 47 | self.keep_previous = keep_previous 48 | 49 | @staticmethod 50 | def add_auto_umap_args(parent_parser: ArgumentParser): 51 | """Adds user-required arguments to a parser. 52 | 53 | Args: 54 | parent_parser (ArgumentParser): parser to add new args to. 55 | """ 56 | 57 | parser = parent_parser.add_argument_group("auto_umap") 58 | parser.add_argument("--auto_umap_dir", default=Path("auto_umap"), type=Path) 59 | parser.add_argument("--auto_umap_frequency", default=1, type=int) 60 | return parent_parser 61 | 62 | def initial_setup(self, trainer: pl.Trainer): 63 | """Creates the directories and does the initial setup needed. 64 | 65 | Args: 66 | trainer (pl.Trainer): pytorch lightning trainer object. 67 | """ 68 | 69 | if trainer.logger is None: 70 | version = None 71 | else: 72 | version = str(trainer.logger.version) 73 | if version is not None: 74 | self.path = self.logdir / version 75 | self.umap_placeholder = f"{self.args.name}-{version}" + "-ep={}.pdf" 76 | else: 77 | self.path = self.logdir 78 | self.umap_placeholder = f"{self.args.name}" + "-ep={}.pdf" 79 | self.last_ckpt: Optional[str] = None 80 | 81 | # create logging dirs 82 | if trainer.is_global_zero: 83 | os.makedirs(self.path, exist_ok=True) 84 | 85 | def on_train_start(self, trainer: pl.Trainer, _): 86 | """Performs initial setup on training start. 87 | 88 | Args: 89 | trainer (pl.Trainer): pytorch lightning trainer object. 90 | """ 91 | 92 | self.initial_setup(trainer) 93 | 94 | def plot(self, trainer: pl.Trainer, module: pl.LightningModule): 95 | """Produces a UMAP visualization by forwarding all data of the 96 | first validation dataloader through the module. 97 | 98 | Args: 99 | trainer (pl.Trainer): pytorch lightning trainer object. 100 | module (pl.LightningModule): current module object. 101 | """ 102 | 103 | device = module.device 104 | data = [] 105 | Y = [] 106 | 107 | # set module to eval model and collect all feature representations 108 | module.eval() 109 | with torch.no_grad(): 110 | for x, y in trainer.val_dataloaders[0]: 111 | x = x.to(device, non_blocking=True) 112 | y = y.to(device, non_blocking=True) 113 | 114 | feats = module(x)["feats"] 115 | 116 | feats = gather(feats) 117 | y = gather(y) 118 | 119 | data.append(feats.cpu()) 120 | Y.append(y.cpu()) 121 | module.train() 122 | 123 | if trainer.is_global_zero and len(data): 124 | data = torch.cat(data, dim=0).numpy() 125 | Y = torch.cat(Y, dim=0) 126 | num_classes = len(torch.unique(Y)) 127 | Y = Y.numpy() 128 | 129 | data = umap.UMAP(n_components=2).fit_transform(data) 130 | 131 | # passing to dataframe 132 | df = pd.DataFrame() 133 | df["feat_1"] = data[:, 0] 134 | df["feat_2"] = data[:, 1] 135 | df["Y"] = Y 136 | plt.figure(figsize=(9, 9)) 137 | ax = sns.scatterplot( 138 | x="feat_1", 139 | y="feat_2", 140 | hue="Y", 141 | palette=sns.color_palette(self.color_palette, num_classes), 142 | data=df, 143 | legend="full", 144 | alpha=0.3, 145 | ) 146 | ax.set(xlabel="", ylabel="", xticklabels=[], yticklabels=[]) 147 | ax.tick_params(left=False, right=False, bottom=False, top=False) 148 | 149 | # manually improve quality of imagenet umaps 150 | if num_classes > 100: 151 | anchor = (0.5, 1.8) 152 | else: 153 | anchor = (0.5, 1.35) 154 | 155 | plt.legend(loc="upper center", bbox_to_anchor=anchor, ncol=math.ceil(num_classes / 10)) 156 | plt.tight_layout() 157 | 158 | if isinstance(trainer.logger, pl.loggers.WandbLogger): 159 | wandb.log( 160 | {"validation_umap": wandb.Image(ax)}, 161 | commit=False, 162 | ) 163 | 164 | # save plot locally as well 165 | epoch = trainer.current_epoch # type: ignore 166 | plt.savefig(self.path / self.umap_placeholder.format(epoch)) 167 | plt.close() 168 | 169 | def on_validation_end(self, trainer: pl.Trainer, module: pl.LightningModule): 170 | """Tries to generate an up-to-date UMAP visualization of the features 171 | at the end of each validation epoch. 172 | 173 | Args: 174 | trainer (pl.Trainer): pytorch lightning trainer object. 175 | """ 176 | 177 | epoch = trainer.current_epoch # type: ignore 178 | if epoch % self.frequency == 0 and not trainer.sanity_checking: 179 | self.plot(trainer, module) 180 | -------------------------------------------------------------------------------- /kaizen/utils/knn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchmetrics.metric import Metric 3 | 4 | 5 | from typing import Sequence 6 | 7 | import torch 8 | from torchmetrics.metric import Metric 9 | 10 | 11 | class WeightedKNNClassifier(Metric): 12 | def __init__( 13 | self, 14 | k: int = 20, 15 | T: float = 0.07, 16 | max_distance_matrix_size: int = int(5e6), 17 | distance_fx: str = "cosine", 18 | epsilon: float = 0.00001, 19 | dist_sync_on_step: bool = False, 20 | ): 21 | """Implements the weighted k-NN classifier used for evaluation. 22 | Args: 23 | k (int, optional): number of neighbors. Defaults to 20. 24 | T (float, optional): temperature for the exponential. Only used with cosine 25 | distance. Defaults to 0.07. 26 | max_distance_matrix_size (int, optional): maximum number of elements in the 27 | distance matrix. Defaults to 5e6. 28 | distance_fx (str, optional): Distance function. Accepted arguments: "cosine" or 29 | "euclidean". Defaults to "cosine". 30 | epsilon (float, optional): Small value for numerical stability. Only used with 31 | euclidean distance. Defaults to 0.00001. 32 | dist_sync_on_step (bool, optional): whether to sync distributed values at every 33 | step. Defaults to False. 34 | """ 35 | 36 | super().__init__(dist_sync_on_step=dist_sync_on_step, compute_on_step=False) 37 | 38 | self.k = k 39 | self.T = T 40 | self.max_distance_matrix_size = max_distance_matrix_size 41 | self.distance_fx = distance_fx 42 | self.epsilon = epsilon 43 | 44 | self.add_state("train_features", default=[], persistent=False) 45 | self.add_state("train_targets", default=[], persistent=False) 46 | self.add_state("test_features", default=[], persistent=False) 47 | self.add_state("test_targets", default=[], persistent=False) 48 | 49 | def update( 50 | self, 51 | train_features: torch.Tensor = None, 52 | train_targets: torch.Tensor = None, 53 | test_features: torch.Tensor = None, 54 | test_targets: torch.Tensor = None, 55 | ): 56 | """Updates the memory banks. If train (test) features are passed as input, the 57 | corresponding train (test) targets must be passed as well. 58 | Args: 59 | train_features (torch.Tensor, optional): a batch of train features. Defaults to None. 60 | train_targets (torch.Tensor, optional): a batch of train targets. Defaults to None. 61 | test_features (torch.Tensor, optional): a batch of test features. Defaults to None. 62 | test_targets (torch.Tensor, optional): a batch of test targets. Defaults to None. 63 | """ 64 | assert (train_features is None) == (train_targets is None) 65 | assert (test_features is None) == (test_targets is None) 66 | 67 | if train_features is not None: 68 | assert train_features.size(0) == train_targets.size(0) 69 | self.train_features.append(train_features.detach()) 70 | self.train_targets.append(train_targets.detach()) 71 | 72 | if test_features is not None: 73 | assert test_features.size(0) == test_targets.size(0) 74 | self.test_features.append(test_features.detach()) 75 | self.test_targets.append(test_targets.detach()) 76 | 77 | @torch.no_grad() 78 | def compute(self) -> Sequence[float]: 79 | """Computes weighted k-NN accuracy @1 and @5. If cosine distance is selected, 80 | the weight is computed using the exponential of the temperature scaled cosine 81 | distance of the samples. If euclidean distance is selected, the weight corresponds 82 | to the inverse of the euclidean distance. 83 | Returns: 84 | Sequence[float]: k-NN accuracy @1 and @5. 85 | """ 86 | 87 | train_features = torch.cat(self.train_features) 88 | train_targets = torch.cat(self.train_targets) 89 | test_features = torch.cat(self.test_features) 90 | test_targets = torch.cat(self.test_targets) 91 | 92 | num_classes = torch.unique(test_targets).numel() 93 | num_train_images = train_targets.size(0) 94 | num_test_images = test_targets.size(0) 95 | num_train_images = train_targets.size(0) 96 | chunk_size = min( 97 | max(1, self.max_distance_matrix_size // num_train_images), 98 | num_test_images, 99 | ) 100 | k = min(self.k, num_train_images) 101 | 102 | top1, top5, total = 0.0, 0.0, 0 103 | retrieval_one_hot = torch.zeros(k, num_classes).to(train_features.device) 104 | for idx in range(0, num_test_images, chunk_size): 105 | # get the features for test images 106 | features = test_features[idx : min((idx + chunk_size), num_test_images), :] 107 | targets = test_targets[idx : min((idx + chunk_size), num_test_images)] 108 | batch_size = targets.size(0) 109 | 110 | # calculate the dot product and compute top-k neighbors 111 | if self.distance_fx == "cosine": 112 | similarity = torch.mm(features, train_features.t()) 113 | elif self.distance_fx == "euclidean": 114 | similarity = 1 / (torch.cdist(features, train_features) + self.epsilon) 115 | else: 116 | raise NotImplementedError 117 | 118 | distances, indices = similarity.topk(k, largest=True, sorted=True) 119 | candidates = train_targets.view(1, -1).expand(batch_size, -1) 120 | retrieved_neighbors = torch.gather(candidates, 1, indices) 121 | 122 | retrieval_one_hot.resize_(batch_size * k, num_classes).zero_() 123 | retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1) 124 | 125 | if self.distance_fx == "cosine": 126 | distances = distances.clone().div_(self.T).exp_() 127 | 128 | probs = torch.sum( 129 | torch.mul( 130 | retrieval_one_hot.view(batch_size, -1, num_classes), 131 | distances.view(batch_size, -1, 1), 132 | ), 133 | 1, 134 | ) 135 | _, predictions = probs.sort(1, True) 136 | 137 | # find the predictions that match the target 138 | correct = predictions.eq(targets.data.view(-1, 1)) 139 | top1 = top1 + correct.narrow(1, 0, 1).sum().item() 140 | top5 = ( 141 | top5 + correct.narrow(1, 0, min(5, k, correct.size(-1))).sum().item() 142 | ) # top5 does not make sense if k < 5 143 | total += targets.size(0) 144 | 145 | top1 = top1 * 100.0 / total 146 | top5 = top5 * 100.0 / total 147 | 148 | self.reset() 149 | 150 | return top1, top5 151 | -------------------------------------------------------------------------------- /main_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import types 3 | 4 | import torch 5 | import torch.nn as nn 6 | from pytorch_lightning import Trainer, seed_everything 7 | from pytorch_lightning.callbacks import LearningRateMonitor 8 | from pytorch_lightning.loggers import WandbLogger 9 | from pytorch_lightning.plugins import DDPPlugin 10 | from torchvision.models import resnet18, resnet50 11 | 12 | from kaizen.args.setup import parse_args_eval 13 | from kaizen.methods.full_model import FullModel 14 | 15 | try: 16 | from kaizen.methods.dali import ClassificationABC 17 | except ImportError: 18 | _dali_avaliable = False 19 | else: 20 | _dali_avaliable = True 21 | from kaizen.methods.linear import LinearModel 22 | from kaizen.methods.multi_layer_classifier import MultiLayerClassifier 23 | from kaizen.utils.classification_dataloader import prepare_data 24 | from kaizen.utils.checkpointer import Checkpointer 25 | 26 | 27 | def main(): 28 | args = parse_args_eval() 29 | 30 | # split classes into tasks 31 | tasks = None 32 | if args.split_strategy == "class": 33 | assert args.num_classes % args.num_tasks == 0 34 | torch.manual_seed(args.split_seed) 35 | tasks = torch.randperm(args.num_classes).chunk(args.num_tasks) 36 | 37 | seed_everything(args.global_seed) 38 | 39 | 40 | # Build backbone 41 | if args.encoder == "resnet18": 42 | backbone = resnet18() 43 | elif args.encoder == "resnet50": 44 | backbone = resnet50() 45 | else: 46 | raise ValueError("Only [resnet18, resnet50] are currently supported.") 47 | 48 | if args.cifar: 49 | backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=2, bias=False) 50 | backbone.maxpool = nn.Identity() 51 | backbone.fc = nn.Identity() 52 | 53 | assert ( 54 | args.pretrained_model.endswith(".ckpt") 55 | or args.pretrained_model.endswith(".pth") 56 | or args.pretrained_model.endswith(".pt") 57 | ) 58 | ckpt_path = args.pretrained_model 59 | 60 | 61 | state = torch.load(ckpt_path, map_location=None if torch.cuda.is_available() else "cpu")["state_dict"] 62 | extracted_state = {} 63 | for k in list(state.keys()): 64 | if "encoder" in k: 65 | extracted_state[k.replace("encoder.", "")] = state[k] 66 | missing_keys_backbone, unexpected_keys_backbone = backbone.load_state_dict(extracted_state, strict=False) 67 | print("Missing keys - Backbone:", missing_keys_backbone) 68 | 69 | # Build Classifier 70 | classifier = MultiLayerClassifier(backbone.inplanes, args.num_classes, args.classifier_layers) 71 | # "--evaluation_mode", choices=["linear_eval", "classifier_eval", "online_classifier_eval"] 72 | if args.evaluation_mode == "linear_eval": 73 | is_model_training = True 74 | elif args.evaluation_mode == "classifier_eval": 75 | extracted_state = {} 76 | for k in state: 77 | if k.startswith("classifier."): 78 | extracted_state[k.replace("classifier.", "")] = state[k] 79 | missing_keys_classifier, unexpected_keys_classifier = classifier.load_state_dict(extracted_state, strict=False) 80 | is_model_training = False 81 | print("Missing keys - Classifier:", missing_keys_classifier) 82 | print("Unexpected keys - Classifier:", unexpected_keys_classifier) 83 | elif args.evaluation_mode == "online_classifier_eval": 84 | extracted_state = {} 85 | for k in state: 86 | if k.startswith("online_eval_classifier."): 87 | extracted_state[k.replace("online_eval_classifier.", "")] = state[k] 88 | missing_keys_classifier, unexpected_keys_classifier = classifier.load_state_dict(extracted_state, strict=False) 89 | is_model_training = False 90 | print("Missing keys - Classifier:", missing_keys_classifier) 91 | print("Unexpected keys - Classifier:", unexpected_keys_classifier) 92 | 93 | print(f"Loaded {ckpt_path}") 94 | 95 | if args.dali: 96 | assert _dali_avaliable, "Dali is not currently avaiable, please install it first." 97 | raise NotImplementedError("Dali is not supported") 98 | MethodClass = types.new_class( 99 | f"Dali{LinearModel.__name__}", (ClassificationABC, LinearModel) 100 | ) 101 | # else: 102 | # MethodClass = LinearModel 103 | 104 | model = FullModel(backbone, classifier=classifier, **args.__dict__, tasks=tasks) 105 | 106 | if is_model_training: 107 | train_loader, val_loader = prepare_data( 108 | args.dataset, 109 | data_dir=args.data_dir, 110 | train_dir=args.train_dir, 111 | val_dir=args.val_dir, 112 | batch_size=args.batch_size, 113 | num_workers=args.num_workers, 114 | semi_supervised=args.semi_supervised, 115 | training_data_source=args.linear_classifier_training_data_source, 116 | training_num_tasks=args.num_tasks, 117 | training_tasks=tasks, 118 | training_task_idx=args.task_idx, 119 | training_split_strategy=args.split_strategy, 120 | training_split_seed=args.split_seed, 121 | replay=args.replay, 122 | replay_proportion=args.replay_proportion, 123 | replay_memory_bank_size=args.replay_memory_bank_size 124 | ) 125 | else: 126 | _, val_loader = prepare_data( 127 | args.dataset, 128 | data_dir=args.data_dir, 129 | train_dir=args.train_dir, 130 | val_dir=args.val_dir, 131 | batch_size=args.batch_size, 132 | num_workers=args.num_workers, 133 | semi_supervised=args.semi_supervised, 134 | ) 135 | 136 | callbacks = [] 137 | 138 | # wandb logging 139 | if args.wandb: 140 | wandb_logger = WandbLogger( 141 | name=args.name, project=args.project, entity=args.entity, offline=args.offline 142 | ) 143 | wandb_logger.watch(model, log="gradients", log_freq=100) 144 | wandb_logger.log_hyperparams(args) 145 | 146 | # lr logging 147 | lr_monitor = LearningRateMonitor(logging_interval="epoch") 148 | callbacks.append(lr_monitor) 149 | 150 | # save checkpoint on last epoch only 151 | ckpt = Checkpointer( 152 | args, 153 | logdir=os.path.join(args.checkpoint_dir, "linear"), 154 | frequency=args.checkpoint_frequency, 155 | ) 156 | callbacks.append(ckpt) 157 | 158 | trainer = Trainer.from_argparse_args( 159 | args, 160 | logger=wandb_logger if args.wandb else None, 161 | callbacks=callbacks, 162 | plugins=DDPPlugin(find_unused_parameters=True), 163 | checkpoint_callback=False, 164 | terminate_on_nan=True, 165 | accelerator="ddp" if torch.cuda.is_available() else "cpu", 166 | ) 167 | if is_model_training: 168 | if args.dali: 169 | trainer.fit(model, val_dataloaders=val_loader) 170 | else: 171 | trainer.fit(model, train_loader, val_loader) 172 | else: 173 | trainer.validate(model, val_loader) 174 | 175 | if __name__ == "__main__": 176 | main() 177 | -------------------------------------------------------------------------------- /kaizen/methods/mocov2plus.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Any, Dict, List, Sequence, Tuple 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from kaizen.losses.moco import moco_loss_func 8 | from kaizen.methods.base import BaseMomentumModel 9 | from kaizen.utils.gather_layer import gather 10 | from kaizen.utils.momentum import initialize_momentum_params 11 | 12 | 13 | class MoCoV2Plus(BaseMomentumModel): 14 | queue: torch.Tensor 15 | 16 | def __init__( 17 | self, output_dim: int, proj_hidden_dim: int, temperature: float, queue_size: int, **kwargs 18 | ): 19 | """Implements MoCo V2+ (https://arxiv.org/abs/2011.10566). 20 | 21 | Args: 22 | output_dim (int): number of dimensions of projected features. 23 | proj_hidden_dim (int): number of neurons of the hidden layers of the projector. 24 | temperature (float): temperature for the softmax in the contrastive loss. 25 | queue_size (int): number of samples to keep in the queue. 26 | """ 27 | 28 | super().__init__(**kwargs) 29 | 30 | self.temperature = temperature 31 | self.queue_size = queue_size 32 | 33 | # projector 34 | self.projector = nn.Sequential( 35 | nn.Linear(self.features_dim, proj_hidden_dim), 36 | nn.ReLU(), 37 | nn.Linear(proj_hidden_dim, output_dim), 38 | ) 39 | 40 | # momentum projector 41 | self.momentum_projector = nn.Sequential( 42 | nn.Linear(self.features_dim, proj_hidden_dim), 43 | nn.ReLU(), 44 | nn.Linear(proj_hidden_dim, output_dim), 45 | ) 46 | initialize_momentum_params(self.projector, self.momentum_projector) 47 | 48 | # create the queue 49 | self.register_buffer("queue", torch.randn(2, output_dim, queue_size)) 50 | self.queue = nn.functional.normalize(self.queue, dim=1) 51 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 52 | 53 | @staticmethod 54 | def add_model_specific_args(parent_parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 55 | parent_parser = super(MoCoV2Plus, MoCoV2Plus).add_model_specific_args(parent_parser) 56 | parser = parent_parser.add_argument_group("mocov2plus") 57 | 58 | # projector 59 | parser.add_argument("--output_dim", type=int, default=128) 60 | parser.add_argument("--proj_hidden_dim", type=int, default=2048) 61 | 62 | # parameters 63 | parser.add_argument("--temperature", type=float, default=0.1) 64 | 65 | # queue settings 66 | parser.add_argument("--queue_size", default=65536, type=int) 67 | 68 | return parent_parser 69 | 70 | @property 71 | def learnable_params(self) -> List[dict]: 72 | """Adds projector parameters together with parent's learnable parameters. 73 | 74 | Returns: 75 | List[dict]: list of learnable parameters. 76 | """ 77 | 78 | extra_learnable_params = [{"params": self.projector.parameters()}] 79 | return super().learnable_params + extra_learnable_params 80 | 81 | @property 82 | def momentum_pairs(self) -> List[Tuple[Any, Any]]: 83 | """Adds (projector, momentum_projector) to the parent's momentum pairs. 84 | 85 | Returns: 86 | List[Tuple[Any, Any]]: list of momentum pairs. 87 | """ 88 | 89 | extra_momentum_pairs = [(self.projector, self.momentum_projector)] 90 | return super().momentum_pairs + extra_momentum_pairs 91 | 92 | @torch.no_grad() 93 | def _dequeue_and_enqueue(self, keys: torch.Tensor): 94 | """Adds new samples and removes old samples from the queue in a fifo manner. 95 | 96 | Args: 97 | keys (torch.Tensor): output features of the momentum encoder. 98 | """ 99 | 100 | batch_size = keys.shape[1] 101 | ptr = int(self.queue_ptr) # type: ignore 102 | keys = keys.permute(0, 2, 1) 103 | 104 | # assert self.queue_size % batch_size == 0 # for simplicity 105 | # Allow non divisible queue size 106 | if ptr + batch_size > self.queue_size: 107 | remaining_slots = self.queue_size - ptr 108 | self.queue[:, :, ptr : ptr + remaining_slots] = keys[:, :, :remaining_slots] 109 | self.queue[:, :, :batch_size - remaining_slots] = keys[:, :, remaining_slots:] 110 | ptr = batch_size - remaining_slots 111 | else: 112 | # replace the keys at ptr (dequeue and enqueue) 113 | self.queue[:, :, ptr : ptr + batch_size] = keys 114 | # ptr = (ptr + batch_size) % self.queue_size # move pointer 115 | ptr += batch_size 116 | self.queue_ptr[0] = ptr # type: ignore 117 | 118 | def forward(self, X: torch.Tensor, *args, **kwargs) -> Dict[str, Any]: 119 | """Performs the forward pass of the online encoder and the online projection. 120 | 121 | Args: 122 | X (torch.Tensor): a batch of images in the tensor format. 123 | 124 | Returns: 125 | Dict[str, Any]: a dict containing the outputs of the parent and the projected features. 126 | """ 127 | 128 | out = super().forward(X, *args, **kwargs) 129 | q = F.normalize(self.projector(out["feats"]), dim=-1) 130 | return {**out, "q": q} 131 | 132 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: 133 | """ 134 | Training step for MoCo reusing BaseMomentumModel training step. 135 | 136 | Args: 137 | batch (Sequence[Any]): a batch of data in the 138 | format of [img_indexes, [X], Y], where [X] is a list of size self.num_crops 139 | containing batches of images. 140 | batch_idx (int): index of the batch. 141 | 142 | Returns: 143 | torch.Tensor: total loss composed of MOCO loss and classification loss. 144 | 145 | """ 146 | 147 | out = super().training_step(batch, batch_idx) 148 | feats1, feats2 = out["feats"] 149 | momentum_feats1, momentum_feats2 = out["momentum_feats"] 150 | 151 | q1 = self.projector(feats1) 152 | q2 = self.projector(feats2) 153 | q1 = F.normalize(q1, dim=-1) 154 | q2 = F.normalize(q2, dim=-1) 155 | 156 | with torch.no_grad(): 157 | k1 = self.momentum_projector(momentum_feats1) 158 | k2 = self.momentum_projector(momentum_feats2) 159 | k1 = F.normalize(k1, dim=-1) 160 | k2 = F.normalize(k2, dim=-1) 161 | 162 | # ------- contrastive loss ------- 163 | # symmetric 164 | queue = self.queue.clone().detach() 165 | nce_loss = ( 166 | moco_loss_func(q1, k2, queue[1], self.temperature) 167 | + moco_loss_func(q2, k1, queue[0], self.temperature) 168 | ) / 2 169 | 170 | # ------- update queue ------- 171 | keys = torch.stack((gather(k1), gather(k2))) 172 | self._dequeue_and_enqueue(keys) 173 | 174 | self.log("train_nce_loss", nce_loss, on_epoch=True, sync_dist=True) 175 | 176 | out.update({"loss": out["loss"] + nce_loss, "z": [q1, q2]}) 177 | return out 178 | -------------------------------------------------------------------------------- /kaizen/args/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import Namespace 3 | import distutils 4 | 5 | N_CLASSES_PER_DATASET = { 6 | "cifar10": 10, 7 | "cifar100": 100, 8 | "stl10": 10, 9 | "imagenet": 1000, 10 | "imagenet100": 100, 11 | "domainnet": 345, 12 | } 13 | 14 | def strtobool(v): 15 | return bool(distutils.util.strtobool(v)) 16 | 17 | def additional_setup_pretrain(args: Namespace): 18 | """Provides final setup for pretraining to non-user given parameters by changing args. 19 | 20 | Parsers arguments to extract the number of classes of a dataset, create 21 | transformations kwargs, correctly parse gpus, identify if a cifar dataset 22 | is being used and adjust the lr. 23 | 24 | Args: 25 | args (Namespace): object that needs to contain, at least: 26 | - dataset: dataset name. 27 | - brightness, contrast, saturation, hue, min_scale: required augmentations 28 | settings. 29 | - multicrop: flag to use multicrop. 30 | - dali: flag to use dali. 31 | - optimizer: optimizer name being used. 32 | - gpus: list of gpus to use. 33 | - lr: learning rate. 34 | 35 | [optional] 36 | - gaussian_prob, solarization_prob: optional augmentations settings. 37 | """ 38 | 39 | args.transform_kwargs = {} 40 | 41 | if args.dataset in N_CLASSES_PER_DATASET: 42 | args.num_classes = N_CLASSES_PER_DATASET[args.dataset] 43 | else: 44 | # hack to maintain the current pipeline 45 | # even if the custom dataset doesn't have any labels 46 | dir_path = args.data_dir / args.train_dir 47 | args.num_classes = max( 48 | 1, 49 | len([entry.name for entry in os.scandir(dir_path) if entry.is_dir]), 50 | ) 51 | 52 | unique_augs = max( 53 | len(p) 54 | for p in [ 55 | args.brightness, 56 | args.contrast, 57 | args.saturation, 58 | args.hue, 59 | args.gaussian_prob, 60 | args.solarization_prob, 61 | args.min_scale, 62 | args.size, 63 | ] 64 | ) 65 | assert unique_augs == args.num_crops or unique_augs == 1 66 | 67 | # assert that either all unique augmentation pipelines have a unique 68 | # parameter or that a single parameter is replicated to all pipelines 69 | for p in [ 70 | "brightness", 71 | "contrast", 72 | "saturation", 73 | "hue", 74 | "gaussian_prob", 75 | "solarization_prob", 76 | "min_scale", 77 | "size", 78 | ]: 79 | values = getattr(args, p) 80 | n = len(values) 81 | assert n == unique_augs or n == 1 82 | 83 | if n == 1: 84 | setattr(args, p, getattr(args, p) * unique_augs) 85 | 86 | args.unique_augs = unique_augs 87 | 88 | if unique_augs > 1: 89 | args.transform_kwargs = [ 90 | dict( 91 | brightness=brightness, 92 | contrast=contrast, 93 | saturation=saturation, 94 | hue=hue, 95 | gaussian_prob=gaussian_prob, 96 | solarization_prob=solarization_prob, 97 | min_scale=min_scale, 98 | size=size, 99 | ) 100 | for ( 101 | brightness, 102 | contrast, 103 | saturation, 104 | hue, 105 | gaussian_prob, 106 | solarization_prob, 107 | min_scale, 108 | size, 109 | ) in zip( 110 | args.brightness, 111 | args.contrast, 112 | args.saturation, 113 | args.hue, 114 | args.gaussian_prob, 115 | args.solarization_prob, 116 | args.min_scale, 117 | args.size, 118 | ) 119 | ] 120 | 121 | elif not args.multicrop: 122 | args.transform_kwargs = dict( 123 | brightness=args.brightness[0], 124 | contrast=args.contrast[0], 125 | saturation=args.saturation[0], 126 | hue=args.hue[0], 127 | gaussian_prob=args.gaussian_prob[0], 128 | solarization_prob=args.solarization_prob[0], 129 | min_scale=args.min_scale[0], 130 | size=args.size[0], 131 | ) 132 | else: 133 | args.transform_kwargs = dict( 134 | brightness=args.brightness[0], 135 | contrast=args.contrast[0], 136 | saturation=args.saturation[0], 137 | hue=args.hue[0], 138 | gaussian_prob=args.gaussian_prob[0], 139 | solarization_prob=args.solarization_prob[0], 140 | ) 141 | 142 | # add support for custom mean and std 143 | if args.dataset == "custom": 144 | if isinstance(args.transform_kwargs, dict): 145 | args.transform_kwargs["mean"] = args.mean 146 | args.transform_kwargs["std"] = args.std 147 | else: 148 | for kwargs in args.transform_kwargs: 149 | kwargs["mean"] = args.mean 150 | kwargs["std"] = args.std 151 | 152 | if args.dataset in ["cifar10", "cifar100", "stl10"]: 153 | if isinstance(args.transform_kwargs, dict): 154 | del args.transform_kwargs["size"] 155 | else: 156 | for kwargs in args.transform_kwargs: 157 | del kwargs["size"] 158 | 159 | args.cifar = True if args.dataset in ["cifar10", "cifar100"] else False 160 | 161 | if args.dali: 162 | assert args.dataset in ["imagenet100", "imagenet", "domainnet", "custom"] 163 | 164 | args.extra_optimizer_args = {} 165 | if args.optimizer == "sgd": 166 | args.extra_optimizer_args["momentum"] = 0.9 167 | 168 | if isinstance(args.gpus, int): 169 | args.gpus = [args.gpus] 170 | elif isinstance(args.gpus, str): 171 | args.gpus = [int(gpu) for gpu in args.gpus.split(",") if gpu] 172 | 173 | # adjust lr according to batch size 174 | args.lr = args.lr * args.batch_size * len(args.gpus) / 256 175 | 176 | 177 | def additional_setup_linear(args: Namespace): 178 | """Provides final setup for linear evaluation to non-user given parameters by changing args. 179 | 180 | Parsers arguments to extract the number of classes of a dataset, correctly parse gpus, identify 181 | if a cifar dataset is being used and adjust the lr. 182 | 183 | Args: 184 | args: Namespace object that needs to contain, at least: 185 | - dataset: dataset name. 186 | - optimizer: optimizer name being used. 187 | - gpus: list of gpus to use. 188 | - lr: learning rate. 189 | """ 190 | 191 | assert args.dataset in N_CLASSES_PER_DATASET 192 | args.num_classes = N_CLASSES_PER_DATASET[args.dataset] 193 | 194 | args.cifar = True if args.dataset in ["cifar10", "cifar100"] else False 195 | 196 | if args.dali: 197 | assert args.dataset in ["imagenet100", "imagenet", "domainnet"] 198 | 199 | args.extra_optimizer_args = {} 200 | if args.optimizer == "sgd": 201 | args.extra_optimizer_args["momentum"] = 0.9 202 | 203 | if isinstance(args.gpus, int): 204 | args.gpus = [args.gpus] 205 | elif isinstance(args.gpus, str): 206 | args.gpus = [int(gpu) for gpu in args.gpus.split(",") if gpu] 207 | --------------------------------------------------------------------------------