├── setup_commands.sh
├── imgs
├── method.png
├── results_final.png
└── results_continual.png
├── kaizen
├── __init__.py
├── args
│ ├── __init__.py
│ ├── continual.py
│ ├── dataset.py
│ └── utils.py
├── losses
│ ├── wmse.py
│ ├── byol.py
│ ├── simsiam.py
│ ├── nnclr.py
│ ├── __init__.py
│ ├── moco.py
│ ├── swav.py
│ ├── deepclusterv2.py
│ ├── ressl.py
│ ├── barlow.py
│ ├── vicreg.py
│ ├── dino.py
│ └── simclr.py
├── utils
│ ├── __init__.py
│ ├── gather_layer.py
│ ├── whitening.py
│ ├── metrics.py
│ ├── trunc_normal.py
│ ├── datasets.py
│ ├── sinkhorn_knopp.py
│ ├── momentum.py
│ ├── lars.py
│ ├── checkpointer.py
│ ├── kmeans.py
│ ├── auto_umap.py
│ └── knn.py
├── methods
│ ├── multi_layer_classifier.py
│ ├── __init__.py
│ ├── barlow_twins.py
│ ├── vicreg.py
│ ├── wmse.py
│ ├── simsiam.py
│ ├── byol.py
│ ├── ressl.py
│ ├── simclr.py
│ └── mocov2plus.py
├── distillers
│ ├── __init__.py
│ ├── base.py
│ ├── predictive_mse.py
│ ├── predictive.py
│ ├── contrastive.py
│ ├── decorrelative.py
│ └── knowledge.py
└── distiller_factories
│ ├── __init__.py
│ ├── base.py
│ ├── predictive_mse.py
│ ├── predictive.py
│ ├── soft_label.py
│ ├── contrastive.py
│ ├── decorrelative.py
│ └── knowledge.py
├── requirements.txt
├── LICENSE.md
├── bash_files
├── byol_cifar_distill_classifier_l1000_soft_label_replay_0.01_b32.sh
├── mocov2plus_cifar_distill_classifier_l1000_soft_label_replay_0.01_b32.sh
├── simclr_cifar_distill_classifier_l1000_soft_label_replay_0.01_b32.sh
└── vicreg_cifar_distill_classifier_l1000_soft_label_replay_0.01_b32.sh
├── job_launcher.py
├── .gitignore
├── main_continual.py
├── main_linear.py
├── README.md
├── evaluate_folder.py
└── main_eval.py
/setup_commands.sh:
--------------------------------------------------------------------------------
1 | export LC_ALL=C.UTF-8
2 | export LANG=C.UTF-8
3 |
--------------------------------------------------------------------------------
/imgs/method.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nokia-Bell-Labs/kaizen/HEAD/imgs/method.png
--------------------------------------------------------------------------------
/imgs/results_final.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nokia-Bell-Labs/kaizen/HEAD/imgs/results_final.png
--------------------------------------------------------------------------------
/imgs/results_continual.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Nokia-Bell-Labs/kaizen/HEAD/imgs/results_continual.png
--------------------------------------------------------------------------------
/kaizen/__init__.py:
--------------------------------------------------------------------------------
1 | from kaizen import args, losses, methods, utils
2 |
3 | __all__ = ["args", "losses", "methods", "utils"]
4 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.10
2 | torchvision
3 | pytorch-lightning
4 | lightning-bolts
5 | wandb
6 | scikit-learn
7 | einops
8 | torchaudio
9 |
--------------------------------------------------------------------------------
/kaizen/args/__init__.py:
--------------------------------------------------------------------------------
1 | from kaizen.args import dataset, setup, utils, continual
2 |
3 | __all__ = ["dataset", "setup", "utils", "continual"]
4 |
--------------------------------------------------------------------------------
/kaizen/losses/wmse.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def wmse_loss_func(z1: torch.Tensor, z2: torch.Tensor, simplified: bool = True) -> torch.Tensor:
6 | """Computes W-MSE's loss given two batches of whitened features z1 and z2.
7 |
8 | Args:
9 | z1 (torch.Tensor): NxD Tensor containing whitened features from view 1.
10 | z2 (torch.Tensor): NxD Tensor containing whitened features from view 2.
11 | simplified (bool): faster computation, but with same result.
12 |
13 | Returns:
14 | torch.Tensor: W-MSE loss.
15 | """
16 |
17 | if simplified:
18 | return 2 - 2 * F.cosine_similarity(z1, z2.detach(), dim=-1).mean()
19 | else:
20 | z1 = F.normalize(z1, dim=-1)
21 | z2 = F.normalize(z2, dim=-1)
22 |
23 | return 2 - 2 * (z1 * z2).sum(dim=-1).mean()
24 |
--------------------------------------------------------------------------------
/kaizen/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from kaizen.utils import (
2 | checkpointer,
3 | classification_dataloader,
4 | datasets,
5 | gather_layer,
6 | knn,
7 | lars,
8 | metrics,
9 | momentum,
10 | pretrain_dataloader,
11 | sinkhorn_knopp,
12 | )
13 |
14 | __all__ = [
15 | "classification_dataloader",
16 | "pretrain_dataloader",
17 | "checkpointer",
18 | "datasets",
19 | "gather_layer",
20 | "knn",
21 | "lars",
22 | "metrics",
23 | "momentum",
24 | "sinkhorn_knopp",
25 | ]
26 |
27 | try:
28 | from kaizen.utils import dali_dataloader # noqa: F401
29 | except ImportError:
30 | pass
31 | else:
32 | __all__.append("dali_dataloader")
33 |
34 | try:
35 | from kaizen.utils import auto_umap # noqa: F401
36 | except ImportError:
37 | pass
38 | else:
39 | __all__.append("auto_umap")
40 |
--------------------------------------------------------------------------------
/kaizen/losses/byol.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def byol_loss_func(p: torch.Tensor, z: torch.Tensor, simplified: bool = True) -> torch.Tensor:
6 | """Computes BYOL's loss given batch of predicted features p and projected momentum features z.
7 |
8 | Args:
9 | p (torch.Tensor): NxD Tensor containing predicted features from view 1
10 | z (torch.Tensor): NxD Tensor containing projected momentum features from view 2
11 | simplified (bool): faster computation, but with same result. Defaults to True.
12 |
13 | Returns:
14 | torch.Tensor: BYOL's loss.
15 | """
16 |
17 | if simplified:
18 | return 2 - 2 * F.cosine_similarity(p, z.detach(), dim=-1).mean()
19 | else:
20 | p = F.normalize(p, dim=-1)
21 | z = F.normalize(z, dim=-1)
22 |
23 | return 2 - 2 * (p * z.detach()).sum(dim=1).mean()
24 |
--------------------------------------------------------------------------------
/kaizen/losses/simsiam.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def simsiam_loss_func(p: torch.Tensor, z: torch.Tensor, simplified: bool = True) -> torch.Tensor:
6 | """Computes SimSiam's loss given batch of predicted features p from view 1 and
7 | a batch of projected features z from view 2.
8 |
9 | Args:
10 | p (torch.Tensor): Tensor containing predicted features from view 1.
11 | z (torch.Tensor): Tensor containing projected features from view 2.
12 | simplified (bool): faster computation, but with same result.
13 |
14 | Returns:
15 | torch.Tensor: SimSiam loss.
16 | """
17 |
18 | if simplified:
19 | return -F.cosine_similarity(p, z.detach(), dim=-1).mean()
20 | else:
21 | p = F.normalize(p, dim=-1)
22 | z = F.normalize(z, dim=-1)
23 |
24 | return -(p * z.detach()).sum(dim=1).mean()
25 |
--------------------------------------------------------------------------------
/kaizen/methods/multi_layer_classifier.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class MultiLayerClassifier(nn.Module):
6 | def __init__(self, input_dim, num_classes, layer_units=[]):
7 | super().__init__()
8 | self.input_dim = input_dim
9 | self.num_classes = num_classes
10 | self.layer_units = layer_units
11 |
12 | in_dim = self.input_dim
13 | self.all_layers = []
14 | for i, num_units in enumerate(layer_units):
15 | layer_name = f"fc_{i}"
16 | layer = nn.Linear(in_dim, num_units)
17 | in_dim = num_units
18 | setattr(self, layer_name, layer)
19 | self.all_layers.append(layer)
20 | self.fc_output = nn.Linear(in_dim, self.num_classes)
21 |
22 | def forward(self, x):
23 | for l in self.all_layers:
24 | x = F.relu(l(x))
25 | return self.fc_output(x)
26 |
27 |
28 |
--------------------------------------------------------------------------------
/kaizen/losses/nnclr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def nnclr_loss_func(nn: torch.Tensor, p: torch.Tensor, temperature: float = 0.1) -> torch.Tensor:
6 | """Computes NNCLR's loss given batch of nearest-neighbors nn from view 1 and
7 | predicted features p from view 2.
8 |
9 | Args:
10 | nn (torch.Tensor): NxD Tensor containing nearest neighbors' features from view 1.
11 | p (torch.Tensor): NxD Tensor containing predicted features from view 2
12 | temperature (float, optional): temperature of the softmax in the contrastive loss. Defaults
13 | to 0.1.
14 |
15 | Returns:
16 | torch.Tensor: NNCLR loss.
17 | """
18 |
19 | nn = F.normalize(nn, dim=-1)
20 | p = F.normalize(p, dim=-1)
21 |
22 | logits = nn @ p.T / temperature
23 |
24 | n = p.size(0)
25 | labels = torch.arange(n, device=p.device)
26 |
27 | loss = F.cross_entropy(logits, labels)
28 | return loss
29 |
--------------------------------------------------------------------------------
/kaizen/distillers/__init__.py:
--------------------------------------------------------------------------------
1 | from kaizen.distillers.base import base_distill_wrapper
2 | from kaizen.distillers.contrastive import contrastive_distill_wrapper
3 | from kaizen.distillers.decorrelative import decorrelative_distill_wrapper
4 | from kaizen.distillers.knowledge import knowledge_distill_wrapper
5 | from kaizen.distillers.predictive import predictive_distill_wrapper
6 | from kaizen.distillers.predictive_mse import predictive_mse_distill_wrapper
7 |
8 |
9 | __all__ = [
10 | "base_distill_wrapper",
11 | "contrastive_distill_wrapper",
12 | "decorrelative_distill_wrapper",
13 | "nearest_neighbor_distill_wrapper",
14 | "predictive_distill_wrapper",
15 | "predictive_mse_distill_wrapper",
16 | ]
17 |
18 | DISTILLERS = {
19 | "base": base_distill_wrapper,
20 | "contrastive": contrastive_distill_wrapper,
21 | "decorrelative": decorrelative_distill_wrapper,
22 | "knowledge": knowledge_distill_wrapper,
23 | "predictive": predictive_distill_wrapper,
24 | "predictive_mse": predictive_mse_distill_wrapper,
25 | }
26 |
--------------------------------------------------------------------------------
/kaizen/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from kaizen.losses.barlow import barlow_loss_func
2 | from kaizen.losses.byol import byol_loss_func
3 | from kaizen.losses.deepclusterv2 import deepclusterv2_loss_func
4 | from kaizen.losses.dino import DINOLoss
5 | from kaizen.losses.moco import moco_loss_func
6 | from kaizen.losses.nnclr import nnclr_loss_func
7 | from kaizen.losses.ressl import ressl_loss_func
8 | from kaizen.losses.simclr import manual_simclr_loss_func, simclr_loss_func, simclr_distill_loss_func
9 | from kaizen.losses.simsiam import simsiam_loss_func
10 | from kaizen.losses.swav import swav_loss_func
11 | from kaizen.losses.vicreg import vicreg_loss_func
12 | from kaizen.losses.wmse import wmse_loss_func
13 |
14 | __all__ = [
15 | "barlow_loss_func",
16 | "byol_loss_func",
17 | "deepclusterv2_loss_func",
18 | "DINOLoss",
19 | "moco_loss_func",
20 | "nnclr_loss_func",
21 | "ressl_loss_func",
22 | "simclr_loss_func",
23 | "manual_simclr_loss_func",
24 | "simclr_distill_loss_func",
25 | "simsiam_loss_func",
26 | "swav_loss_func",
27 | "vicreg_loss_func",
28 | "wmse_loss_func",
29 | ]
30 |
--------------------------------------------------------------------------------
/kaizen/utils/gather_layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 |
4 |
5 | class GatherLayer(torch.autograd.Function):
6 | """Gathers tensors from all processes, supporting backward propagation."""
7 |
8 | @staticmethod
9 | def forward(ctx, input):
10 | ctx.save_for_backward(input)
11 | if dist.is_available() and dist.is_initialized():
12 | output = [torch.zeros_like(input) for _ in range(dist.get_world_size())]
13 | dist.all_gather(output, input)
14 | else:
15 | output = [input]
16 | return tuple(output)
17 |
18 | @staticmethod
19 | def backward(ctx, *grads):
20 | (input,) = ctx.saved_tensors
21 | if dist.is_available() and dist.is_initialized():
22 | grad_out = torch.zeros_like(input)
23 | grad_out[:] = grads[dist.get_rank()]
24 | else:
25 | grad_out = grads[0]
26 | return grad_out
27 |
28 |
29 | def gather(X, dim=0):
30 | """Gathers tensors from all processes, supporting backward propagation."""
31 | return torch.cat(GatherLayer.apply(X), dim=dim)
32 |
--------------------------------------------------------------------------------
/kaizen/losses/moco.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def moco_loss_func(
6 | query: torch.Tensor, key: torch.Tensor, queue: torch.Tensor, temperature=0.1
7 | ) -> torch.Tensor:
8 | """Computes MoCo's loss given a batch of queries from view 1, a batch of keys from view 2 and a
9 | queue of past elements.
10 |
11 | Args:
12 | query (torch.Tensor): NxD Tensor containing the queries from view 1.
13 | key (torch.Tensor): NxD Tensor containing the queries from view 2.
14 | queue (torch.Tensor): a queue of negative samples for the contrastive loss.
15 | temperature (float, optional): [description]. temperature of the softmax in the contrastive
16 | loss. Defaults to 0.1.
17 |
18 | Returns:
19 | torch.Tensor: MoCo loss.
20 | """
21 |
22 | pos = torch.einsum("nc,nc->n", [query, key]).unsqueeze(-1)
23 | neg = torch.einsum("nc,ck->nk", [query, queue])
24 | logits = torch.cat([pos, neg], dim=1)
25 | logits /= temperature
26 | targets = torch.zeros(query.size(0), device=query.device, dtype=torch.long)
27 | return F.cross_entropy(logits, targets)
28 |
--------------------------------------------------------------------------------
/kaizen/losses/swav.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import numpy as np
4 | import torch
5 |
6 |
7 | def swav_loss_func(
8 | preds: List[torch.Tensor], assignments: List[torch.Tensor], temperature: float = 0.1
9 | ) -> torch.Tensor:
10 | """Computes SwAV's loss given list of batch predictions from multiple views
11 | and a list of cluster assignments from the same multiple views.
12 |
13 | Args:
14 | preds (torch.Tensor): list of NxC Tensors containing nearest neighbors' features from
15 | view 1.
16 | assignments (torch.Tensor): list of NxC Tensor containing predicted features from view 2.
17 | temperature (torch.Tensor): softmax temperature for the loss. Defaults to 0.1.
18 |
19 | Returns:
20 | torch.Tensor: SwAV loss.
21 | """
22 |
23 | losses = []
24 | for v1 in range(len(preds)):
25 | for v2 in np.delete(np.arange(len(preds)), v1):
26 | a = assignments[v1]
27 | p = preds[v2] / temperature
28 | loss = -torch.mean(torch.sum(a * torch.log_softmax(p, dim=1), dim=1))
29 | losses.append(loss)
30 | return sum(losses) / len(losses)
31 |
--------------------------------------------------------------------------------
/kaizen/losses/deepclusterv2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def deepclusterv2_loss_func(
6 | outputs: torch.Tensor, assignments: torch.Tensor, temperature: float = 0.1
7 | ) -> torch.Tensor:
8 | """Computes DeepClusterV2's loss given a tensor containing logits from multiple views
9 | and a tensor containing cluster assignments from the same multiple views.
10 |
11 | Args:
12 | outputs (torch.Tensor): tensor of size PxVxNxC where P is the number of prototype
13 | layers and V is the number of views.
14 | assignments (torch.Tensor): tensor of size PxVxNxC containing the assignments
15 | generated using k-means.
16 | temperature (float, optional): softmax temperature for the loss. Defaults to 0.1.
17 |
18 | Returns:
19 | torch.Tensor: DeepClusterV2 loss.
20 | """
21 | loss = 0
22 | for h in range(outputs.size(0)):
23 | scores = outputs[h].view(-1, outputs.size(-1)) / temperature
24 | targets = assignments[h].repeat(outputs.size(1)).to(outputs.device, non_blocking=True)
25 | loss += F.cross_entropy(scores, targets, ignore_index=-1)
26 | return loss / outputs.size(0)
27 |
--------------------------------------------------------------------------------
/kaizen/args/continual.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | from .utils import strtobool
3 | DISTILLER_LIBRARIES = ["default", "factory"]
4 |
5 | def continual_args(parser: ArgumentParser):
6 | """Adds continual learning arguments to a parser.
7 |
8 | Args:
9 | parser (ArgumentParser): parser to add dataset args to.
10 | """
11 | # base continual learning args
12 | parser.add_argument("--num_tasks", type=int, default=2)
13 | parser.add_argument("--task_idx", type=int, required=True)
14 |
15 | SPLIT_STRATEGIES = ["class", "data", "domain"]
16 | parser.add_argument("--split_strategy", choices=SPLIT_STRATEGIES, type=str, required=True)
17 |
18 | # distillation args
19 | parser.add_argument("--distiller", type=str, default=None)
20 | parser.add_argument("--distiller_classifier", type=str, default=None)
21 | parser.add_argument("--distiller_library", type=str, choices=DISTILLER_LIBRARIES, default=DISTILLER_LIBRARIES[0])
22 |
23 | # Memory Bank/Replay args
24 | parser.add_argument("--replay", type=strtobool, default=False)
25 | parser.add_argument("--replay_proportion", type=float, default=1.0)
26 | parser.add_argument("--replay_memory_bank_size", type=int, default=None)
27 | parser.add_argument("--replay_batch_size", type=int, default=64)
28 |
--------------------------------------------------------------------------------
/kaizen/losses/ressl.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def ressl_loss_func(
6 | q: torch.Tensor,
7 | k: torch.Tensor,
8 | queue: torch.Tensor,
9 | temperature_q: float = 0.1,
10 | temperature_k: float = 0.04,
11 | ) -> torch.Tensor:
12 | """Computes ReSSL's loss given a batch of queries from view 1, a batch of keys from view 2 and a
13 | queue of past elements.
14 |
15 | Args:
16 | query (torch.Tensor): NxD Tensor containing the queries from view 1.
17 | key (torch.Tensor): NxD Tensor containing the queries from view 2.
18 | queue (torch.Tensor): a queue of negative samples for the contrastive loss.
19 | temperature_q (float, optional): [description]. temperature of the softmax for the query.
20 | Defaults to 0.1.
21 | temperature_k (float, optional): [description]. temperature of the softmax for the key.
22 | Defaults to 0.04.
23 |
24 | Returns:
25 | torch.Tensor: ReSSL loss.
26 | """
27 |
28 | logits_q = torch.einsum("nc,kc->nk", [q, queue])
29 | logits_k = torch.einsum("nc,kc->nk", [k, queue])
30 |
31 | loss = -torch.sum(
32 | F.softmax(logits_k.detach() / temperature_k, dim=1)
33 | * F.log_softmax(logits_q / temperature_q, dim=1),
34 | dim=1,
35 | ).mean()
36 |
37 | return loss
38 |
--------------------------------------------------------------------------------
/kaizen/distiller_factories/__init__.py:
--------------------------------------------------------------------------------
1 | from kaizen.distiller_factories.base import base_frozen_model_factory
2 | from kaizen.distiller_factories.contrastive import contrastive_distill_factory
3 | from kaizen.distiller_factories.decorrelative import decorrelative_distill_factory
4 | from kaizen.distiller_factories.knowledge import knowledge_distill_factory
5 | from kaizen.distiller_factories.predictive import predictive_distill_factory
6 | from kaizen.distiller_factories.predictive_mse import predictive_mse_distill_factory
7 | from kaizen.distiller_factories.soft_label import soft_label_distill_factory
8 |
9 |
10 | __all__ = [
11 | "base_frozen_model_factory",
12 | "contrastive_distill_factory",
13 | "decorrelative_distill_factory",
14 | "nearest_neighbor_distill_wrapper", # TODO: Check what this is
15 | "knowledge_distill_factory",
16 | "predictive_distill_factory",
17 | "predictive_mse_distill_factory",
18 | "soft_label_distill_factory"
19 | ]
20 |
21 | DISTILLER_FACTORIES = {
22 | "base": base_frozen_model_factory,
23 | "contrastive": contrastive_distill_factory,
24 | "decorrelative": decorrelative_distill_factory,
25 | "knowledge": knowledge_distill_factory,
26 | "predictive": predictive_distill_factory,
27 | "predictive_mse": predictive_mse_distill_factory,
28 | "soft_label": soft_label_distill_factory
29 | }
30 |
--------------------------------------------------------------------------------
/kaizen/losses/barlow.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import torch.distributed as dist
4 |
5 |
6 | def barlow_loss_func(
7 | z1: torch.Tensor, z2: torch.Tensor, lamb: float = 5e-3, scale_loss: float = 0.025
8 | ) -> torch.Tensor:
9 | """Computes Barlow Twins' loss given batch of projected features z1 from view 1 and
10 | projected features z2 from view 2.
11 |
12 | Args:
13 | z1 (torch.Tensor): NxD Tensor containing projected features from view 1.
14 | z2 (torch.Tensor): NxD Tensor containing projected features from view 2.
15 | lamb (float, optional): off-diagonal scaling factor for the cross-covariance matrix.
16 | Defaults to 5e-3.
17 | scale_loss (float, optional): final scaling factor of the loss. Defaults to 0.025.
18 |
19 | Returns:
20 | torch.Tensor: Barlow Twins' loss.
21 | """
22 |
23 | N, D = z1.size()
24 |
25 | # to match the original code
26 | bn = torch.nn.BatchNorm1d(D, affine=False).to(z1.device)
27 | z1 = bn(z1)
28 | z2 = bn(z2)
29 |
30 | corr = torch.einsum("bi, bj -> ij", z1, z2) / N
31 |
32 | if dist.is_available() and dist.is_initialized():
33 | dist.all_reduce(corr)
34 | world_size = dist.get_world_size()
35 | corr /= world_size
36 |
37 | diag = torch.eye(D, device=corr.device)
38 | cdif = (corr - diag).pow(2)
39 | cdif[~diag.bool()] *= lamb
40 | loss = scale_loss * cdif.sum()
41 | return loss
42 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | Copyright (c) 2024 Nokia
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without modification, are permitted (subject to the limitations in the disclaimer below) provided that the following conditions are met:
5 |
6 | * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
7 | * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
8 | * Neither the name of [Owner Organization] nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
9 |
10 | NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/kaizen/methods/__init__.py:
--------------------------------------------------------------------------------
1 | from kaizen.methods.barlow_twins import BarlowTwins
2 | from kaizen.methods.base import BaseModel
3 | from kaizen.methods.byol import BYOL
4 | from kaizen.methods.deepclusterv2 import DeepClusterV2
5 | from kaizen.methods.dino import DINO
6 | from kaizen.methods.linear import LinearModel
7 | from kaizen.methods.mocov2plus import MoCoV2Plus
8 | from kaizen.methods.nnclr import NNCLR
9 | from kaizen.methods.ressl import ReSSL
10 | from kaizen.methods.simclr import SimCLR
11 | from kaizen.methods.simsiam import SimSiam
12 | from kaizen.methods.swav import SwAV
13 | from kaizen.methods.vicreg import VICReg
14 | from kaizen.methods.wmse import WMSE
15 | from kaizen.methods.full_model import FullModel
16 |
17 | METHODS = {
18 | # base classes
19 | "base": BaseModel,
20 | "linear": LinearModel,
21 | "full_model": FullModel,
22 | # methods
23 | "barlow_twins": BarlowTwins,
24 | "byol": BYOL,
25 | "deepclusterv2": DeepClusterV2,
26 | "dino": DINO,
27 | "mocov2plus": MoCoV2Plus,
28 | "nnclr": NNCLR,
29 | "ressl": ReSSL,
30 | "simclr": SimCLR,
31 | "simsiam": SimSiam,
32 | "swav": SwAV,
33 | "vicreg": VICReg,
34 | "wmse": WMSE,
35 | }
36 | __all__ = [
37 | "BarlowTwins",
38 | "BYOL",
39 | "BaseModel",
40 | "DeepClusterV2",
41 | "DINO",
42 | "LinearModel",
43 | "FullModel",
44 | "MoCoV2Plus",
45 | "NNCLR",
46 | "ReSSL",
47 | "SimCLR",
48 | "SimSiam",
49 | "SwAV",
50 | "VICReg",
51 | "WMSE",
52 | ]
53 |
54 | try:
55 | from kaizen.methods import dali # noqa: F401
56 | except ImportError:
57 | pass
58 | else:
59 | __all__.append("dali")
60 |
--------------------------------------------------------------------------------
/kaizen/distillers/base.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | from typing import Any, Sequence
3 | import torch
4 |
5 |
6 | def base_distill_wrapper(Method=object):
7 | class BaseDistillWrapper(Method):
8 | def __init__(self, **kwargs) -> None:
9 | super().__init__(**kwargs)
10 |
11 | self.output_dim = kwargs["output_dim"]
12 |
13 | self.frozen_encoder = deepcopy(self.encoder)
14 | self.frozen_projector = deepcopy(self.projector)
15 |
16 | def on_train_start(self):
17 | super().on_train_start()
18 |
19 | if self.current_task_idx > 0:
20 |
21 | self.frozen_encoder = deepcopy(self.encoder)
22 | self.frozen_projector = deepcopy(self.projector)
23 |
24 | for pg in self.frozen_encoder.parameters():
25 | pg.requires_grad = False
26 | for pg in self.frozen_projector.parameters():
27 | pg.requires_grad = False
28 |
29 | @torch.no_grad()
30 | def frozen_forward(self, X):
31 | feats = self.frozen_encoder(X)
32 | return feats, self.frozen_projector(feats)
33 |
34 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
35 | _, (X1, X2), _ = batch[f"task{self.current_task_idx}"]
36 |
37 | out = super().training_step(batch, batch_idx)
38 |
39 | frozen_feats1, frozen_z1 = self.frozen_forward(X1)
40 | frozen_feats2, frozen_z2 = self.frozen_forward(X2)
41 |
42 | out.update(
43 | {"frozen_feats": [frozen_feats1, frozen_feats2], "frozen_z": [frozen_z1, frozen_z2]}
44 | )
45 | return out
46 |
47 | return BaseDistillWrapper
48 |
--------------------------------------------------------------------------------
/bash_files/byol_cifar_distill_classifier_l1000_soft_label_replay_0.01_b32.sh:
--------------------------------------------------------------------------------
1 | python3 main_continual.py \
2 | --dataset cifar100 \
3 | --encoder resnet18 \
4 | --data_dir $DATA_DIR \
5 | --split_strategy class \
6 | --max_epochs 500 \
7 | --num_tasks 5 \
8 | --task_idx 0 \
9 | --gpus 0 \
10 | --precision 16 \
11 | --optimizer sgd \
12 | --lars \
13 | --grad_clip_lars \
14 | --eta_lars 0.02 \
15 | --exclude_bias_n_norm \
16 | --scheduler warmup_cosine \
17 | --lr 1.0 \
18 | --classifier_lr 0.1 \
19 | --weight_decay 1e-5 \
20 | --batch_size 256 \
21 | --num_workers 2 \
22 | --brightness 0.4 \
23 | --contrast 0.4 \
24 | --saturation 0.2 \
25 | --hue 0.1 \
26 | --gaussian_prob 0.0 0.0 \
27 | --solarization_prob 0.0 0.2 \
28 | --name cifar100-byol-predictive-distill-classifier-l1000-soft-label-replay-0.01-b32 \
29 | --project ever-learn-2 \
30 | --entity your_entity \
31 | --offline \
32 | --wandb \
33 | --save_checkpoint \
34 | --output_dim 256 \
35 | --proj_hidden_dim 4096 \
36 | --pred_hidden_dim 4096 \
37 | --base_tau_momentum 0.99 \
38 | --final_tau_momentum 1.0 \
39 | --momentum_classifier \
40 | --disable_knn_eval \
41 | --online_eval_classifier_lr 0.1 \
42 | --classifier_training True \
43 | --classifier_stop_gradient True \
44 | --classifier_layers 1000 \
45 | --distiller_library factory \
46 | --method byol \
47 | --distiller predictive \
48 | --distiller_classifier soft_label \
49 | --classifier_distill_lamb 2.0 \
50 | --classifier_distill_no_predictior True \
51 | --replay True \
52 | --replay_proportion 0.01 \
53 | --replay_batch_size 32 \
54 | --online_evaluation True \
55 | --online_evaluation_training_data_source seen_tasks
56 |
--------------------------------------------------------------------------------
/kaizen/utils/whitening.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.cuda.amp import custom_fwd
4 | from torch.nn.functional import conv2d
5 |
6 |
7 | class Whitening2d(nn.Module):
8 | def __init__(self, output_dim: int, eps: float = 0.0):
9 | """Layer that computes hard whitening for W-MSE using the Cholesky decomposition.
10 |
11 | Args:
12 | output_dim (int): number of dimension of projected features.
13 | eps (float, optional): eps for numerical stability in Cholesky decomposition. Defaults
14 | to 0.0.
15 | """
16 |
17 | super(Whitening2d, self).__init__()
18 | self.output_dim = output_dim
19 | self.eps = eps
20 |
21 | @custom_fwd(cast_inputs=torch.float32)
22 | def forward(self, x: torch.Tensor) -> torch.Tensor:
23 | """Performs whitening using the Cholesky decomposition.
24 |
25 | Args:
26 | x (torch.Tensor): a batch or slice of projected features.
27 |
28 | Returns:
29 | torch.Tensor: a batch or slice of whitened features.
30 | """
31 |
32 | x = x.unsqueeze(2).unsqueeze(3)
33 | m = x.mean(0).view(self.output_dim, -1).mean(-1).view(1, -1, 1, 1)
34 | xn = x - m
35 |
36 | T = xn.permute(1, 0, 2, 3).contiguous().view(self.output_dim, -1)
37 | f_cov = torch.mm(T, T.permute(1, 0)) / (T.shape[-1] - 1)
38 |
39 | eye = torch.eye(self.output_dim).type(f_cov.type())
40 |
41 | f_cov_shrinked = (1 - self.eps) * f_cov + self.eps * eye
42 |
43 | inv_sqrt = torch.triangular_solve(eye, torch.cholesky(f_cov_shrinked), upper=False)[0]
44 | inv_sqrt = inv_sqrt.contiguous().view(self.output_dim, self.output_dim, 1, 1)
45 |
46 | decorrelated = conv2d(xn, inv_sqrt)
47 |
48 | return decorrelated.squeeze(2).squeeze(2)
49 |
--------------------------------------------------------------------------------
/bash_files/mocov2plus_cifar_distill_classifier_l1000_soft_label_replay_0.01_b32.sh:
--------------------------------------------------------------------------------
1 | python3 main_continual.py \
2 | --dataset cifar100 \
3 | --encoder resnet18 \
4 | --data_dir $DATA_DIR \
5 | --split_strategy class \
6 | --max_epochs 2 \
7 | --num_tasks 5 \
8 | --task_idx 0 \
9 | --gpus 0 \
10 | --precision 16 \
11 | --optimizer sgd \
12 | --lars \
13 | --grad_clip_lars \
14 | --eta_lars 0.02 \
15 | --exclude_bias_n_norm \
16 | --scheduler warmup_cosine \
17 | --lr 1.0 \
18 | --classifier_lr 0.1 \
19 | --weight_decay 1e-5 \
20 | --batch_size 256 \
21 | --num_workers 2 \
22 | --brightness 0.4 \
23 | --contrast 0.4 \
24 | --saturation 0.2 \
25 | --hue 0.1 \
26 | --gaussian_prob 0.0 0.0 \
27 | --solarization_prob 0.0 0.2 \
28 | --name cifar100-mocov2plus-contrastive-distill-classifier-l1000-soft-label-replay-0.01-b32 \
29 | --project ever-learn \
30 | --entity your_entity \
31 | --offline \
32 | --wandb \
33 | --save_checkpoint \
34 | --output_dim 256 \
35 | --proj_hidden_dim 2048 \
36 | --queue_size 65536 \
37 | --temperature 0.2 \
38 | --base_tau_momentum 0.99 \
39 | --final_tau_momentum 0.999 \
40 | --momentum_classifier \
41 | --disable_knn_eval \
42 | --online_eval_classifier_lr 0.1 \
43 | --classifier_training True \
44 | --classifier_stop_gradient True \
45 | --classifier_layers 1000 \
46 | --distiller_library factory \
47 | --method mocov2plus \
48 | --distiller contrastive \
49 | --classifier_distill_lamb 2.0 \
50 | --distiller_classifier soft_label \
51 | --classifier_distill_no_predictior True \
52 | --replay True \
53 | --replay_proportion 0.01 \
54 | --replay_batch_size 32 \
55 | --online_evaluation True \
56 | --online_evaluation_training_data_source seen_tasks
--------------------------------------------------------------------------------
/kaizen/utils/metrics.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Sequence
2 |
3 | import torch
4 |
5 |
6 | def accuracy_at_k(
7 | outputs: torch.Tensor, targets: torch.Tensor, top_k: Sequence[int] = (1, 5)
8 | ) -> Sequence[int]:
9 | """Computes the accuracy over the k top predictions for the specified values of k.
10 |
11 | Args:
12 | outputs (torch.Tensor): output of a classifier (logits or probabilities).
13 | targets (torch.Tensor): ground truth labels.
14 | top_k (Sequence[int], optional): sequence of top k values to compute the accuracy over.
15 | Defaults to (1, 5).
16 |
17 | Returns:
18 | Sequence[int]: accuracies at the desired k.
19 | """
20 |
21 | with torch.no_grad():
22 | maxk = max(top_k)
23 | batch_size = targets.size(0)
24 |
25 | _, pred = outputs.topk(maxk, 1, True, True)
26 | pred = pred.t()
27 | correct = pred.eq(targets.view(1, -1).expand_as(pred))
28 |
29 | res = []
30 | for k in top_k:
31 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
32 | res.append(correct_k.mul_(100.0 / batch_size))
33 | return res
34 |
35 |
36 | def weighted_mean(outputs: List[Dict], key: str, batch_size_key: str) -> float:
37 | """Computes the mean of the values of a key weighted by the batch size.
38 |
39 | Args:
40 | outputs (List[Dict]): list of dicts containing the outputs of a validation step.
41 | key (str): key of the metric of interest.
42 | batch_size_key (str): key of batch size values.
43 |
44 | Returns:
45 | float: weighted mean of the values of a key
46 | """
47 |
48 | value = 0
49 | n = 0
50 | for out in outputs:
51 | value += out[batch_size_key] * out[key]
52 | n += out[batch_size_key]
53 | value = value / n
54 | return value.squeeze(0)
55 |
--------------------------------------------------------------------------------
/kaizen/utils/trunc_normal.py:
--------------------------------------------------------------------------------
1 | import math
2 | import warnings
3 |
4 | import torch
5 |
6 |
7 | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
8 | """Copy & paste from PyTorch official master until it's in a few official releases - RW
9 | Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
10 | """
11 |
12 | def norm_cdf(x):
13 | """Computes standard normal cumulative distribution function"""
14 |
15 | return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
16 |
17 | if (mean < a - 2 * std) or (mean > b + 2 * std):
18 | warnings.warn(
19 | "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
20 | "The distribution of values may be incorrect.",
21 | stacklevel=2,
22 | )
23 |
24 | with torch.no_grad():
25 | # Values are generated by using a truncated uniform distribution and
26 | # then using the inverse CDF for the normal distribution.
27 | # Get upper and lower cdf values
28 | l = norm_cdf((a - mean) / std)
29 | u = norm_cdf((b - mean) / std)
30 |
31 | # Uniformly fill tensor with values from [l, u], then translate to
32 | # [2l-1, 2u-1].
33 | tensor.uniform_(2 * l - 1, 2 * u - 1)
34 |
35 | # Use inverse cdf transform for normal distribution to get truncated
36 | # standard normal
37 | tensor.erfinv_()
38 |
39 | # Transform to proper mean, std
40 | tensor.mul_(std * math.sqrt(2.0))
41 | tensor.add_(mean)
42 |
43 | # Clamp to ensure it's in the proper range
44 | tensor.clamp_(min=a, max=b)
45 | return tensor
46 |
47 |
48 | def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
49 | """Copy & paste from PyTorch official master until it's in a few official releases - RW
50 | Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
51 | """
52 |
53 | return _no_grad_trunc_normal_(tensor, mean, std, a, b)
54 |
--------------------------------------------------------------------------------
/bash_files/simclr_cifar_distill_classifier_l1000_soft_label_replay_0.01_b32.sh:
--------------------------------------------------------------------------------
1 | python3 main_continual.py \
2 | --dataset cifar100 \
3 | --encoder resnet18 \
4 | --data_dir $DATA_DIR \
5 | --split_strategy class \
6 | --max_epochs 500 \
7 | --num_tasks 5 \
8 | --task_idx 1 \
9 | --gpus 0 \
10 | --precision 16 \
11 | --optimizer sgd \
12 | --lars \
13 | --grad_clip_lars \
14 | --eta_lars 0.02 \
15 | --exclude_bias_n_norm \
16 | --scheduler warmup_cosine \
17 | --lr 1.0 \
18 | --classifier_lr 0.1 \
19 | --weight_decay 1e-5 \
20 | --batch_size 256 \
21 | --num_workers 3 \
22 | --brightness 0.4 \
23 | --contrast 0.4 \
24 | --saturation 0.2 \
25 | --hue 0.1 \
26 | --gaussian_prob 0.0 0.0 \
27 | --solarization_prob 0.0 0.2 \
28 | --name cifar100-simclr-contrastive-distill-classifier-l1000-soft-label-replay-0.01-b32 \
29 | --project ever-learn-2 \
30 | --entity your_entity \
31 | --offline \
32 | --wandb \
33 | --save_checkpoint \
34 | --output_dim 256 \
35 | --proj_hidden_dim 4096 \
36 | --pred_hidden_dim 4096 \
37 | --base_tau_momentum 0.99 \
38 | --final_tau_momentum 1.0 \
39 | --momentum_classifier \
40 | --disable_knn_eval \
41 | --online_eval_classifier_lr 0.1 \
42 | --classifier_training True \
43 | --classifier_stop_gradient True \
44 | --classifier_layers 1000 \
45 | --distiller_library factory \
46 | --method simclr \
47 | --distiller contrastive \
48 | --classifier_distill_lamb 2.0 \
49 | --distiller_classifier soft_label \
50 | --classifier_distill_no_predictior True \
51 | --replay True \
52 | --replay_proportion 0.01 \
53 | --replay_batch_size 32 \
54 | --online_evaluation True \
55 | --online_evaluation_training_data_source seen_tasks \
56 | --pretrained_model experiments/2022_08_31_14_49_21-cifar100-simclr-contrastive-distill-classifier-l1000-lamb5-soft-label-replay-0.1/task0-1o6mzete/cifar100-simclr-contrastive-distill-classifier-l1000-lamb5-soft-label-replay-0.1-task0-ep=499-1o6mzete.ckpt
57 |
--------------------------------------------------------------------------------
/bash_files/vicreg_cifar_distill_classifier_l1000_soft_label_replay_0.01_b32.sh:
--------------------------------------------------------------------------------
1 | python3 main_continual.py \
2 | --dataset cifar100 \
3 | --encoder resnet18 \
4 | --data_dir $DATA_DIR \
5 | --split_strategy class \
6 | --max_epochs 500 \
7 | --num_tasks 5 \
8 | --task_idx 4 \
9 | --gpus 0 \
10 | --precision 16 \
11 | --optimizer sgd \
12 | --lars \
13 | --grad_clip_lars \
14 | --eta_lars 0.02 \
15 | --exclude_bias_n_norm \
16 | --scheduler warmup_cosine \
17 | --lr 0.3 \
18 | --classifier_lr 0.05 \
19 | --weight_decay 1e-4 \
20 | --batch_size 256 \
21 | --num_workers 2 \
22 | --min_scale 0.2 \
23 | --brightness 0.4 \
24 | --contrast 0.4 \
25 | --saturation 0.2 \
26 | --hue 0.1 \
27 | --solarization_prob 0.1 \
28 | --gaussian_prob 0.0 0.0 \
29 | --name cifar100-vicreg-predictive_mse-distill-classifier-l1000-soft-label-replay-0.01-b32 \
30 | --project ever-learn-2 \
31 | --entity your_entity \
32 | --offline \
33 | --wandb \
34 | --save_checkpoint \
35 | --proj_hidden_dim 2048 \
36 | --output_dim 2048 \
37 | --sim_loss_weight 25.0 \
38 | --var_loss_weight 25.0 \
39 | --cov_loss_weight 1.0 \
40 | --disable_knn_eval \
41 | --online_eval_classifier_lr 0.1 \
42 | --classifier_training True \
43 | --classifier_stop_gradient True \
44 | --classifier_layers 1000 \
45 | --distiller_library factory \
46 | --method vicreg \
47 | --distiller predictive_mse \
48 | --classifier_distill_lamb 2.0 \
49 | --distiller_classifier soft_label \
50 | --classifier_distill_no_predictior True \
51 | --replay True \
52 | --replay_proportion 0.01 \
53 | --replay_batch_size 32 \
54 | --online_evaluation True \
55 | --online_evaluation_training_data_source seen_tasks \
56 | --pretrained_model experiments/2023_03_05_04_19_14-cifar100-vicreg-predictive_mse-distill-classifier-l1000-lamb2-soft-label-replay-0.01-b32/task3-6348uzl8/cifar100-vicreg-predictive_mse-distill-classifier-l1000-lamb2-soft-label-replay-0.01-b32-task3-ep=499-6348uzl8.ckpt
57 |
--------------------------------------------------------------------------------
/kaizen/utils/datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from torch.utils.data.dataset import Dataset
4 | from PIL import Image
5 |
6 |
7 | class DomainNetDataset(Dataset):
8 | def __init__(
9 | self,
10 | data_root,
11 | image_list_root,
12 | domain_names,
13 | split="train",
14 | transform=None,
15 | return_domain=False,
16 | ):
17 | self.data_root = data_root
18 | self.transform = transform
19 | self.domain_names = domain_names
20 | self.return_domain = return_domain
21 |
22 | if domain_names is None:
23 | self.domain_names = [
24 | "clipart",
25 | "infograph",
26 | "painting",
27 | "quickdraw",
28 | "real",
29 | "sketch",
30 | ]
31 | if not isinstance(domain_names, list):
32 | self.domain_name = [domain_names]
33 |
34 | image_list_paths = [
35 | os.path.join(image_list_root, d + "_" + split + ".txt") for d in self.domain_names
36 | ]
37 | self.imgs = self._make_dataset(image_list_paths)
38 |
39 | def _make_dataset(self, image_list_paths):
40 | images = []
41 | for image_list_path in image_list_paths:
42 | image_list = open(image_list_path).readlines()
43 | images += [(val.split()[0], int(val.split()[1])) for val in image_list]
44 | return images
45 |
46 | def _rgb_loader(self, path):
47 | with open(path, "rb") as f:
48 | with Image.open(f) as img:
49 | return img.convert("RGB")
50 |
51 | def __getitem__(self, index):
52 | path, target = self.imgs[index]
53 | img = self._rgb_loader(os.path.join(self.data_root, path))
54 |
55 | if self.transform is not None:
56 | img = self.transform(img)
57 |
58 | domain = None
59 | if self.return_domain:
60 | domain = [d for d in self.domain_names if d in path]
61 | assert len(domain) == 1
62 | domain = domain[0]
63 |
64 | return domain if self.return_domain else index, img, target
65 |
66 | def __len__(self):
67 | return len(self.imgs)
68 |
--------------------------------------------------------------------------------
/kaizen/distillers/predictive_mse.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from typing import Any, List, Sequence
3 |
4 | import torch
5 | from torch import nn
6 | from kaizen.distillers.base import base_distill_wrapper
7 | from kaizen.losses.vicreg import invariance_loss
8 |
9 |
10 | def predictive_mse_distill_wrapper(Method=object):
11 | class PredictiveMSEDistillWrapper(base_distill_wrapper(Method)):
12 | def __init__(self, distill_lamb: float, distill_proj_hidden_dim, **kwargs):
13 | super().__init__(**kwargs)
14 |
15 | self.distill_lamb = distill_lamb
16 | output_dim = kwargs["output_dim"]
17 |
18 | self.distill_predictor = nn.Sequential(
19 | nn.Linear(output_dim, distill_proj_hidden_dim),
20 | nn.BatchNorm1d(distill_proj_hidden_dim),
21 | nn.ReLU(),
22 | nn.Linear(distill_proj_hidden_dim, output_dim),
23 | )
24 |
25 | @staticmethod
26 | def add_model_specific_args(
27 | parent_parser: argparse.ArgumentParser,
28 | ) -> argparse.ArgumentParser:
29 | parser = parent_parser.add_argument_group("contrastive_distiller")
30 |
31 | parser.add_argument("--distill_lamb", type=float, default=25)
32 | parser.add_argument("--distill_proj_hidden_dim", type=int, default=2048)
33 |
34 | return parent_parser
35 |
36 | @property
37 | def learnable_params(self) -> List[dict]:
38 | """Adds distill predictor parameters to the parent's learnable parameters.
39 |
40 | Returns:
41 | List[dict]: list of learnable parameters.
42 | """
43 |
44 | extra_learnable_params = [
45 | {"params": self.distill_predictor.parameters()},
46 | ]
47 | return super().learnable_params + extra_learnable_params
48 |
49 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
50 | out = super().training_step(batch, batch_idx)
51 | z1, z2 = out["z"]
52 | frozen_z1, frozen_z2 = out["frozen_z"]
53 |
54 | p1 = self.distill_predictor(z1)
55 | p2 = self.distill_predictor(z2)
56 |
57 | distill_loss = (invariance_loss(p1, frozen_z1) + invariance_loss(p2, frozen_z2)) / 2
58 |
59 | self.log("train_predictive_distill_loss", distill_loss, on_epoch=True, sync_dist=True)
60 |
61 | return out["loss"] + self.distill_lamb * distill_loss
62 |
63 | return PredictiveMSEDistillWrapper
64 |
--------------------------------------------------------------------------------
/kaizen/distiller_factories/base.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | from typing import Any, Sequence
3 | import torch
4 |
5 |
6 | def base_frozen_model_factory(MethodClass=object):
7 | class BaseFrozenModel(MethodClass):
8 | def __init__(self, **kwargs) -> None:
9 | super().__init__(**kwargs)
10 |
11 | self.output_dim = kwargs["output_dim"]
12 | self.store_model_frozen_copy()
13 |
14 | def store_model_frozen_copy(self):
15 | self.frozen_encoder = deepcopy(self.encoder)
16 | self.frozen_projector = deepcopy(self.projector)
17 |
18 | for pg in self.frozen_encoder.parameters():
19 | pg.requires_grad = False
20 | for pg in self.frozen_projector.parameters():
21 | pg.requires_grad = False
22 | if self.classifier_training:
23 | self.frozen_classifier = deepcopy(self.classifier)
24 | for pg in self.frozen_classifier.parameters():
25 | pg.requires_grad = False
26 | else:
27 | self.frozen_classifier = None
28 |
29 | def on_train_start(self):
30 | super().on_train_start()
31 | self.store_model_frozen_copy()
32 |
33 | @torch.no_grad()
34 | def frozen_forward(self, X):
35 | feats_encoder = self.frozen_encoder(X)
36 | feats_projector = self.frozen_projector(feats_encoder)
37 | if self.frozen_classifier is not None:
38 | logits_classifier = self.frozen_classifier(feats_encoder)
39 | else:
40 | logits_classifier = None
41 | return feats_encoder, feats_projector, logits_classifier
42 |
43 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
44 | _, (X1, X2), _ = batch[f"task{self.current_task_idx}"]
45 | if "replay" in batch:
46 | *_, (X1R, X2R), _ = batch["replay"]
47 | X1 = torch.cat([X1, X1R])
48 | X2 = torch.cat([X2, X2R])
49 |
50 | out = super().training_step(batch, batch_idx)
51 |
52 | frozen_feats1, frozen_z1, frozen_logits1 = self.frozen_forward(X1)
53 | frozen_feats2, frozen_z2, frozen_logits2 = self.frozen_forward(X2)
54 |
55 | out.update({
56 | "frozen_feats": [frozen_feats1, frozen_feats2],
57 | "frozen_z": [frozen_z1, frozen_z2],
58 | "frozen_logits": [frozen_logits1, frozen_logits2]
59 | })
60 | return out
61 |
62 | return BaseFrozenModel
63 |
--------------------------------------------------------------------------------
/job_launcher.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import subprocess
4 | import argparse
5 | from datetime import datetime
6 | import inspect
7 | import shutil
8 |
9 |
10 | from main_continual import str_to_dict
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument("--script", type=str, required=True)
14 | parser.add_argument("--mode", type=str, default="normal")
15 | parser.add_argument("--experiment_dir", type=str, default=None)
16 | parser.add_argument("--base_experiment_dir", type=str, default="./experiments")
17 | parser.add_argument("--gpu", type=str, default="v100-16g")
18 | parser.add_argument("--num_gpus", type=int, default=2)
19 | parser.add_argument("--hours", type=int, default=20)
20 | parser.add_argument("--requeue", type=int, default=0)
21 |
22 | args = parser.parse_args()
23 |
24 | # load file
25 | if os.path.exists(args.script):
26 | with open(args.script) as f:
27 | command = [line.strip().strip("\\").strip() for line in f.readlines()]
28 | else:
29 | print(f"{args.script} does not exist.")
30 | exit()
31 |
32 | assert (
33 | "--checkpoint_dir" not in command
34 | ), "Please remove the --checkpoint_dir argument, it will be added automatically"
35 |
36 | # collect args
37 | command_args = str_to_dict(" ".join(command).split(" ")[2:])
38 |
39 | # create experiment directory
40 | if args.experiment_dir is None:
41 | args.experiment_dir = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
42 | args.experiment_dir += f"-{command_args['--name']}"
43 | full_experiment_dir = os.path.join(args.base_experiment_dir, args.experiment_dir)
44 | os.makedirs(full_experiment_dir, exist_ok=True) # Moved to main_continual.py
45 | print(f"Experiment directory: {full_experiment_dir}")
46 | shutil.copy(args.script, full_experiment_dir)
47 | # add experiment directory to the command
48 | command.extend(["--checkpoint_dir", full_experiment_dir])
49 | command = " ".join(command)
50 |
51 | print(command)
52 |
53 | # run command
54 | if args.mode == "normal":
55 | p = subprocess.Popen(command, shell=True, stdout=sys.stdout, stderr=sys.stdout)
56 | p.wait()
57 |
58 | elif args.mode == "slurm":
59 | # infer qos
60 | if 0 <= args.hours <= 2:
61 | qos = "qos_gpu-dev"
62 | elif args.hours <= 20:
63 | qos = "qos_gpu-t3"
64 | elif args.hours <= 100:
65 | qos = "qos_gpu-t4"
66 |
67 | # write command
68 | command_path = os.path.join(full_experiment_dir, "command.sh")
69 | with open(command_path, "w") as f:
70 | f.write(command)
71 |
72 | # run command
73 | p = subprocess.Popen(f"sbatch {command_path}", shell=True, stdout=sys.stdout, stderr=sys.stdout)
74 | p.wait()
75 |
--------------------------------------------------------------------------------
/kaizen/utils/sinkhorn_knopp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 |
4 |
5 | class SinkhornKnopp(torch.nn.Module):
6 | def __init__(self, num_iters: int = 3, epsilon: float = 0.05, world_size: int = 1):
7 | """Approximates optimal transport using the Sinkhorn-Knopp algorithm.
8 |
9 | A simple iterative method to approach the double stochastic matrix is to alternately rescale
10 | rows and columns of the matrix to sum to 1.
11 |
12 | Args:
13 | num_iters (int, optional): number of times to perform row and column normalization.
14 | Defaults to 3.
15 | epsilon (float, optional): weight for the entropy regularization term. Defaults to 0.05.
16 | world_size (int, optional): number of nodes for distributed training. Defaults to 1.
17 | """
18 |
19 | super().__init__()
20 | self.num_iters = num_iters
21 | self.epsilon = epsilon
22 | self.world_size = world_size
23 |
24 | @torch.no_grad()
25 | def forward(self, Q: torch.Tensor) -> torch.Tensor:
26 | """Produces assignments using Sinkhorn-Knopp algorithm.
27 |
28 | Applies the entropy regularization, normalizes the Q matrix and then normalizes rows and
29 | columns in an alternating fashion for num_iter times. Before returning it normalizes again
30 | the columns in order for the output to be an assignment of samples to prototypes.
31 |
32 | Args:
33 | Q (torch.Tensor): cosine similarities between the features of the
34 | samples and the prototypes.
35 |
36 | Returns:
37 | torch.Tensor: assignment of samples to prototypes according to optimal transport.
38 | """
39 |
40 | Q = torch.exp(Q / self.epsilon).t()
41 | B = Q.shape[1] * self.world_size
42 | K = Q.shape[0] # num prototypes
43 |
44 | # make the matrix sums to 1
45 | sum_Q = torch.sum(Q)
46 | if dist.is_available() and dist.is_initialized():
47 | dist.all_reduce(sum_Q)
48 | Q /= sum_Q
49 |
50 | for it in range(self.num_iters):
51 | # normalize each row: total weight per prototype must be 1/K
52 | sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
53 | if dist.is_available() and dist.is_initialized():
54 | dist.all_reduce(sum_of_rows)
55 | Q /= sum_of_rows
56 | Q /= K
57 |
58 | # normalize each column: total weight per sample must be 1/B
59 | Q /= torch.sum(Q, dim=0, keepdim=True)
60 | Q /= B
61 |
62 | Q *= B # the colomns must sum to 1 so that Q is an assignment
63 | return Q.t()
64 |
--------------------------------------------------------------------------------
/kaizen/utils/momentum.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | from torch import nn
5 |
6 |
7 | @torch.no_grad()
8 | def initialize_momentum_params(online_net: nn.Module, momentum_net: nn.Module):
9 | """Copies the parameters of the online network to the momentum network.
10 |
11 | Args:
12 | online_net (nn.Module): online network (e.g. online encoder, online projection, etc...).
13 | momentum_net (nn.Module): momentum network (e.g. momentum encoder,
14 | momentum projection, etc...).
15 | """
16 |
17 | params_online = online_net.parameters()
18 | params_momentum = momentum_net.parameters()
19 | for po, pm in zip(params_online, params_momentum):
20 | pm.data.copy_(po.data)
21 | pm.requires_grad = False
22 |
23 |
24 | class MomentumUpdater:
25 | def __init__(self, base_tau: float = 0.996, final_tau: float = 1.0):
26 | """Updates momentum parameters using exponential moving average.
27 |
28 | Args:
29 | base_tau (float, optional): base value of the weight decrease coefficient
30 | (should be in [0,1]). Defaults to 0.996.
31 | final_tau (float, optional): final value of the weight decrease coefficient
32 | (should be in [0,1]). Defaults to 1.0.
33 | """
34 |
35 | super().__init__()
36 |
37 | assert 0 <= base_tau <= 1
38 | assert 0 <= final_tau <= 1 and base_tau <= final_tau
39 |
40 | self.base_tau = base_tau
41 | self.cur_tau = base_tau
42 | self.final_tau = final_tau
43 |
44 | @torch.no_grad()
45 | def update(self, online_net: nn.Module, momentum_net: nn.Module):
46 | """Performs the momentum update for each param group.
47 |
48 | Args:
49 | online_net (nn.Module): online network (e.g. online encoder, online projection, etc...).
50 | momentum_net (nn.Module): momentum network (e.g. momentum encoder,
51 | momentum projection, etc...).
52 | """
53 |
54 | for op, mp in zip(online_net.parameters(), momentum_net.parameters()):
55 | mp.data = self.cur_tau * mp.data + (1 - self.cur_tau) * op.data
56 |
57 | def update_tau(self, cur_step: int, max_steps: int):
58 | """Computes the next value for the weighting decrease coefficient tau using cosine annealing.
59 |
60 | Args:
61 | cur_step (int): number of gradient steps so far.
62 | max_steps (int): overall number of gradient steps in the whole training.
63 | """
64 |
65 | self.cur_tau = (
66 | self.final_tau
67 | - (self.final_tau - self.base_tau) * (math.cos(math.pi * cur_step / max_steps) + 1) / 2
68 | )
69 |
--------------------------------------------------------------------------------
/kaizen/distillers/predictive.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from typing import Any, List, Sequence
3 |
4 | import torch
5 | from torch import nn
6 | from kaizen.distillers.base import base_distill_wrapper
7 | from kaizen.losses.byol import byol_loss_func
8 |
9 |
10 | def predictive_distill_wrapper(Method=object):
11 | class PredictiveDistillWrapper(base_distill_wrapper(Method)):
12 | def __init__(self, distill_lamb: float, distill_proj_hidden_dim, **kwargs):
13 | super().__init__(**kwargs)
14 |
15 | self.distill_lamb = distill_lamb
16 | output_dim = kwargs["output_dim"]
17 |
18 | self.distill_predictor = nn.Sequential(
19 | nn.Linear(output_dim, distill_proj_hidden_dim),
20 | nn.BatchNorm1d(distill_proj_hidden_dim),
21 | nn.ReLU(),
22 | nn.Linear(distill_proj_hidden_dim, output_dim),
23 | )
24 |
25 | @staticmethod
26 | def add_model_specific_args(
27 | parent_parser: argparse.ArgumentParser,
28 | ) -> argparse.ArgumentParser:
29 | parser = parent_parser.add_argument_group("contrastive_distiller")
30 |
31 | parser.add_argument("--distill_lamb", type=float, default=1)
32 | parser.add_argument("--distill_proj_hidden_dim", type=int, default=2048)
33 |
34 | return parent_parser
35 |
36 | @property
37 | def learnable_params(self) -> List[dict]:
38 | """Adds distill predictor parameters to the parent's learnable parameters.
39 |
40 | Returns:
41 | List[dict]: list of learnable parameters.
42 | """
43 |
44 | extra_learnable_params = [
45 | {
46 | "name": "distill_predictor",
47 | "params": self.distill_predictor.parameters(),
48 | "lr": self.lr if self.distill_lamb >= 1 else self.lr / self.distill_lamb,
49 | "weight_decay": self.weight_decay,
50 | },
51 | ]
52 | return super().learnable_params + extra_learnable_params
53 |
54 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
55 | out = super().training_step(batch, batch_idx)
56 | z1, z2 = out["z"]
57 | frozen_z1, frozen_z2 = out["frozen_z"]
58 |
59 | p1 = self.distill_predictor(z1)
60 | p2 = self.distill_predictor(z2)
61 |
62 | distill_loss = (byol_loss_func(p1, frozen_z1) + byol_loss_func(p2, frozen_z2)) / 2
63 |
64 | self.log("train_predictive_distill_loss", distill_loss, on_epoch=True, sync_dist=True)
65 |
66 | return out["loss"] + self.distill_lamb * distill_loss
67 |
68 | return PredictiveDistillWrapper
69 |
--------------------------------------------------------------------------------
/kaizen/args/dataset.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | from pathlib import Path
3 | from .utils import strtobool
4 |
5 | def dataset_args(parser: ArgumentParser):
6 | """Adds dataset-related arguments to a parser.
7 |
8 | Args:
9 | parser (ArgumentParser): parser to add dataset args to.
10 | """
11 |
12 | SUPPORTED_DATASETS = [
13 | "cifar10",
14 | "cifar100",
15 | "stl10",
16 | "imagenet",
17 | "imagenet100",
18 | "domainnet",
19 | "custom",
20 | ]
21 |
22 | parser.add_argument("--dataset", choices=SUPPORTED_DATASETS, type=str, required=True)
23 |
24 | # dataset path
25 | parser.add_argument("--data_dir", type=Path, required=True)
26 | parser.add_argument("--train_dir", type=Path, default=None)
27 | parser.add_argument("--val_dir", type=Path, default=None)
28 |
29 | # dali (imagenet-100/imagenet/custom only)
30 | parser.add_argument("--dali", action="store_true")
31 | parser.add_argument("--dali_device", type=str, default="gpu")
32 |
33 | # custom dataset only
34 | parser.add_argument("--no_labels", action="store_true")
35 | parser.add_argument("--semi_supervised", default=None, type=float)
36 |
37 | parser.add_argument("--split_seed", type=int, default=5)
38 | parser.add_argument("--global_seed", type=int, default=5)
39 |
40 |
41 | def augmentations_args(parser: ArgumentParser):
42 | """Adds augmentation-related arguments to a parser.
43 |
44 | Args:
45 | parser (ArgumentParser): parser to add augmentation args to.
46 | """
47 |
48 | # cropping
49 | parser.add_argument("--multicrop", action="store_true")
50 | parser.add_argument("--num_crops", type=int, default=2)
51 | parser.add_argument("--num_small_crops", type=int, default=0)
52 |
53 | # augmentations
54 | parser.add_argument("--brightness", type=float, required=True, nargs="+")
55 | parser.add_argument("--contrast", type=float, required=True, nargs="+")
56 | parser.add_argument("--saturation", type=float, required=True, nargs="+")
57 | parser.add_argument("--hue", type=float, required=True, nargs="+")
58 | parser.add_argument("--gaussian_prob", type=float, default=[0.5], nargs="+")
59 | parser.add_argument("--solarization_prob", type=float, default=[0.0], nargs="+")
60 | parser.add_argument("--min_scale", type=float, default=[0.08], nargs="+")
61 |
62 | # for imagenet or custom dataset
63 | parser.add_argument("--size", type=int, default=[224], nargs="+")
64 |
65 | # for custom dataset
66 | parser.add_argument("--mean", type=float, default=[0.485, 0.456, 0.406], nargs="+")
67 | parser.add_argument("--std", type=float, default=[0.228, 0.224, 0.225], nargs="+")
68 |
69 | # debug
70 | parser.add_argument("--debug_augmentations", action="store_true")
71 |
--------------------------------------------------------------------------------
/kaizen/distillers/contrastive.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from typing import Any, List, Sequence
3 |
4 | import torch
5 | from torch import nn
6 | from kaizen.distillers.base import base_distill_wrapper
7 | from kaizen.losses.simclr import simclr_distill_loss_func
8 |
9 |
10 | def contrastive_distill_wrapper(Method=object):
11 | class ContrastiveDistillWrapper(base_distill_wrapper(Method)):
12 | def __init__(
13 | self,
14 | distill_lamb: float,
15 | distill_proj_hidden_dim: int,
16 | distill_temperature: float,
17 | **kwargs
18 | ):
19 | super().__init__(**kwargs)
20 |
21 | self.distill_lamb = distill_lamb
22 | self.distill_temperature = distill_temperature
23 | output_dim = kwargs["output_dim"]
24 |
25 | self.distill_predictor = nn.Sequential(
26 | nn.Linear(output_dim, distill_proj_hidden_dim),
27 | nn.BatchNorm1d(distill_proj_hidden_dim),
28 | nn.ReLU(),
29 | nn.Linear(distill_proj_hidden_dim, output_dim),
30 | )
31 |
32 | @staticmethod
33 | def add_model_specific_args(
34 | parent_parser: argparse.ArgumentParser,
35 | ) -> argparse.ArgumentParser:
36 | parser = parent_parser.add_argument_group("contrastive_distiller")
37 |
38 | parser.add_argument("--distill_lamb", type=float, default=1)
39 | parser.add_argument("--distill_proj_hidden_dim", type=int, default=2048)
40 | parser.add_argument("--distill_temperature", type=float, default=0.2)
41 |
42 | return parent_parser
43 |
44 | @property
45 | def learnable_params(self) -> List[dict]:
46 | """Adds distill predictor parameters to the parent's learnable parameters.
47 |
48 | Returns:
49 | List[dict]: list of learnable parameters.
50 | """
51 |
52 | extra_learnable_params = [
53 | {"params": self.distill_predictor.parameters()},
54 | ]
55 | return super().learnable_params + extra_learnable_params
56 |
57 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
58 | out = super().training_step(batch, batch_idx)
59 | z1, z2 = out["z"]
60 | frozen_z1, frozen_z2 = out["frozen_z"]
61 |
62 | p1 = self.distill_predictor(z1)
63 | p2 = self.distill_predictor(z2)
64 |
65 | distill_loss = (
66 | simclr_distill_loss_func(p1, p2, frozen_z1, frozen_z2, self.distill_temperature)
67 | + simclr_distill_loss_func(frozen_z1, frozen_z2, p1, p2, self.distill_temperature)
68 | ) / 2
69 |
70 | self.log("train_contrastive_distill_loss", distill_loss, on_epoch=True, sync_dist=True)
71 |
72 | return out["loss"] + self.distill_lamb * distill_loss
73 |
74 | return ContrastiveDistillWrapper
75 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | experiments/
2 | trained_models/
3 |
4 | # checkpoint
5 | last_checkpoint.txt
6 |
7 | # tensorboard dir
8 | runs/
9 |
10 | # wandb dir
11 | wandb/
12 | wandb*/
13 |
14 | # umap dir
15 | auto_umap/
16 |
17 | # datasets dir
18 | datasets/
19 | # saved models
20 | *.pt
21 | *.pth
22 | *.ckpt
23 | *.tar
24 |
25 | *.png
26 | !imgs/*.png
27 | *.jpg
28 | *.jpeg
29 |
30 | saved_models/
31 | model_storage/
32 | model_storage*/
33 | lightning_logs/
34 |
35 | *.json
36 |
37 | *logs*/
38 |
39 | # Created by https://www.gitignore.io/api/python,visualstudiocode
40 | # Edit at https://www.gitignore.io/?templates=python,visualstudiocode
41 |
42 | ### Python ###
43 | # Byte-compiled / optimized / DLL files
44 | __pycache__/
45 | *.py[cod]
46 | *$py.class
47 |
48 | # C extensions
49 | *.so
50 |
51 | # Distribution / packaging
52 | .Python
53 | build/
54 | develop-eggs/
55 | dist/
56 | downloads/
57 | eggs/
58 | .eggs/
59 | lib/
60 | lib64/
61 | parts/
62 | sdist/
63 | var/
64 | wheels/
65 | pip-wheel-metadata/
66 | share/python-wheels/
67 | *.egg-info/
68 | .installed.cfg
69 | *.egg
70 | MANIFEST
71 |
72 | # PyInstaller
73 | # Usually these files are written by a python script from a template
74 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
75 | *.manifest
76 | *.spec
77 |
78 | # Installer logs
79 | pip-log.txt
80 | pip-delete-this-directory.txt
81 |
82 | # Unit test / coverage reports
83 | htmlcov/
84 | .tox/
85 | .nox/
86 | .coverage
87 | .coverage.*
88 | .cache
89 | nosetests.xml
90 | coverage.xml
91 | *.cover
92 | .hypothesis/
93 | .pytest_cache/
94 |
95 | # Translations
96 | *.mo
97 | *.pot
98 |
99 | # Scrapy stuff:
100 | .scrapy
101 |
102 | # Sphinx documentation
103 | docs/_build/
104 |
105 | # PyBuilder
106 | target/
107 |
108 | # pyenv
109 | .python-version
110 |
111 | # pipenv
112 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
113 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
114 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
115 | # install all needed dependencies.
116 | #Pipfile.lock
117 |
118 | # celery beat schedule file
119 | celerybeat-schedule
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Spyder project settings
125 | .spyderproject
126 | .spyproject
127 |
128 | # Rope project settings
129 | .ropeproject
130 |
131 | # Mr Developer
132 | .mr.developer.cfg
133 | .project
134 | .pydevproject
135 |
136 | # mkdocs documentation
137 | /site
138 |
139 | # mypy
140 | .mypy_cache/
141 | .dmypy.json
142 | dmypy.json
143 |
144 | # Pyre type checker
145 | .pyre/
146 |
147 | ### VisualStudioCode ###
148 | .vscode/*
149 | !.vscode/settings.json
150 | !.vscode/tasks.json
151 | !.vscode/launch.json
152 | !.vscode/extensions.json
153 |
154 | ### VisualStudioCode Patch ###
155 | # Ignore all local history of files
156 | .history
157 |
158 | # End of https://www.gitignore.io/api/python,visualstudiocode
159 |
--------------------------------------------------------------------------------
/kaizen/losses/vicreg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def invariance_loss(z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor:
6 | """Computes mse loss given batch of projected features z1 from view 1 and
7 | projected features z2 from view 2.
8 |
9 | Args:
10 | z1 (torch.Tensor): NxD Tensor containing projected features from view 1.
11 | z2 (torch.Tensor): NxD Tensor containing projected features from view 2.
12 |
13 | Returns:
14 | torch.Tensor: invariance loss (mean squared error).
15 | """
16 |
17 | return F.mse_loss(z1, z2)
18 |
19 |
20 | def variance_loss(z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor:
21 | """Computes variance loss given batch of projected features z1 from view 1 and
22 | projected features z2 from view 2.
23 |
24 | Args:
25 | z1 (torch.Tensor): NxD Tensor containing projected features from view 1.
26 | z2 (torch.Tensor): NxD Tensor containing projected features from view 2.
27 |
28 | Returns:
29 | torch.Tensor: variance regularization loss.
30 | """
31 |
32 | eps = 1e-4
33 | std_z1 = torch.sqrt(z1.var(dim=0) + eps)
34 | std_z2 = torch.sqrt(z2.var(dim=0) + eps)
35 | std_loss = torch.mean(F.relu(1 - std_z1)) + torch.mean(F.relu(1 - std_z2))
36 | return std_loss
37 |
38 |
39 | def covariance_loss(z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor:
40 | """Computes covariance loss given batch of projected features z1 from view 1 and
41 | projected features z2 from view 2.
42 |
43 | Args:
44 | z1 (torch.Tensor): NxD Tensor containing projected features from view 1.
45 | z2 (torch.Tensor): NxD Tensor containing projected features from view 2.
46 |
47 | Returns:
48 | torch.Tensor: covariance regularization loss.
49 | """
50 |
51 | N, D = z1.size()
52 |
53 | z1 = z1 - z1.mean(dim=0)
54 | z2 = z2 - z2.mean(dim=0)
55 | cov_z1 = (z1.T @ z1) / (N - 1)
56 | cov_z2 = (z2.T @ z2) / (N - 1)
57 |
58 | diag = torch.eye(D, device=z1.device)
59 | cov_loss = cov_z1[~diag.bool()].pow_(2).sum() / D + cov_z2[~diag.bool()].pow_(2).sum() / D
60 | return cov_loss
61 |
62 |
63 | def vicreg_loss_func(
64 | z1: torch.Tensor,
65 | z2: torch.Tensor,
66 | sim_loss_weight: float = 25.0,
67 | var_loss_weight: float = 25.0,
68 | cov_loss_weight: float = 1.0,
69 | ) -> torch.Tensor:
70 | """Computes VICReg's loss given batch of projected features z1 from view 1 and
71 | projected features z2 from view 2.
72 |
73 | Args:
74 | z1 (torch.Tensor): NxD Tensor containing projected features from view 1.
75 | z2 (torch.Tensor): NxD Tensor containing projected features from view 2.
76 | sim_loss_weight (float): invariance loss weight.
77 | var_loss_weight (float): variance loss weight.
78 | cov_loss_weight (float): covariance loss weight.
79 |
80 | Returns:
81 | torch.Tensor: VICReg loss.
82 | """
83 |
84 | sim_loss = invariance_loss(z1, z2)
85 | var_loss = variance_loss(z1, z2)
86 | cov_loss = covariance_loss(z1, z2)
87 |
88 | loss = sim_loss_weight * sim_loss + var_loss_weight * var_loss + cov_loss_weight * cov_loss
89 | return loss
90 |
--------------------------------------------------------------------------------
/kaizen/distillers/decorrelative.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from typing import Any, List, Sequence
3 |
4 | import torch
5 | from torch import nn
6 | from kaizen.distillers.base import base_distill_wrapper
7 | from kaizen.losses.barlow import barlow_loss_func
8 |
9 |
10 | def decorrelative_distill_wrapper(Method=object):
11 | class DecorrelativeDistillWrapper(base_distill_wrapper(Method)):
12 | def __init__(
13 | self,
14 | distill_lamb: float,
15 | distill_proj_hidden_dim: int,
16 | distill_barlow_lamb: float,
17 | distill_scale_loss: float,
18 | **kwargs
19 | ):
20 | super().__init__(**kwargs)
21 |
22 | output_dim = kwargs["output_dim"]
23 | self.distill_lamb = distill_lamb
24 | self.distill_barlow_lamb = distill_barlow_lamb
25 | self.distill_scale_loss = distill_scale_loss
26 |
27 | self.distill_predictor = nn.Sequential(
28 | nn.Linear(output_dim, distill_proj_hidden_dim),
29 | nn.BatchNorm1d(distill_proj_hidden_dim),
30 | nn.ReLU(),
31 | nn.Linear(distill_proj_hidden_dim, output_dim),
32 | )
33 |
34 | @staticmethod
35 | def add_model_specific_args(
36 | parent_parser: argparse.ArgumentParser,
37 | ) -> argparse.ArgumentParser:
38 | parser = parent_parser.add_argument_group("contrastive_distiller")
39 |
40 | parser.add_argument("--distill_lamb", type=float, default=1)
41 | parser.add_argument("--distill_proj_hidden_dim", type=int, default=2048)
42 | parser.add_argument("--distill_barlow_lamb", type=float, default=5e-3)
43 | parser.add_argument("--distill_scale_loss", type=float, default=0.1)
44 |
45 | return parent_parser
46 |
47 | @property
48 | def learnable_params(self) -> List[dict]:
49 | """Adds distill predictor parameters to the parent's learnable parameters.
50 |
51 | Returns:
52 | List[dict]: list of learnable parameters.
53 | """
54 |
55 | extra_learnable_params = [
56 | {"params": self.distill_predictor.parameters()},
57 | ]
58 | return super().learnable_params + extra_learnable_params
59 |
60 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
61 | out = super().training_step(batch, batch_idx)
62 | z1, z2 = out["z"]
63 | frozen_z1, frozen_z2 = out["frozen_z"]
64 |
65 | p1 = self.distill_predictor(z1)
66 | p2 = self.distill_predictor(z2)
67 |
68 | distill_loss = (
69 | barlow_loss_func(
70 | p1,
71 | frozen_z1,
72 | lamb=self.distill_barlow_lamb,
73 | scale_loss=self.distill_scale_loss,
74 | )
75 | + barlow_loss_func(
76 | p2,
77 | frozen_z2,
78 | lamb=self.distill_barlow_lamb,
79 | scale_loss=self.distill_scale_loss,
80 | )
81 | ) / 2
82 |
83 | self.log(
84 | "train_decorrelative_distill_loss", distill_loss, on_epoch=True, sync_dist=True
85 | )
86 |
87 | return out["loss"] + self.distill_lamb * distill_loss
88 |
89 | return DecorrelativeDistillWrapper
90 |
--------------------------------------------------------------------------------
/main_continual.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import itertools
3 | import subprocess
4 | import sys
5 | import os
6 | import json
7 |
8 |
9 | def str_to_dict(command):
10 | d = {}
11 | for part, part_next in itertools.zip_longest(command[:-1], command[1:]):
12 | if part[:2] == "--":
13 | if part_next[:2] != "--":
14 | d[part] = part_next
15 | else:
16 | d[part] = part
17 | elif part[:2] != "--" and part_next[:2] != "--":
18 | part_prev = list(d.keys())[-1]
19 | if not isinstance(d[part_prev], list):
20 | d[part_prev] = [d[part_prev]]
21 | if not part_next[:2] == "--":
22 | d[part_prev].append(part_next)
23 | return d
24 |
25 |
26 | def dict_to_list(command):
27 | s = []
28 | for k, v in command.items():
29 | s.append(k)
30 | if k != v and v[:2] != "--":
31 | s.append(v)
32 | return s
33 |
34 |
35 | def run_bash_command(args):
36 | for i, a in enumerate(args):
37 | if isinstance(a, list):
38 | args[i] = " ".join(a)
39 | command = ("python3 main_pretrain.py", *args)
40 | command = " ".join(command)
41 | p = subprocess.Popen(command, shell=True)
42 | p.wait()
43 |
44 |
45 | if __name__ == "__main__":
46 | args = sys.argv[1:]
47 | args = str_to_dict(args)
48 | os.makedirs(args['--checkpoint_dir'], exist_ok=True)
49 |
50 | # parse args from the script
51 | num_tasks = int(args["--num_tasks"])
52 | start_task_idx = int(args.get("--task_idx", 0))
53 | distill_args = {k: v for k, v in args.items() if "distill" in k}
54 |
55 | # delete things that shouldn't be used for task_idx 0
56 | args.pop("--task_idx", None)
57 | for k in distill_args.keys():
58 | args.pop(k, None)
59 |
60 | # check if this experiment is being resumed
61 | # look for the file last_checkpoint.txt
62 | last_checkpoint_file = os.path.join(args["--checkpoint_dir"], "last_checkpoint.txt")
63 | if os.path.exists(last_checkpoint_file):
64 | with open(last_checkpoint_file) as f:
65 | ckpt_path, args_path = [line.rstrip() for line in f.readlines()]
66 | start_task_idx = json.load(open(args_path))["task_idx"]
67 | args["--resume_from_checkpoint"] = ckpt_path
68 |
69 | # main task loop
70 | for task_idx in range(start_task_idx, num_tasks):
71 | print(f"\n#### Starting Task {task_idx} ####")
72 |
73 | task_args = copy.deepcopy(args)
74 |
75 | print(task_idx, start_task_idx, task_args)
76 |
77 | # add pretrained model arg
78 | if task_args.get("--no_continual_learning", "False").startswith("True"):
79 | task_args.pop("--resume_from_checkpoint", None)
80 | task_args.pop("--pretrained_model", None)
81 |
82 | elif task_idx != 0 and task_idx != start_task_idx:
83 | task_args.pop("--resume_from_checkpoint", None)
84 | task_args.pop("--pretrained_model", None)
85 | assert os.path.exists(last_checkpoint_file)
86 | ckpt_path = open(last_checkpoint_file).readlines()[0].rstrip()
87 | task_args["--pretrained_model"] = ckpt_path
88 |
89 | if not task_args.get("--no_distill", "False").startswith("True"):
90 | if task_idx != 0 and distill_args:
91 | task_args.update(distill_args)
92 |
93 | task_args["--task_idx"] = str(task_idx)
94 | task_args = dict_to_list(task_args)
95 |
96 | run_bash_command(task_args)
97 |
--------------------------------------------------------------------------------
/kaizen/distiller_factories/predictive_mse.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from typing import Any, List, Sequence
3 |
4 | import torch
5 | from torch import nn
6 | from kaizen.losses.vicreg import invariance_loss
7 | from kaizen.args.utils import strtobool
8 |
9 | def predictive_mse_distill_factory(
10 | Method=object, class_tag="",
11 | distill_current_key="z", distill_frozen_key="frozen_z", output_dim=256
12 | ):
13 | distill_lamb_name = f"{class_tag}_distill_lamb"
14 | distill_proj_hidden_dim_name = f"{class_tag}_distill_proj_hidden_dim"
15 | distill_no_predictior_name = f"{class_tag}_distill_no_predictior"
16 |
17 | distill_predictor_name = f"{class_tag}_distill_predictor"
18 | class PredictiveMSEDistillWrapper(Method):
19 | def __init__(self, **kwargs):
20 | distill_lamb = kwargs.pop(distill_lamb_name)
21 | distill_proj_hidden_dim = kwargs.pop(distill_proj_hidden_dim_name)
22 | distill_no_predictior = kwargs.pop(distill_no_predictior_name)
23 | super().__init__(**kwargs)
24 |
25 | setattr(self, distill_lamb_name, distill_lamb)
26 | setattr(self, distill_proj_hidden_dim_name, distill_proj_hidden_dim)
27 | setattr(self, distill_no_predictior_name, distill_no_predictior)
28 | if distill_no_predictior:
29 | setattr(self, distill_predictor_name, nn.Identity())
30 | else:
31 | setattr(self, distill_predictor_name, nn.Sequential(
32 | nn.Linear(output_dim, distill_proj_hidden_dim),
33 | nn.BatchNorm1d(distill_proj_hidden_dim),
34 | nn.ReLU(),
35 | nn.Linear(distill_proj_hidden_dim, output_dim),
36 | ))
37 |
38 | @staticmethod
39 | def add_model_specific_args(
40 | parent_parser: argparse.ArgumentParser,
41 | ) -> argparse.ArgumentParser:
42 | parser = parent_parser.add_argument_group(f"predictive_mse_{class_tag}_distiller")
43 |
44 | parser.add_argument(f"--{distill_lamb_name}", type=float, default=25)
45 | parser.add_argument(f"--{distill_proj_hidden_dim_name}", type=int, default=2048)
46 | parser.add_argument(f"--{distill_no_predictior_name}", type=strtobool, default=False)
47 |
48 | return parent_parser
49 |
50 | @property
51 | def learnable_params(self) -> List[dict]:
52 | """Adds distill predictor parameters to the parent's learnable parameters.
53 |
54 | Returns:
55 | List[dict]: list of learnable parameters.
56 | """
57 |
58 | extra_learnable_params = [
59 | {
60 | "params": getattr(self, distill_predictor_name).parameters()
61 | },
62 | ]
63 | return super().learnable_params + extra_learnable_params
64 |
65 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
66 | out = super().training_step(batch, batch_idx)
67 | z1, z2 = out[distill_current_key]
68 | frozen_z1, frozen_z2 = out[distill_frozen_key]
69 |
70 | p1 = getattr(self, distill_predictor_name)(z1)
71 | p2 = getattr(self, distill_predictor_name)(z2)
72 |
73 | distill_loss = (invariance_loss(p1, frozen_z1) + invariance_loss(p2, frozen_z2)) / 2
74 |
75 | self.log(f"train_{class_tag}_predictive_mse_distill_loss", distill_loss, on_epoch=True, sync_dist=True)
76 |
77 | out["loss"] += getattr(self, distill_lamb_name) * distill_loss
78 | return out
79 |
80 | return PredictiveMSEDistillWrapper
81 |
--------------------------------------------------------------------------------
/kaizen/methods/barlow_twins.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from typing import Any, List, Sequence
3 |
4 | import torch
5 | import torch.nn as nn
6 | from kaizen.losses.barlow import barlow_loss_func
7 | from kaizen.methods.base import BaseModel
8 |
9 |
10 | class BarlowTwins(BaseModel):
11 | def __init__(
12 | self, proj_hidden_dim: int, output_dim: int, lamb: float, scale_loss: float, **kwargs
13 | ):
14 | """Implements Barlow Twins (https://arxiv.org/abs/2103.03230)
15 |
16 | Args:
17 | proj_hidden_dim (int): number of neurons of the hidden layers of the projector.
18 | output_dim (int): number of dimensions of projected features.
19 | lamb (float): off-diagonal scaling factor for the cross-covariance matrix.
20 | scale_loss (float): scaling factor of the loss.
21 | """
22 |
23 | super().__init__(**kwargs)
24 |
25 | self.lamb = lamb
26 | self.scale_loss = scale_loss
27 |
28 | # projector
29 | self.projector = nn.Sequential(
30 | nn.Linear(self.features_dim, proj_hidden_dim),
31 | nn.BatchNorm1d(proj_hidden_dim),
32 | nn.ReLU(),
33 | nn.Linear(proj_hidden_dim, proj_hidden_dim),
34 | nn.BatchNorm1d(proj_hidden_dim),
35 | nn.ReLU(),
36 | nn.Linear(proj_hidden_dim, output_dim),
37 | )
38 |
39 | @staticmethod
40 | def add_model_specific_args(parent_parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
41 | parent_parser = super(BarlowTwins, BarlowTwins).add_model_specific_args(parent_parser)
42 | parser = parent_parser.add_argument_group("barlow_twins")
43 |
44 | # projector
45 | parser.add_argument("--output_dim", type=int, default=2048)
46 | parser.add_argument("--proj_hidden_dim", type=int, default=2048)
47 |
48 | # parameters
49 | parser.add_argument("--lamb", type=float, default=5e-3)
50 | parser.add_argument("--scale_loss", type=float, default=0.025)
51 | return parent_parser
52 |
53 | @property
54 | def learnable_params(self) -> List[dict]:
55 | """Adds projector parameters to parent's learnable parameters.
56 |
57 | Returns:
58 | List[dict]: list of learnable parameters.
59 | """
60 |
61 | extra_learnable_params = [{"params": self.projector.parameters()}]
62 | return super().learnable_params + extra_learnable_params
63 |
64 | def forward(self, X, *args, **kwargs):
65 | out = super().forward(X, *args, **kwargs)
66 | z = self.projector(out["feats"])
67 | return {**out, "z": z}
68 |
69 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
70 | """Training step for Barlow Twins reusing BaseModel training step.
71 |
72 | Args:
73 | batch (Sequence[Any]): a batch of data in the format of [img_indexes, [X], Y], where
74 | [X] is a list of size self.num_crops containing batches of images.
75 | batch_idx (int): index of the batch.
76 |
77 | Returns:
78 | torch.Tensor: total loss composed of Barlow loss and classification loss.
79 | """
80 |
81 | out = super().training_step(batch, batch_idx)
82 |
83 | feats1, feats2 = out["feats"]
84 |
85 | z1 = self.projector(feats1)
86 | z2 = self.projector(feats2)
87 |
88 | # ------- barlow twins loss -------
89 | barlow_loss = barlow_loss_func(z1, z2, lamb=self.lamb, scale_loss=self.scale_loss)
90 |
91 | self.log("train_barlow_loss", barlow_loss, on_epoch=True, sync_dist=True)
92 |
93 | out.update({"loss": out["loss"] + barlow_loss, "z": [z1, z2]})
94 | return out
95 |
--------------------------------------------------------------------------------
/kaizen/distiller_factories/predictive.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from typing import Any, List, Sequence
3 |
4 | import torch
5 | from torch import nn
6 | from kaizen.losses.byol import byol_loss_func
7 | from kaizen.args.utils import strtobool
8 |
9 | def predictive_distill_factory(
10 | Method=object, class_tag="",
11 | distill_current_key="z", distill_frozen_key="frozen_z", output_dim=256
12 | ):
13 | distill_lamb_name = f"{class_tag}_distill_lamb"
14 | distill_proj_hidden_dim_name = f"{class_tag}_distill_proj_hidden_dim"
15 | distill_no_predictior_name = f"{class_tag}_distill_no_predictior"
16 |
17 | distill_predictor_name = f"{class_tag}_distill_predictor"
18 | class PredictiveDistillWrapper(Method):
19 | def __init__(self, **kwargs):
20 | distill_lamb = kwargs.pop(distill_lamb_name)
21 | distill_proj_hidden_dim = kwargs.pop(distill_proj_hidden_dim_name)
22 | distill_no_predictior = kwargs.pop(distill_no_predictior_name)
23 | super().__init__(**kwargs)
24 |
25 | setattr(self, distill_lamb_name, distill_lamb)
26 | setattr(self, distill_proj_hidden_dim_name, distill_proj_hidden_dim)
27 | setattr(self, distill_no_predictior_name, distill_no_predictior)
28 | if distill_no_predictior:
29 | setattr(self, distill_predictor_name, nn.Identity())
30 | else:
31 | setattr(self, distill_predictor_name, nn.Sequential(
32 | nn.Linear(output_dim, distill_proj_hidden_dim),
33 | nn.BatchNorm1d(distill_proj_hidden_dim),
34 | nn.ReLU(),
35 | nn.Linear(distill_proj_hidden_dim, output_dim),
36 | ))
37 |
38 |
39 | @staticmethod
40 | def add_model_specific_args(
41 | parent_parser: argparse.ArgumentParser,
42 | ) -> argparse.ArgumentParser:
43 | parser = parent_parser.add_argument_group(f"predictive_{class_tag}_distiller")
44 |
45 | parser.add_argument(f"--{distill_lamb_name}", type=float, default=1)
46 | parser.add_argument(f"--{distill_proj_hidden_dim_name}", type=int, default=2048)
47 | parser.add_argument(f"--{distill_no_predictior_name}", type=strtobool, default=False)
48 |
49 | return parent_parser
50 |
51 | @property
52 | def learnable_params(self) -> List[dict]:
53 | """Adds distill predictor parameters to the parent's learnable parameters.
54 |
55 | Returns:
56 | List[dict]: list of learnable parameters.
57 | """
58 |
59 | extra_learnable_params = [
60 | {
61 | "name": f"{class_tag}_distill_predictor",
62 | "params": getattr(self, distill_predictor_name).parameters(),
63 | "lr": self.lr if getattr(self, distill_lamb_name) >= 1 else self.lr / getattr(self, distill_lamb_name),
64 | "weight_decay": self.weight_decay,
65 | },
66 | ]
67 | return super().learnable_params + extra_learnable_params
68 |
69 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
70 | out = super().training_step(batch, batch_idx)
71 | z1, z2 = out[distill_current_key]
72 | frozen_z1, frozen_z2 = out[distill_frozen_key]
73 |
74 | p1 = getattr(self, distill_predictor_name)(z1)
75 | p2 = getattr(self, distill_predictor_name)(z2)
76 |
77 | distill_loss = (byol_loss_func(p1, frozen_z1) + byol_loss_func(p2, frozen_z2)) / 2
78 |
79 | self.log(f"train_{class_tag}_predictive_distill_loss", distill_loss, on_epoch=True, sync_dist=True)
80 |
81 | out["loss"] += getattr(self, distill_lamb_name) * distill_loss
82 | return out
83 |
84 | return PredictiveDistillWrapper
85 |
--------------------------------------------------------------------------------
/kaizen/distiller_factories/soft_label.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from typing import Any, List, Sequence
3 |
4 | import torch
5 | from torch import nn
6 | from torch.functional import F
7 | from kaizen.args.utils import strtobool
8 |
9 | def soft_label_distill_factory(
10 | Method=object, class_tag="",
11 | distill_current_key="z", distill_frozen_key="frozen_z", output_dim=256
12 | ):
13 | distill_lamb_name = f"{class_tag}_distill_lamb"
14 | distill_proj_hidden_dim_name = f"{class_tag}_distill_proj_hidden_dim"
15 | distill_no_predictior_name = f"{class_tag}_distill_no_predictior"
16 |
17 | distill_predictor_name = f"{class_tag}_distill_predictor"
18 | class PredictiveDistillWrapper(Method):
19 | def __init__(self, **kwargs):
20 | distill_lamb = kwargs.pop(distill_lamb_name)
21 | distill_proj_hidden_dim = kwargs.pop(distill_proj_hidden_dim_name)
22 | distill_no_predictior = kwargs.pop(distill_no_predictior_name)
23 | super().__init__(**kwargs)
24 |
25 | setattr(self, distill_lamb_name, distill_lamb)
26 | setattr(self, distill_proj_hidden_dim_name, distill_proj_hidden_dim)
27 | setattr(self, distill_no_predictior_name, distill_no_predictior)
28 | if distill_no_predictior:
29 | setattr(self, distill_predictor_name, nn.Identity())
30 | else:
31 | setattr(self, distill_predictor_name, nn.Sequential(
32 | nn.Linear(output_dim, distill_proj_hidden_dim),
33 | nn.BatchNorm1d(distill_proj_hidden_dim),
34 | nn.ReLU(),
35 | nn.Linear(distill_proj_hidden_dim, output_dim),
36 | ))
37 |
38 |
39 | @staticmethod
40 | def add_model_specific_args(
41 | parent_parser: argparse.ArgumentParser,
42 | ) -> argparse.ArgumentParser:
43 | parser = parent_parser.add_argument_group(f"predictive_{class_tag}_distiller")
44 |
45 | parser.add_argument(f"--{distill_lamb_name}", type=float, default=1)
46 | parser.add_argument(f"--{distill_proj_hidden_dim_name}", type=int, default=2048)
47 | parser.add_argument(f"--{distill_no_predictior_name}", type=strtobool, default=False)
48 |
49 | return parent_parser
50 |
51 | @property
52 | def learnable_params(self) -> List[dict]:
53 | """Adds distill predictor parameters to the parent's learnable parameters.
54 |
55 | Returns:
56 | List[dict]: list of learnable parameters.
57 | """
58 |
59 | extra_learnable_params = [
60 | {
61 | "name": f"{class_tag}_distill_predictor",
62 | "params": getattr(self, distill_predictor_name).parameters(),
63 | "lr": self.lr if getattr(self, distill_lamb_name) >= 1 else self.lr / getattr(self, distill_lamb_name),
64 | "weight_decay": self.weight_decay,
65 | },
66 | ]
67 | return super().learnable_params + extra_learnable_params
68 |
69 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
70 | out = super().training_step(batch, batch_idx)
71 | z1, z2 = out[distill_current_key]
72 | frozen_z1, frozen_z2 = out[distill_frozen_key]
73 |
74 | p1 = getattr(self, distill_predictor_name)(z1)
75 | p2 = getattr(self, distill_predictor_name)(z2)
76 |
77 | distill_loss = (F.cross_entropy(p1, frozen_z1.softmax(dim=1)) + F.cross_entropy(p2, frozen_z2.softmax(dim=1))) / 2
78 |
79 | self.log(f"train_{class_tag}_soft_label_distill_loss", distill_loss, on_epoch=True, sync_dist=True)
80 |
81 | out["loss"] += getattr(self, distill_lamb_name) * distill_loss
82 | return out
83 |
84 | return PredictiveDistillWrapper
85 |
--------------------------------------------------------------------------------
/kaizen/utils/lars.py:
--------------------------------------------------------------------------------
1 | """
2 | References:
3 | - https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py
4 | - https://arxiv.org/pdf/1708.03888.pdf
5 | - https://github.com/noahgolmant/pytorch-lars/blob/master/lars.py
6 | """
7 | import torch
8 | from torch.optim import Optimizer
9 |
10 |
11 | class LARSWrapper:
12 | def __init__(
13 | self,
14 | optimizer: Optimizer,
15 | eta: float = 1e-3,
16 | clip: bool = False,
17 | eps: float = 1e-8,
18 | exclude_bias_n_norm: bool = False,
19 | ):
20 | """Wrapper that adds LARS scheduling to any optimizer.
21 | This helps stability with huge batch sizes.
22 |
23 | Args:
24 | optimizer (Optimizer): torch optimizer.
25 | eta (float, optional): trust coefficient. Defaults to 1e-3.
26 | clip (bool, optional): clip gradient values. Defaults to False.
27 | eps (float, optional): adaptive_lr stability coefficient. Defaults to 1e-8.
28 | exclude_bias_n_norm (bool, optional): exclude bias and normalization layers from lars.
29 | Defaults to False.
30 | """
31 |
32 | self.optim = optimizer
33 | self.eta = eta
34 | self.eps = eps
35 | self.clip = clip
36 | self.exclude_bias_n_norm = exclude_bias_n_norm
37 |
38 | # transfer optim methods
39 | self.state_dict = self.optim.state_dict
40 | self.load_state_dict = self.optim.load_state_dict
41 | self.zero_grad = self.optim.zero_grad
42 | self.add_param_group = self.optim.add_param_group
43 |
44 | self.__setstate__ = self.optim.__setstate__ # type: ignore
45 | self.__getstate__ = self.optim.__getstate__ # type: ignore
46 | self.__repr__ = self.optim.__repr__ # type: ignore
47 |
48 | @property
49 | def defaults(self):
50 | return self.optim.defaults
51 |
52 | @defaults.setter
53 | def defaults(self, defaults):
54 | self.optim.defaults = defaults
55 |
56 | @property # type: ignore
57 | def __class__(self):
58 | return Optimizer
59 |
60 | @property
61 | def state(self):
62 | return self.optim.state
63 |
64 | @state.setter
65 | def state(self, state):
66 | self.optim.state = state
67 |
68 | @property
69 | def param_groups(self):
70 | return self.optim.param_groups
71 |
72 | @param_groups.setter
73 | def param_groups(self, value):
74 | self.optim.param_groups = value
75 |
76 | @torch.no_grad()
77 | def step(self, closure=None):
78 | weight_decays = []
79 |
80 | for group in self.optim.param_groups:
81 | weight_decay = group.get("weight_decay", 0)
82 | weight_decays.append(weight_decay)
83 |
84 | # reset weight decay
85 | group["weight_decay"] = 0
86 |
87 | # update the parameters
88 | for p in group["params"]:
89 | if p.grad is not None and (p.ndim != 1 or not self.exclude_bias_n_norm):
90 | self.update_p(p, group, weight_decay)
91 |
92 | # update the optimizer
93 | self.optim.step(closure=closure)
94 |
95 | # return weight decay control to optimizer
96 | for group_idx, group in enumerate(self.optim.param_groups):
97 | group["weight_decay"] = weight_decays[group_idx]
98 |
99 | def update_p(self, p, group, weight_decay):
100 | # calculate new norms
101 | p_norm = torch.norm(p.data)
102 | g_norm = torch.norm(p.grad.data)
103 |
104 | if p_norm != 0 and g_norm != 0:
105 | # calculate new lr
106 | new_lr = (self.eta * p_norm) / (g_norm + p_norm * weight_decay + self.eps)
107 |
108 | # clip lr
109 | if self.clip:
110 | new_lr = min(new_lr / group["lr"], 1)
111 |
112 | # update params with clipped lr
113 | p.grad.data += weight_decay * p.data
114 | p.grad.data *= new_lr
115 |
--------------------------------------------------------------------------------
/kaizen/distiller_factories/contrastive.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from typing import Any, List, Sequence
3 |
4 | import torch
5 | from torch import nn
6 | from kaizen.losses.simclr import simclr_distill_loss_func
7 | from kaizen.args.utils import strtobool
8 |
9 | def contrastive_distill_factory(
10 | Method=object, class_tag="",
11 | distill_current_key="z", distill_frozen_key="frozen_z", output_dim=256
12 | ):
13 | distill_lamb_name = f"{class_tag}_distill_lamb"
14 | distill_proj_hidden_dim_name = f"{class_tag}_distill_proj_hidden_dim"
15 | distill_temperature_name = f"{class_tag}_distill_temperature"
16 | distill_no_predictior_name = f"{class_tag}_distill_no_predictior"
17 |
18 | distill_predictor_name = f"{class_tag}_distill_predictor"
19 | class ContrastiveDistillWrapper(Method):
20 | def __init__(self, **kwargs):
21 | distill_lamb: float = kwargs.pop(distill_lamb_name)
22 | distill_proj_hidden_dim: int = kwargs.pop(distill_proj_hidden_dim_name)
23 | distill_temperature: float = kwargs.pop(distill_temperature_name)
24 | distill_no_predictior = kwargs.pop(distill_no_predictior_name)
25 | super().__init__(**kwargs)
26 |
27 | setattr(self, distill_lamb_name, distill_lamb)
28 | setattr(self, distill_proj_hidden_dim_name, distill_proj_hidden_dim)
29 | setattr(self, distill_temperature_name, distill_temperature)
30 | setattr(self, distill_no_predictior_name, distill_no_predictior)
31 | if distill_no_predictior:
32 | setattr(self, distill_predictor_name, nn.Identity())
33 | else:
34 | setattr(self, distill_predictor_name, nn.Sequential(
35 | nn.Linear(output_dim, distill_proj_hidden_dim),
36 | nn.BatchNorm1d(distill_proj_hidden_dim),
37 | nn.ReLU(),
38 | nn.Linear(distill_proj_hidden_dim, output_dim),
39 | ))
40 |
41 | @staticmethod
42 | def add_model_specific_args(
43 | parent_parser: argparse.ArgumentParser,
44 | ) -> argparse.ArgumentParser:
45 | parser = parent_parser.add_argument_group(f"contrastive_{class_tag}_distiller")
46 |
47 | parser.add_argument(f"--{distill_lamb_name}", type=float, default=1)
48 | parser.add_argument(f"--{distill_proj_hidden_dim_name}", type=int, default=2048)
49 | parser.add_argument(f"--{distill_temperature_name}", type=float, default=0.2)
50 | parser.add_argument(f"--{distill_no_predictior_name}", type=strtobool, default=False)
51 |
52 | return parent_parser
53 |
54 | @property
55 | def learnable_params(self) -> List[dict]:
56 | """Adds distill predictor parameters to the parent's learnable parameters.
57 |
58 | Returns:
59 | List[dict]: list of learnable parameters.
60 | """
61 |
62 | extra_learnable_params = [
63 | {"params": getattr(self, distill_predictor_name).parameters()},
64 | ]
65 | return super().learnable_params + extra_learnable_params
66 |
67 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
68 | out = super().training_step(batch, batch_idx)
69 | z1, z2 = out[distill_current_key]
70 | frozen_z1, frozen_z2 = out[distill_frozen_key]
71 |
72 | p1 = getattr(self, distill_predictor_name)(z1)
73 | p2 = getattr(self, distill_predictor_name)(z2)
74 |
75 | distill_loss = (
76 | simclr_distill_loss_func(p1, p2, frozen_z1, frozen_z2, getattr(self, distill_temperature_name))
77 | + simclr_distill_loss_func(frozen_z1, frozen_z2, p1, p2, getattr(self, distill_temperature_name))
78 | ) / 2
79 |
80 | self.log(f"train_{class_tag}_contrastive_distill_loss", distill_loss, on_epoch=True, sync_dist=True)
81 |
82 | out["loss"] += getattr(self, distill_lamb_name) * distill_loss
83 | return out
84 |
85 | return ContrastiveDistillWrapper
86 |
--------------------------------------------------------------------------------
/main_linear.py:
--------------------------------------------------------------------------------
1 | import os
2 | import types
3 |
4 | import torch
5 | import torch.nn as nn
6 | from pytorch_lightning import Trainer, seed_everything
7 | from pytorch_lightning.callbacks import LearningRateMonitor
8 | from pytorch_lightning.loggers import WandbLogger
9 | from pytorch_lightning.plugins import DDPPlugin
10 | from torchvision.models import resnet18, resnet50
11 |
12 | from kaizen.args.setup import parse_args_linear
13 |
14 | try:
15 | from kaizen.methods.dali import ClassificationABC
16 | except ImportError:
17 | _dali_avaliable = False
18 | else:
19 | _dali_avaliable = True
20 | from kaizen.methods.linear import LinearModel
21 | from kaizen.utils.classification_dataloader import prepare_data
22 | from kaizen.utils.checkpointer import Checkpointer
23 |
24 |
25 | def main():
26 | args = parse_args_linear()
27 |
28 | # split classes into tasks
29 | tasks = None
30 | if args.split_strategy == "class":
31 | assert args.num_classes % args.num_tasks == 0
32 | torch.manual_seed(args.split_seed)
33 | tasks = torch.randperm(args.num_classes).chunk(args.num_tasks)
34 |
35 | seed_everything(args.global_seed)
36 |
37 | if args.encoder == "resnet18":
38 | backbone = resnet18()
39 | elif args.encoder == "resnet50":
40 | backbone = resnet50()
41 | else:
42 | raise ValueError("Only [resnet18, resnet50] are currently supported.")
43 |
44 | if args.cifar:
45 | backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=2, bias=False)
46 | backbone.maxpool = nn.Identity()
47 | backbone.fc = nn.Identity()
48 |
49 | assert (
50 | args.pretrained_feature_extractor.endswith(".ckpt")
51 | or args.pretrained_feature_extractor.endswith(".pth")
52 | or args.pretrained_feature_extractor.endswith(".pt")
53 | )
54 | ckpt_path = args.pretrained_feature_extractor
55 |
56 | state = torch.load(ckpt_path)["state_dict"]
57 | for k in list(state.keys()):
58 | if "encoder" in k:
59 | state[k.replace("encoder.", "")] = state[k]
60 | del state[k]
61 | backbone.load_state_dict(state, strict=False)
62 |
63 | print(f"Loaded {ckpt_path}")
64 |
65 | if args.dali:
66 | assert _dali_avaliable, "Dali is not currently avaiable, please install it first."
67 | MethodClass = types.new_class(
68 | f"Dali{LinearModel.__name__}", (ClassificationABC, LinearModel)
69 | )
70 | else:
71 | MethodClass = LinearModel
72 |
73 | model = MethodClass(backbone, **args.__dict__, tasks=tasks)
74 |
75 | train_loader, val_loader = prepare_data(
76 | args.dataset,
77 | data_dir=args.data_dir,
78 | train_dir=args.train_dir,
79 | val_dir=args.val_dir,
80 | batch_size=args.batch_size,
81 | num_workers=args.num_workers,
82 | semi_supervised=args.semi_supervised,
83 | )
84 |
85 | callbacks = []
86 |
87 | # wandb logging
88 | if args.wandb:
89 | wandb_logger = WandbLogger(
90 | name=args.name, project=args.project, entity=args.entity, offline=args.offline
91 | )
92 | wandb_logger.watch(model, log="gradients", log_freq=100)
93 | wandb_logger.log_hyperparams(args)
94 |
95 | # lr logging
96 | lr_monitor = LearningRateMonitor(logging_interval="epoch")
97 | callbacks.append(lr_monitor)
98 |
99 | # save checkpoint on last epoch only
100 | ckpt = Checkpointer(
101 | args,
102 | logdir=os.path.join(args.checkpoint_dir, "linear"),
103 | frequency=args.checkpoint_frequency,
104 | )
105 | callbacks.append(ckpt)
106 |
107 | trainer = Trainer.from_argparse_args(
108 | args,
109 | logger=wandb_logger if args.wandb else None,
110 | callbacks=callbacks,
111 | plugins=DDPPlugin(find_unused_parameters=True),
112 | checkpoint_callback=False,
113 | terminate_on_nan=True,
114 | accelerator="ddp",
115 | )
116 | if args.dali:
117 | trainer.fit(model, val_dataloaders=val_loader)
118 | else:
119 | trainer.fit(model, train_loader, val_loader)
120 |
121 |
122 | if __name__ == "__main__":
123 | main()
124 |
--------------------------------------------------------------------------------
/kaizen/losses/dino.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.distributed as dist
5 | import numpy as np
6 |
7 |
8 | class DINOLoss(nn.Module):
9 | def __init__(
10 | self,
11 | num_prototypes: int,
12 | warmup_teacher_temp: float,
13 | teacher_temp: float,
14 | warmup_teacher_temp_epochs: float,
15 | num_epochs: int,
16 | student_temp: float = 0.1,
17 | num_crops: int = 2,
18 | center_momentum: float = 0.9,
19 | ):
20 | """Auxiliary module to compute DINO's loss.
21 |
22 | Args:
23 | num_prototypes (int): number of prototypes.
24 | warmup_teacher_temp (float): base temperature for the temperature schedule
25 | of the teacher.
26 | teacher_temp (float): final temperature for the teacher.
27 | warmup_teacher_temp_epochs (float): number of epochs for the cosine annealing schedule.
28 | num_epochs (int): total number of epochs.
29 | student_temp (float, optional): temperature for the student. Defaults to 0.1.
30 | num_crops (int, optional): number of crops/views. Defaults to 2.
31 | center_momentum (float, optional): momentum for the EMA update of the center of
32 | mass of the teacher. Defaults to 0.9.
33 | """
34 |
35 | super().__init__()
36 | self.epoch = 0
37 | self.student_temp = student_temp
38 | self.center_momentum = center_momentum
39 | self.num_crops = num_crops
40 | self.register_buffer("center", torch.zeros(1, num_prototypes))
41 | # we apply a warm up for the teacher temperature because
42 | # a too high temperature makes the training unstable at the beginning
43 | self.teacher_temp_schedule = np.concatenate(
44 | (
45 | np.linspace(warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs),
46 | np.ones(num_epochs - warmup_teacher_temp_epochs) * teacher_temp,
47 | )
48 | )
49 |
50 | def forward(self, student_output: torch.Tensor, teacher_output: torch.Tensor) -> torch.Tensor:
51 | """Computes DINO's loss given a batch of logits of the student and a batch of logits of the
52 | teacher.
53 |
54 | Args:
55 | student_output (torch.Tensor): NxP Tensor containing student logits for all views.
56 | teacher_output (torch.Tensor): NxP Tensor containing teacher logits for all views.
57 |
58 | Returns:
59 | torch.Tensor: DINO loss.
60 | """
61 |
62 | student_out = student_output / self.student_temp
63 | student_out = student_out.chunk(self.num_crops)
64 |
65 | # teacher centering and sharpening
66 | temp = self.teacher_temp_schedule[self.epoch]
67 | teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)
68 | teacher_out = teacher_out.detach().chunk(2)
69 |
70 | total_loss = 0
71 | n_loss_terms = 0
72 | for iq, q in enumerate(teacher_out):
73 | for v in range(len(student_out)):
74 | if v == iq:
75 | # we skip cases where student and teacher operate on the same view
76 | continue
77 | loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
78 | total_loss += loss.mean()
79 | n_loss_terms += 1
80 | total_loss /= n_loss_terms
81 | self.update_center(teacher_output)
82 | return total_loss
83 |
84 | @torch.no_grad()
85 | def update_center(self, teacher_output: torch.Tensor):
86 | """Updates the center for DINO's loss using exponential moving average.
87 |
88 | Args:
89 | teacher_output (torch.Tensor): NxP Tensor containing teacher logits of all views.
90 | """
91 |
92 | batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
93 | if dist.is_available() and dist.is_initialized():
94 | dist.all_reduce(batch_center)
95 | batch_center = batch_center / dist.get_world_size()
96 | batch_center = batch_center / len(teacher_output)
97 |
98 | # ema update
99 | self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)
100 |
--------------------------------------------------------------------------------
/kaizen/methods/vicreg.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from typing import Any, Dict, List, Sequence
3 |
4 | import torch
5 | import torch.nn as nn
6 | from kaizen.losses.vicreg import vicreg_loss_func
7 | from kaizen.methods.base import BaseModel
8 |
9 |
10 | class VICReg(BaseModel):
11 | def __init__(
12 | self,
13 | output_dim: int,
14 | proj_hidden_dim: int,
15 | sim_loss_weight: float,
16 | var_loss_weight: float,
17 | cov_loss_weight: float,
18 | **kwargs
19 | ):
20 | """Implements VICReg (https://arxiv.org/abs/2105.04906)
21 |
22 | Args:
23 | output_dim (int): number of dimensions of the projected features.
24 | proj_hidden_dim (int): number of neurons in the hidden layers of the projector.
25 | sim_loss_weight (float): weight of the invariance term.
26 | var_loss_weight (float): weight of the variance term.
27 | cov_loss_weight (float): weight of the covariance term.
28 | """
29 |
30 | super().__init__(**kwargs)
31 |
32 | self.sim_loss_weight = sim_loss_weight
33 | self.var_loss_weight = var_loss_weight
34 | self.cov_loss_weight = cov_loss_weight
35 |
36 | # projector
37 | self.projector = nn.Sequential(
38 | nn.Linear(self.features_dim, proj_hidden_dim),
39 | nn.BatchNorm1d(proj_hidden_dim),
40 | nn.ReLU(),
41 | nn.Linear(proj_hidden_dim, proj_hidden_dim),
42 | nn.BatchNorm1d(proj_hidden_dim),
43 | nn.ReLU(),
44 | nn.Linear(proj_hidden_dim, output_dim),
45 | )
46 |
47 | @staticmethod
48 | def add_model_specific_args(parent_parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
49 | parent_parser = super(VICReg, VICReg).add_model_specific_args(parent_parser)
50 | parser = parent_parser.add_argument_group("vicreg")
51 |
52 | # projector
53 | parser.add_argument("--output_dim", type=int, default=2048)
54 | parser.add_argument("--proj_hidden_dim", type=int, default=2048)
55 |
56 | # parameters
57 | parser.add_argument("--sim_loss_weight", default=25, type=float)
58 | parser.add_argument("--var_loss_weight", default=25, type=float)
59 | parser.add_argument("--cov_loss_weight", default=1.0, type=float)
60 | return parent_parser
61 |
62 | @property
63 | def learnable_params(self) -> List[dict]:
64 | """Adds projector parameters to the parent's learnable parameters.
65 |
66 | Returns:
67 | List[dict]: list of learnable parameters.
68 | """
69 |
70 | extra_learnable_params = [{"params": self.projector.parameters()}]
71 | return super().learnable_params + extra_learnable_params
72 |
73 | def forward(self, X: torch.Tensor, *args, **kwargs) -> Dict[str, Any]:
74 | """Performs the forward pass of the encoder and the projector.
75 |
76 | Args:
77 | X (torch.Tensor): a batch of images in the tensor format.
78 |
79 | Returns:
80 | Dict[str, Any]: a dict containing the outputs of the parent and the projected features.
81 | """
82 |
83 | out = super().forward(X, *args, **kwargs)
84 | z = self.projector(out["feats"])
85 | return {**out, "z": z}
86 |
87 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
88 | """Training step for VICReg reusing BaseModel training step.
89 |
90 | Args:
91 | batch (Sequence[Any]): a batch of data in the format of [img_indexes, [X], Y], where
92 | [X] is a list of size self.num_crops containing batches of images.
93 | batch_idx (int): index of the batch.
94 |
95 | Returns:
96 | torch.Tensor: total loss composed of VICReg loss and classification loss.
97 | """
98 |
99 | out = super().training_step(batch, batch_idx)
100 | feats1, feats2 = out["feats"]
101 |
102 | z1 = self.projector(feats1)
103 | z2 = self.projector(feats2)
104 |
105 | # ------- barlow twins loss -------
106 | vicreg_loss = vicreg_loss_func(
107 | z1,
108 | z2,
109 | sim_loss_weight=self.sim_loss_weight,
110 | var_loss_weight=self.var_loss_weight,
111 | cov_loss_weight=self.cov_loss_weight,
112 | )
113 |
114 | self.log("train_vicreg_loss", vicreg_loss, on_epoch=True, sync_dist=True)
115 |
116 | out.update({"loss": out["loss"] + vicreg_loss, "z": [z1, z2]})
117 | return out
118 |
--------------------------------------------------------------------------------
/kaizen/distillers/knowledge.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from typing import Any, List, Sequence
3 |
4 | import torch
5 | from torch import nn
6 | import torch.nn.functional as F
7 | from kaizen.distillers.base import base_distill_wrapper
8 |
9 |
10 | def cross_entropy(preds, targets):
11 | return -torch.mean(
12 | torch.sum(F.softmax(targets, dim=-1) * torch.log_softmax(preds, dim=-1), dim=-1)
13 | )
14 |
15 |
16 | def knowledge_distill_wrapper(Method=object):
17 | class KnowledgeDistillWrapper(base_distill_wrapper(Method)):
18 | def __init__(
19 | self,
20 | distill_lamb: float,
21 | distill_proj_hidden_dim: int,
22 | distill_temperature: float,
23 | **kwargs
24 | ):
25 | super().__init__(**kwargs)
26 |
27 | self.distill_lamb = distill_lamb
28 | self.distill_temperature = distill_temperature
29 | output_dim = kwargs["output_dim"]
30 | num_prototypes = kwargs["num_prototypes"]
31 |
32 | self.frozen_prototypes = nn.utils.weight_norm(
33 | nn.Linear(output_dim, num_prototypes, bias=False)
34 | )
35 | for frozen_pg, pg in zip(
36 | self.frozen_prototypes.parameters(), self.prototypes.parameters()
37 | ):
38 | frozen_pg.data.copy_(pg.data)
39 | frozen_pg.requires_grad = False
40 |
41 | self.distill_predictor = nn.Sequential(
42 | nn.Linear(output_dim, distill_proj_hidden_dim),
43 | nn.BatchNorm1d(distill_proj_hidden_dim),
44 | nn.ReLU(),
45 | nn.Linear(distill_proj_hidden_dim, output_dim),
46 | )
47 |
48 | self.distill_prototypes = nn.utils.weight_norm(
49 | nn.Linear(output_dim, num_prototypes, bias=False)
50 | )
51 |
52 | @staticmethod
53 | def add_model_specific_args(
54 | parent_parser: argparse.ArgumentParser,
55 | ) -> argparse.ArgumentParser:
56 | parser = parent_parser.add_argument_group("knowledge_distiller")
57 |
58 | parser.add_argument("--distill_lamb", type=float, default=1)
59 | parser.add_argument("--distill_proj_hidden_dim", type=int, default=2048)
60 | parser.add_argument("--distill_temperature", type=float, default=0.1)
61 |
62 | return parent_parser
63 |
64 | @property
65 | def learnable_params(self) -> List[dict]:
66 | """Adds distill predictor parameters to the parent's learnable parameters.
67 |
68 | Returns:
69 | List[dict]: list of learnable parameters.
70 | """
71 |
72 | extra_learnable_params = [
73 | {"params": self.distill_predictor.parameters()},
74 | {"params": self.distill_prototypes.parameters()},
75 | ]
76 | return super().learnable_params + extra_learnable_params
77 |
78 | def on_train_start(self):
79 | super().on_train_start()
80 |
81 | if self.current_task_idx > 0:
82 | for frozen_pg, pg in zip(
83 | self.frozen_prototypes.parameters(), self.prototypes.parameters()
84 | ):
85 | frozen_pg.data.copy_(pg.data)
86 | frozen_pg.requires_grad = False
87 |
88 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
89 | out = super().training_step(batch, batch_idx)
90 | z1, z2 = out["z"]
91 | frozen_z1, frozen_z2 = out["frozen_z"]
92 |
93 | with torch.no_grad():
94 | frozen_z1 = F.normalize(frozen_z1)
95 | frozen_z2 = F.normalize(frozen_z2)
96 | frozen_p1 = self.frozen_prototypes(frozen_z1) / self.distill_temperature
97 | frozen_p2 = self.frozen_prototypes(frozen_z2) / self.distill_temperature
98 |
99 | distill_z1 = F.normalize(self.distill_predictor(z1))
100 | distill_z2 = F.normalize(self.distill_predictor(z2))
101 | distill_p1 = self.distill_prototypes(distill_z1) / self.distill_temperature
102 | distill_p2 = self.distill_prototypes(distill_z2) / self.distill_temperature
103 |
104 | distill_loss = (
105 | cross_entropy(distill_p1, frozen_p1) + cross_entropy(distill_p2, frozen_p2)
106 | ) / 2
107 |
108 | self.log("train_knowledge_distill_loss", distill_loss, on_epoch=True, sync_dist=True)
109 |
110 | return out["loss"] + self.distill_lamb * distill_loss
111 |
112 | return KnowledgeDistillWrapper
113 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Kaizen: Practical Self-Supervised Continual Learning With Continual Fine-Tuning
2 |
3 | Official implementation of the algorithm outlined in
4 | > **[Kaizen: Practical Self-Supervised Continual Learning With Continual Fine-Tuning](https://openaccess.thecvf.com/content/WACV2024/html/Tang_Kaizen_Practical_Self-Supervised_Continual_Learning_With_Continual_Fine-Tuning_WACV_2024_paper.html)**
5 | > Chi Ian Tang, Lorena Qendro, Dimitris Spathis, Fahim Kawsar, Cecilia Mascolo, Akhil Mathur
6 | > **WACV 2024**
7 |
8 | > **Abstract:** *Self-supervised learning (SSL) has shown remarkable performance in computer vision tasks when trained offline. However, in a Continual Learning (CL) scenario where new data is introduced progressively, models still suffer from catastrophic forgetting. Retraining a model from scratch to adapt to newly generated data is time-consuming and inefficient. Previous approaches suggested re-purposing self-supervised objectives with knowledge distillation to mitigate forgetting across tasks, assuming that labels from all tasks are available during fine-tuning. In this paper, we generalize self-supervised continual learning in a practical setting where available labels can be leveraged in any step of the SSL process. With an increasing number of continual tasks, this offers more flexibility in the pre-training and fine-tuning phases. With Kaizen, we introduce a training architecture that is able to mitigate catastrophic forgetting for both the feature extractor and classifier with a carefully designed loss function. By using a set of comprehensive evaluation metrics reflecting different aspects of continual learning, we demonstrated that Kaizen significantly outperforms previous SSL models in competitive vision benchmarks, with up to 16.5% accuracy improvement on split CIFAR-100. Kaizen is able to balance the trade-off between knowledge retention and learning from new data with an end-to-end model, paving the way for practical deployment of continual learning systems.*
9 |
10 | 
11 | 
12 | 
13 |
14 | # Installation
15 | Please run the following command to install the required packages:
16 | ```
17 | pip install -r requirements.txt
18 | ```
19 | In order to work with the [WandB](https://wandb.ai/site) logging library, you may need to run the following:
20 | ```
21 | source setup_commands.sh
22 | ```
23 |
24 | # Commands
25 |
26 | Bash files for launching the experiments are provided in the `bash_files` folder. Using the `job_launcher.py` can launch experiments which automatically train a model continually for the number of tasks specified in the corresponding bash file.
27 |
28 | For example, to launch the MoCoV2+ experiment, run:
29 |
30 | ```
31 | DATA_DIR=/YOUR/DATA/DIR/ CUDA_VISIBLE_DEVICES=0 python job_launcher.py --script bash_files/mocov2plus_cifar_distill_classifier_l1000_soft_label_replay_0.01_b32.sh
32 | ```
33 |
34 | Note that it is the default behaviour of each script to use gpu:0, so setting the `CUDA_VISIBLE_DEVICES=0` environment variable allows control over which gpu to be used by the script.
35 |
36 | # Copyright, Acknowledgment & Warranty
37 | This repository is provided for research reproducibility only. The authors reserve all rights. (Copyright (c) 2023 Chi Ian Tang, Lorena Qendro, Dimitris Spathis, Fahim Kawsar, Cecilia Mascolo, Akhil Mathur).
38 |
39 | Part of this repository is based on the implementation of [cassle](https://github.com/DonkeyShot21/cassle) ([https://github.com/DonkeyShot21/cassle](https://github.com/DonkeyShot21/cassle)), used under the MIT License, Copyright (c) 2021 Enrico Fini, Victor Turrisi, Xavier Alameda-Pineda, Elisa Ricci, Karteek Alahari, Julien Mairal.
40 |
41 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
42 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
43 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
44 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
45 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
46 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
47 | SOFTWARE.
48 |
49 | # Citation
50 |
51 | Please cite our paper if any part of this repository is used in any research work:
52 | ```
53 | @inproceedings{tang2024kaizen,
54 | title={Kaizen: Practical Self-Supervised Continual Learning With Continual Fine-Tuning},
55 | author={Tang, Chi Ian and Qendro, Lorena and Spathis, Dimitris and Kawsar, Fahim and Mascolo, Cecilia and Mathur, Akhil},
56 | booktitle={Proceedings of the IEEE/CVF Winter Conference on Applications of Computer Vision},
57 | pages={2841--2850},
58 | year={2024}
59 | }
60 | ```
61 |
62 | # License
63 |
64 | This project is licensed under the [BSD 3-Clause Clear License](LICENSE.md).
65 |
--------------------------------------------------------------------------------
/kaizen/distiller_factories/decorrelative.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from typing import Any, List, Sequence
3 |
4 | import torch
5 | from torch import nn
6 | from kaizen.losses.barlow import barlow_loss_func
7 | from kaizen.args.utils import strtobool
8 |
9 |
10 | def decorrelative_distill_factory(
11 | Method=object, class_tag="",
12 | distill_current_key="z", distill_frozen_key="frozen_z", output_dim=256
13 | ):
14 | distill_lamb_name = f"{class_tag}_distill_lamb"
15 | distill_proj_hidden_dim_name = f"{class_tag}_distill_proj_hidden_dim"
16 | distill_barlow_lamb_name = f"{class_tag}_distill_barlow_lamb"
17 | distill_scale_loss_name = f"{class_tag}_distill_scale_loss"
18 | distill_no_predictior_name = f"{class_tag}_distill_no_predictior"
19 |
20 | distill_predictor_name = f"{class_tag}_distill_predictor"
21 | class DecorrelativeDistillWrapper(Method):
22 | def __init__(self, **kwargs):
23 | distill_lamb: float = kwargs.pop(distill_lamb_name)
24 | distill_proj_hidden_dim: int = kwargs.pop(distill_proj_hidden_dim_name)
25 | distill_barlow_lamb: float = kwargs.pop(distill_barlow_lamb_name)
26 | distill_scale_loss: float = kwargs.pop(distill_scale_loss_name)
27 | distill_no_predictior = kwargs.pop(distill_no_predictior_name)
28 | super().__init__(**kwargs)
29 |
30 | setattr(self, distill_lamb_name, distill_lamb)
31 | setattr(self, distill_proj_hidden_dim_name, distill_proj_hidden_dim)
32 | setattr(self, distill_barlow_lamb_name, distill_barlow_lamb)
33 | setattr(self, distill_scale_loss_name, distill_scale_loss)
34 | setattr(self, distill_no_predictior_name, distill_no_predictior)
35 | if distill_no_predictior:
36 | setattr(self, distill_predictor_name, nn.Identity())
37 | else:
38 | setattr(self, distill_predictor_name, nn.Sequential(
39 | nn.Linear(output_dim, distill_proj_hidden_dim),
40 | nn.BatchNorm1d(distill_proj_hidden_dim),
41 | nn.ReLU(),
42 | nn.Linear(distill_proj_hidden_dim, output_dim),
43 | ))
44 |
45 |
46 | @staticmethod
47 | def add_model_specific_args(
48 | parent_parser: argparse.ArgumentParser,
49 | ) -> argparse.ArgumentParser:
50 | parser = parent_parser.add_argument_group(f"decorrelative_{class_tag}_distiller")
51 |
52 | parser.add_argument(f"--{distill_lamb_name}", type=float, default=1)
53 | parser.add_argument(f"--{distill_proj_hidden_dim_name}", type=int, default=2048)
54 | parser.add_argument(f"--{distill_barlow_lamb_name}", type=float, default=5e-3)
55 | parser.add_argument(f"--{distill_scale_loss_name}", type=float, default=0.1)
56 | parser.add_argument(f"--{distill_no_predictior_name}", type=strtobool, default=False)
57 |
58 | return parent_parser
59 |
60 | @property
61 | def learnable_params(self) -> List[dict]:
62 | """Adds distill predictor parameters to the parent's learnable parameters.
63 |
64 | Returns:
65 | List[dict]: list of learnable parameters.
66 | """
67 |
68 | extra_learnable_params = [
69 | {"params": getattr(self, distill_predictor_name).parameters()},
70 | ]
71 | return super().learnable_params + extra_learnable_params
72 |
73 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
74 | out = super().training_step(batch, batch_idx)
75 | z1, z2 = out[distill_current_key]
76 | frozen_z1, frozen_z2 = out[distill_frozen_key]
77 |
78 | p1 = getattr(self, distill_predictor_name)(z1)
79 | p2 = getattr(self, distill_predictor_name)(z2)
80 |
81 | distill_loss = (
82 | barlow_loss_func(
83 | p1,
84 | frozen_z1,
85 | lamb=getattr(self, distill_barlow_lamb_name),
86 | scale_loss=getattr(self, distill_scale_loss_name),
87 | )
88 | + barlow_loss_func(
89 | p2,
90 | frozen_z2,
91 | lamb=getattr(self, distill_barlow_lamb_name),
92 | scale_loss=getattr(self, distill_scale_loss_name),
93 | )
94 | ) / 2
95 |
96 | self.log(
97 | f"train_{class_tag}_decorrelative_distill_loss", distill_loss, on_epoch=True, sync_dist=True
98 | )
99 |
100 | out["loss"] += getattr(self, distill_lamb_name) * distill_loss
101 | return out
102 |
103 | return DecorrelativeDistillWrapper
104 |
--------------------------------------------------------------------------------
/kaizen/methods/wmse.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Sequence
2 |
3 | import torch
4 | import torch.nn as nn
5 | from kaizen.losses.wmse import wmse_loss_func
6 | from kaizen.methods.base import BaseModel
7 | from kaizen.utils.whitening import Whitening2d
8 |
9 |
10 | class WMSE(BaseModel):
11 | def __init__(
12 | self,
13 | output_dim: int,
14 | proj_hidden_dim: int,
15 | whitening_iters: int,
16 | whitening_size: int,
17 | whitening_eps: float,
18 | **kwargs
19 | ):
20 | """Implements W-MSE (https://arxiv.org/abs/2007.06346)
21 |
22 | Args:
23 | output_dim (int): number of dimensions of the projected features.
24 | proj_hidden_dim (int): number of neurons in the hidden layers of the projector.
25 | whitening_iters (int): number of times to perform whitening.
26 | whitening_size (int): size of the batch slice for whitening.
27 | whitening_eps (float): epsilon for numerical stability in whitening.
28 | """
29 |
30 | super().__init__(**kwargs)
31 |
32 | self.whitening_iters = whitening_iters
33 | self.whitening_size = whitening_size
34 |
35 | assert self.whitening_size <= self.batch_size
36 |
37 | # projector
38 | self.projector = nn.Sequential(
39 | nn.Linear(self.features_dim, proj_hidden_dim),
40 | nn.BatchNorm1d(proj_hidden_dim),
41 | nn.ReLU(),
42 | nn.Linear(proj_hidden_dim, output_dim),
43 | )
44 |
45 | self.whitening = Whitening2d(output_dim, eps=whitening_eps)
46 |
47 | @staticmethod
48 | def add_model_specific_args(parent_parser):
49 | parent_parser = super(WMSE, WMSE).add_model_specific_args(parent_parser)
50 | parser = parent_parser.add_argument_group("simclr")
51 |
52 | # projector
53 | parser.add_argument("--output_dim", type=int, default=128)
54 | parser.add_argument("--proj_hidden_dim", type=int, default=1024)
55 |
56 | # wmse
57 | parser.add_argument("--whitening_iters", type=int, default=1)
58 | parser.add_argument("--whitening_size", type=int, default=256)
59 | parser.add_argument("--whitening_eps", type=float, default=0)
60 |
61 | return parent_parser
62 |
63 | @property
64 | def learnable_params(self) -> List[Dict]:
65 | """Adds projector parameters to the parent's learnable parameters.
66 |
67 | Returns:
68 | List[dict]: list of learnable parameters.
69 | """
70 |
71 | extra_learnable_params = [{"params": self.projector.parameters()}]
72 | return super().learnable_params + extra_learnable_params
73 |
74 | def forward(self, X: torch.Tensor, *args, **kwargs) -> Dict[str, Any]:
75 | """Performs the forward pass of the encoder and the projector.
76 |
77 | Args:
78 | X (torch.Tensor): a batch of images in the tensor format.
79 |
80 | Returns:
81 | Dict[str, Any]: a dict containing the outputs of the parent and the projected features.
82 | """
83 |
84 | out = super().forward(X, *args, **kwargs)
85 | v = self.projector(out["feats"])
86 | return {**out, "v": v}
87 |
88 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
89 | """Training step for W-MSE reusing BaseModel training step.
90 |
91 | Args:
92 | batch (Sequence[Any]): a batch of data in the format of [img_indexes, [X], Y], where
93 | [X] is a list of size self.num_crops containing batches of images
94 | batch_idx (int): index of the batch
95 |
96 | Returns:
97 | torch.Tensor: total loss composed of W-MSE loss and classification loss
98 | """
99 |
100 | out = super().training_step(batch, batch_idx)
101 | class_loss = out["loss"]
102 | feats = out["feats"]
103 |
104 | v = torch.cat([self.projector(f) for f in feats])
105 |
106 | # ------- wmse loss -------
107 | bs = self.batch_size
108 | num_losses, wmse_loss = 0, 0
109 | for _ in range(self.whitening_iters):
110 | z = torch.empty_like(v)
111 | perm = torch.randperm(bs).view(-1, self.whitening_size)
112 | for idx in perm:
113 | for i in range(self.num_crops):
114 | z[idx + i * bs] = self.whitening(v[idx + i * bs]).type_as(z)
115 | for i in range(self.num_crops - 1):
116 | for j in range(i + 1, self.num_crops):
117 | x0 = z[i * bs : (i + 1) * bs]
118 | x1 = z[j * bs : (j + 1) * bs]
119 | wmse_loss += wmse_loss_func(x0, x1)
120 | num_losses += 1
121 | wmse_loss /= num_losses
122 |
123 | self.log("train_neg_cos_sim", wmse_loss, on_epoch=True, sync_dist=True)
124 |
125 | return wmse_loss + class_loss
126 |
--------------------------------------------------------------------------------
/kaizen/methods/simsiam.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from typing import Any, Dict, List, Sequence
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from kaizen.losses.simsiam import simsiam_loss_func
8 | from kaizen.methods.base import BaseModel
9 |
10 |
11 | class SimSiam(BaseModel):
12 | def __init__(
13 | self,
14 | output_dim: int,
15 | proj_hidden_dim: int,
16 | pred_hidden_dim: int,
17 | **kwargs,
18 | ):
19 | """Implements SimSiam (https://arxiv.org/abs/2011.10566).
20 |
21 | Args:
22 | output_dim (int): number of dimensions of projected features.
23 | proj_hidden_dim (int): number of neurons of the hidden layers of the projector.
24 | pred_hidden_dim (int): number of neurons of the hidden layers of the predictor.
25 | """
26 |
27 | super().__init__(**kwargs)
28 |
29 | # projector
30 | self.projector = nn.Sequential(
31 | nn.Linear(self.features_dim, proj_hidden_dim, bias=False),
32 | nn.BatchNorm1d(proj_hidden_dim),
33 | nn.ReLU(),
34 | nn.Linear(proj_hidden_dim, proj_hidden_dim, bias=False),
35 | nn.BatchNorm1d(proj_hidden_dim),
36 | nn.ReLU(),
37 | nn.Linear(proj_hidden_dim, output_dim),
38 | nn.BatchNorm1d(output_dim, affine=False),
39 | )
40 | self.projector[6].bias.requires_grad = False # hack: not use bias as it is followed by BN
41 |
42 | # predictor
43 | self.predictor = nn.Sequential(
44 | nn.Linear(output_dim, pred_hidden_dim, bias=False),
45 | nn.BatchNorm1d(pred_hidden_dim),
46 | nn.ReLU(),
47 | nn.Linear(pred_hidden_dim, output_dim),
48 | )
49 |
50 | @staticmethod
51 | def add_model_specific_args(parent_parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
52 | parent_parser = super(SimSiam, SimSiam).add_model_specific_args(parent_parser)
53 | parser = parent_parser.add_argument_group("simsiam")
54 |
55 | # projector
56 | parser.add_argument("--output_dim", type=int, default=128)
57 | parser.add_argument("--proj_hidden_dim", type=int, default=2048)
58 |
59 | # predictor
60 | parser.add_argument("--pred_hidden_dim", type=int, default=512)
61 | return parent_parser
62 |
63 | @property
64 | def learnable_params(self) -> List[dict]:
65 | """Adds projector and predictor parameters to the parent's learnable parameters.
66 |
67 | Returns:
68 | List[dict]: list of learnable parameters.
69 | """
70 |
71 | extra_learnable_params: List[dict] = [
72 | {"params": self.projector.parameters()},
73 | {"params": self.predictor.parameters(), "static_lr": True},
74 | ]
75 | return super().learnable_params + extra_learnable_params
76 |
77 | def forward(self, X: torch.Tensor, *args, **kwargs) -> Dict[str, Any]:
78 | """Performs the forward pass of the encoder, the projector and the predictor.
79 |
80 | Args:
81 | X (torch.Tensor): a batch of images in the tensor format.
82 |
83 | Returns:
84 | Dict[str, Any]:
85 | a dict containing the outputs of the parent
86 | and the projected and predicted features.
87 | """
88 |
89 | out = super().forward(X, *args, **kwargs)
90 | z = self.projector(out["feats"])
91 | p = self.predictor(z)
92 | return {**out, "z": z, "p": p}
93 |
94 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
95 | """Training step for SimSiam reusing BaseModel training step.
96 |
97 | Args:
98 | batch (Sequence[Any]): a batch of data in the format of [img_indexes, [X], Y], where
99 | [X] is a list of size self.num_crops containing batches of images
100 | batch_idx (int): index of the batch
101 |
102 | Returns:
103 | torch.Tensor: total loss composed of SimSiam loss and classification loss
104 | """
105 |
106 | out = super().training_step(batch, batch_idx)
107 | feats1, feats2 = out["feats"]
108 |
109 | z1 = self.projector(feats1)
110 | z2 = self.projector(feats2)
111 |
112 | p1 = self.predictor(z1)
113 | p2 = self.predictor(z2)
114 |
115 | # ------- contrastive loss -------
116 | neg_cos_sim = simsiam_loss_func(p1, z2) / 2 + simsiam_loss_func(p2, z1) / 2
117 |
118 | # calculate std of features
119 | z1_std = F.normalize(z1, dim=-1).std(dim=0).mean()
120 | z2_std = F.normalize(z2, dim=-1).std(dim=0).mean()
121 | z_std = (z1_std + z2_std) / 2
122 |
123 | metrics = {
124 | "train_neg_cos_sim": neg_cos_sim,
125 | "train_z_std": z_std,
126 | }
127 | self.log_dict(metrics, on_epoch=True, sync_dist=True)
128 |
129 | out.update({"loss": out["loss"] + neg_cos_sim, "z": [z1, z2]})
130 | return out
131 |
--------------------------------------------------------------------------------
/kaizen/losses/simclr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from typing import Optional
4 |
5 |
6 | def simclr_distill_loss_func(
7 | p1: torch.Tensor,
8 | p2: torch.Tensor,
9 | z1: torch.Tensor,
10 | z2: torch.Tensor,
11 | temperature: float = 0.1,
12 | ) -> torch.Tensor:
13 |
14 | device = z1.device
15 |
16 | b = z1.size(0)
17 |
18 | p = F.normalize(torch.cat([p1, p2]), dim=-1)
19 | z = F.normalize(torch.cat([z1, z2]), dim=-1)
20 |
21 | logits = torch.einsum("if, jf -> ij", p, z) / temperature
22 | logits_max, _ = torch.max(logits, dim=1, keepdim=True)
23 | logits = logits - logits_max.detach()
24 |
25 | # positive mask are matches i, j (i from aug1, j from aug2), where i == j and matches j, i
26 | pos_mask = torch.zeros((2 * b, 2 * b), dtype=torch.bool, device=device)
27 | pos_mask.fill_diagonal_(True)
28 |
29 | # all matches excluding the main diagonal
30 | logit_mask = torch.ones_like(pos_mask, device=device)
31 | logit_mask.fill_diagonal_(True)
32 | logit_mask[:, b:].fill_diagonal_(True)
33 | logit_mask[b:, :].fill_diagonal_(True)
34 |
35 | exp_logits = torch.exp(logits) * logit_mask
36 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
37 |
38 | # compute mean of log-likelihood over positives
39 | mean_log_prob_pos = (pos_mask * log_prob).sum(1) / pos_mask.sum(1)
40 | # loss
41 | loss = -mean_log_prob_pos.mean()
42 | return loss
43 |
44 |
45 | def simclr_loss_func(
46 | z1: torch.Tensor,
47 | z2: torch.Tensor,
48 | temperature: float = 0.1,
49 | extra_pos_mask: Optional[torch.Tensor] = None,
50 | ) -> torch.Tensor:
51 | """Computes SimCLR's loss given batch of projected features z1 from view 1 and
52 | projected features z2 from view 2.
53 |
54 | Args:
55 | z1 (torch.Tensor): NxD Tensor containing projected features from view 1.
56 | z2 (torch.Tensor): NxD Tensor containing projected features from view 2.
57 | temperature (float): temperature factor for the loss. Defaults to 0.1.
58 | extra_pos_mask (Optional[torch.Tensor]): boolean mask containing extra positives other
59 | than normal across-view positives. Defaults to None.
60 |
61 | Returns:
62 | torch.Tensor: SimCLR loss.
63 | """
64 |
65 | device = z1.device
66 |
67 | b = z1.size(0)
68 | z = torch.cat((z1, z2), dim=0)
69 | z = F.normalize(z, dim=-1)
70 |
71 | logits = torch.einsum("if, jf -> ij", z, z) / temperature
72 | logits_max, _ = torch.max(logits, dim=1, keepdim=True)
73 | logits = logits - logits_max.detach()
74 |
75 | # positive mask are matches i, j (i from aug1, j from aug2), where i == j and matches j, i
76 | pos_mask = torch.zeros((2 * b, 2 * b), dtype=torch.bool, device=device)
77 | pos_mask[:, b:].fill_diagonal_(True)
78 | pos_mask[b:, :].fill_diagonal_(True)
79 |
80 | # if we have extra "positives"
81 | if extra_pos_mask is not None:
82 | pos_mask = torch.bitwise_or(pos_mask, extra_pos_mask)
83 |
84 | # all matches excluding the main diagonal
85 | logit_mask = torch.ones_like(pos_mask, device=device).fill_diagonal_(0)
86 |
87 | exp_logits = torch.exp(logits) * logit_mask
88 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
89 |
90 | # compute mean of log-likelihood over positives
91 | mean_log_prob_pos = (pos_mask * log_prob).sum(1) / pos_mask.sum(1)
92 | # loss
93 | loss = -mean_log_prob_pos.mean()
94 | return loss
95 |
96 |
97 | def manual_simclr_loss_func(
98 | z: torch.Tensor, pos_mask: torch.Tensor, neg_mask: torch.Tensor, temperature: float = 0.1
99 | ) -> torch.Tensor:
100 | """Manually computes SimCLR's loss given batch of projected features z
101 | from different views, a positive boolean mask of all positives and
102 | a negative boolean mask of all negatives.
103 |
104 | Args:
105 | z (torch.Tensor): NxViewsxD Tensor containing projected features from the views.
106 | pos_mask (torch.Tensor): boolean mask containing all positives for z * z.T.
107 | neg_mask (torch.Tensor): boolean mask containing all negatives for z * z.T.
108 | temperature (float): temperature factor for the loss.
109 |
110 | Return:
111 | torch.Tensor: manual SimCLR loss.
112 | """
113 |
114 | z = F.normalize(z, dim=-1)
115 |
116 | logits = torch.einsum("if, jf -> ij", z, z) / temperature
117 | logits_max, _ = torch.max(logits, dim=1, keepdim=True)
118 | logits = logits - logits_max.detach()
119 |
120 | negatives = torch.sum(torch.exp(logits) * neg_mask, dim=1, keepdim=True)
121 | exp_logits = torch.exp(logits)
122 | log_prob = torch.log(exp_logits / (exp_logits + negatives))
123 |
124 | # compute mean of log-likelihood over positive
125 | mean_log_prob_pos = (pos_mask * log_prob).sum(1)
126 |
127 | indexes = pos_mask.sum(1) > 0
128 | pos_mask = pos_mask[indexes]
129 | mean_log_prob_pos = mean_log_prob_pos[indexes] / pos_mask.sum(1)
130 |
131 | # loss
132 | loss = -mean_log_prob_pos.mean()
133 | return loss
134 |
--------------------------------------------------------------------------------
/kaizen/utils/checkpointer.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import random
4 | import string
5 | import time
6 | from argparse import ArgumentParser, Namespace
7 | from pathlib import Path
8 | from typing import Optional, Union
9 |
10 | import pytorch_lightning as pl
11 | from pytorch_lightning.callbacks import Callback
12 |
13 |
14 | def random_string(letter_count=4, digit_count=4):
15 | tmp_random = random.Random(time.time())
16 | rand_str = "".join((tmp_random.choice(string.ascii_lowercase) for x in range(letter_count)))
17 | rand_str += "".join((tmp_random.choice(string.digits) for x in range(digit_count)))
18 | rand_str = list(rand_str)
19 | tmp_random.shuffle(rand_str)
20 | return "".join(rand_str)
21 |
22 |
23 | class Checkpointer(Callback):
24 | def __init__(
25 | self,
26 | args: Namespace,
27 | logdir: Union[str, Path] = Path("trained_models"),
28 | frequency: int = 1,
29 | keep_previous_checkpoints: bool = False,
30 | ):
31 | """Custom checkpointer callback that stores checkpoints in an easier to access way.
32 |
33 | Args:
34 | args (Namespace): namespace object containing at least an attribute name.
35 | logdir (Union[str, Path], optional): base directory to store checkpoints.
36 | Defaults to "trained_models".
37 | frequency (int, optional): number of epochs between each checkpoint. Defaults to 1.
38 | keep_previous_checkpoints (bool, optional): whether to keep previous checkpoints or not.
39 | Defaults to False.
40 | """
41 |
42 | super().__init__()
43 |
44 | assert "task" not in args.name
45 |
46 | self.args = args
47 | self.logdir = Path(logdir)
48 | self.frequency = frequency
49 | self.keep_previous_checkpoints = keep_previous_checkpoints
50 |
51 | @staticmethod
52 | def add_checkpointer_args(parent_parser: ArgumentParser):
53 | """Adds user-required arguments to a parser.
54 |
55 | Args:
56 | parent_parser (ArgumentParser): parser to add new args to.
57 | """
58 |
59 | parser = parent_parser.add_argument_group("checkpointer")
60 | parser.add_argument("--checkpoint_dir", default=Path("trained_models"), type=Path)
61 | parser.add_argument("--checkpoint_frequency", default=1, type=int)
62 | return parent_parser
63 |
64 | def initial_setup(self, trainer: pl.Trainer):
65 | """Creates the directories and does the initial setup needed.
66 |
67 | Args:
68 | trainer (pl.Trainer): pytorch lightning trainer object.
69 | """
70 |
71 | if trainer.logger is None:
72 | if os.path.exists(self.logdir):
73 | existing_versions = set(os.listdir(self.logdir))
74 | else:
75 | existing_versions = set()
76 | version = "offline-" + random_string()
77 | while version in existing_versions:
78 | version = "offline-" + random_string()
79 | else:
80 | version = str(trainer.logger.version)
81 | if version is not None:
82 | task_idx = getattr(self.args, "task_idx", "_all")
83 | self.path = self.logdir / f"task{task_idx}-{version}"
84 | self.ckpt_placeholder = f"{self.args.name}" + "-task{}-ep={}" + f"-{version}.ckpt"
85 | self.last_ckpt: Optional[str] = None
86 |
87 | # create logging dirs
88 | if trainer.is_global_zero:
89 | os.makedirs(self.path, exist_ok=True)
90 |
91 | def save_args(self, trainer: pl.Trainer):
92 | """Stores arguments into a json file.
93 |
94 | Args:
95 | trainer (pl.Trainer): pytorch lightning trainer object.
96 | """
97 |
98 | if trainer.is_global_zero:
99 | args = vars(self.args)
100 | self.json_path = self.path / "args.json"
101 | json.dump(args, open(self.json_path, "w"), default=lambda o: "")
102 |
103 | def save(self, trainer: pl.Trainer):
104 | """Saves current checkpoint.
105 |
106 | Args:
107 | trainer (pl.Trainer): pytorch lightning trainer object.
108 | """
109 |
110 | if trainer.is_global_zero and not trainer.sanity_checking:
111 | epoch = trainer.current_epoch # type: ignore
112 | task_idx = getattr(self.args, "task_idx", "_all")
113 | ckpt = self.path / self.ckpt_placeholder.format(task_idx, epoch)
114 | trainer.save_checkpoint(ckpt)
115 |
116 | if self.last_ckpt and self.last_ckpt != ckpt and not self.keep_previous_checkpoints:
117 | if os.path.exists(self.last_ckpt):
118 | os.remove(self.last_ckpt)
119 |
120 | with open(self.logdir / "last_checkpoint.txt", "w") as f:
121 | f.write(str(ckpt) + "\n" + str(self.json_path))
122 |
123 | self.last_ckpt = ckpt
124 |
125 | def on_train_start(self, trainer: pl.Trainer, _):
126 | """Executes initial setup and saves arguments.
127 |
128 | Args:
129 | trainer (pl.Trainer): pytorch lightning trainer object.
130 | """
131 |
132 | self.initial_setup(trainer)
133 | self.save_args(trainer)
134 |
135 | def on_train_epoch_end(self, trainer: pl.Trainer, _):
136 | """Tries to save current checkpoint at the end of each validation epoch.
137 |
138 | Args:
139 | trainer (pl.Trainer): pytorch lightning trainer object.
140 | """
141 |
142 | epoch = trainer.current_epoch # type: ignore
143 | if epoch % self.frequency == 0:
144 | self.save(trainer)
145 |
--------------------------------------------------------------------------------
/kaizen/methods/byol.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from typing import Any, Dict, List, Sequence, Tuple
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from kaizen.losses.byol import byol_loss_func
8 | from kaizen.methods.base import BaseMomentumModel
9 | from kaizen.utils.momentum import initialize_momentum_params
10 |
11 |
12 | class BYOL(BaseMomentumModel):
13 | def __init__(
14 | self,
15 | output_dim: int,
16 | proj_hidden_dim: int,
17 | pred_hidden_dim: int,
18 | **kwargs,
19 | ):
20 | """Implements BYOL (https://arxiv.org/abs/2006.07733).
21 |
22 | Args:
23 | output_dim (int): number of dimensions of projected features.
24 | proj_hidden_dim (int): number of neurons of the hidden layers of the projector.
25 | pred_hidden_dim (int): number of neurons of the hidden layers of the predictor.
26 | """
27 |
28 | super().__init__(**kwargs)
29 |
30 | # projector
31 | self.projector = nn.Sequential(
32 | nn.Linear(self.features_dim, proj_hidden_dim),
33 | nn.BatchNorm1d(proj_hidden_dim),
34 | nn.ReLU(),
35 | nn.Linear(proj_hidden_dim, output_dim),
36 | )
37 |
38 | # momentum projector
39 | self.momentum_projector = nn.Sequential(
40 | nn.Linear(self.features_dim, proj_hidden_dim),
41 | nn.BatchNorm1d(proj_hidden_dim),
42 | nn.ReLU(),
43 | nn.Linear(proj_hidden_dim, output_dim),
44 | )
45 | initialize_momentum_params(self.projector, self.momentum_projector)
46 |
47 | # predictor
48 | self.predictor = nn.Sequential(
49 | nn.Linear(output_dim, pred_hidden_dim),
50 | nn.BatchNorm1d(pred_hidden_dim),
51 | nn.ReLU(),
52 | nn.Linear(pred_hidden_dim, output_dim),
53 | )
54 |
55 | @staticmethod
56 | def add_model_specific_args(parent_parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
57 | parent_parser = super(BYOL, BYOL).add_model_specific_args(parent_parser)
58 | parser = parent_parser.add_argument_group("byol")
59 |
60 | # projector
61 | parser.add_argument("--output_dim", type=int, default=256)
62 | parser.add_argument("--proj_hidden_dim", type=int, default=2048)
63 |
64 | # predictor
65 | parser.add_argument("--pred_hidden_dim", type=int, default=512)
66 |
67 | return parent_parser
68 |
69 | @property
70 | def learnable_params(self) -> List[dict]:
71 | """Adds projector and predictor parameters to the parent's learnable parameters.
72 |
73 | Returns:
74 | List[dict]: list of learnable parameters.
75 | """
76 |
77 | extra_learnable_params = [
78 | {"params": self.projector.parameters()},
79 | {"params": self.predictor.parameters()},
80 | ]
81 | return super().learnable_params + extra_learnable_params
82 |
83 | @property
84 | def momentum_pairs(self) -> List[Tuple[Any, Any]]:
85 | """Adds (projector, momentum_projector) to the parent's momentum pairs.
86 |
87 | Returns:
88 | List[Tuple[Any, Any]]: list of momentum pairs.
89 | """
90 |
91 | extra_momentum_pairs = [(self.projector, self.momentum_projector)]
92 | return super().momentum_pairs + extra_momentum_pairs
93 |
94 | def forward(self, X: torch.Tensor, *args, **kwargs) -> Dict[str, Any]:
95 | """Performs forward pass of the online encoder (encoder, projector and predictor).
96 |
97 | Args:
98 | X (torch.Tensor): batch of images in tensor format.
99 |
100 | Returns:
101 | Dict[str, Any]: a dict containing the outputs of the parent and the logits of the head.
102 | """
103 |
104 | out = super().forward(X, *args, **kwargs)
105 | z = self.projector(out["feats"])
106 | p = self.predictor(z)
107 | return {**out, "z": z, "p": p}
108 |
109 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
110 | """Training step for BYOL reusing BaseModel training step.
111 |
112 | Args:
113 | batch (Sequence[Any]): a batch of data in the format of [img_indexes, [X], Y], where
114 | [X] is a list of size self.num_crops containing batches of images.
115 | batch_idx (int): index of the batch.
116 |
117 | Returns:
118 | torch.Tensor: total loss composed of BYOL and classification loss.
119 | """
120 |
121 | out = super().training_step(batch, batch_idx)
122 | feats1, feats2 = out["feats"]
123 | momentum_feats1, momentum_feats2 = out["momentum_feats"]
124 |
125 | z1 = self.projector(feats1)
126 | z2 = self.projector(feats2)
127 | p1 = self.predictor(z1)
128 | p2 = self.predictor(z2)
129 |
130 | # forward momentum encoder
131 | with torch.no_grad():
132 | z1_momentum = self.momentum_projector(momentum_feats1)
133 | z2_momentum = self.momentum_projector(momentum_feats2)
134 |
135 | # ------- contrastive loss -------
136 | neg_cos_sim = byol_loss_func(p1, z2_momentum) + byol_loss_func(p2, z1_momentum)
137 |
138 | # calculate std of features
139 | z1_std = F.normalize(z1, dim=-1).std(dim=0).mean()
140 | z2_std = F.normalize(z2, dim=-1).std(dim=0).mean()
141 | z_std = (z1_std + z2_std) / 2
142 |
143 | metrics = {
144 | "train_neg_cos_sim": neg_cos_sim,
145 | "train_z_std": z_std,
146 | }
147 | self.log_dict(metrics, on_epoch=True, sync_dist=True)
148 |
149 | out.update({"loss": out["loss"] + neg_cos_sim, "z": [z1, z2]})
150 | return out
151 |
--------------------------------------------------------------------------------
/evaluate_folder.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import subprocess
3 | import glob
4 | import os
5 | import re
6 | import multiprocessing as mp
7 | import signal
8 | import time
9 |
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument("--folder", type=str, required=True)
12 | parser.add_argument("--script", type=str, required=True)
13 | parser.add_argument("--start_task", type=int, default=0)
14 | parser.add_argument("--num_tasks", type=int, default=5)
15 | parser.add_argument("--gpus", type=str, default="0,1", help="Comma-separated list of gpu indices")
16 | parser.add_argument("--tag", type=str, default="")
17 |
18 |
19 | def consumer(consumer_id, job_queue, results_queue, setup_args=None):
20 | job = None
21 | try:
22 | print(f'{consumer_id} {os.getpid()} Starting consumer')
23 | results_queue.put({
24 | 'type': 'init',
25 | 'data': os.getpid()
26 | })
27 | # consumer_vars = consumer_setup(consumer_id, setup_args)
28 | while(True):
29 | job = job_queue.get(timeout=60)
30 | print(f"{consumer_id} {os.getpid()} get job") # {job}
31 | if job is None:
32 | break
33 | else:
34 | return_value = process_job(setup_args, job)
35 | results_queue.put({
36 | 'type': 'job_finished',
37 | 'data': return_value
38 | })
39 | print(f"{consumer_id} {os.getpid()} exitting loop")
40 | except Exception as e:
41 | print("DEBUG: There is some issue with the below arg settings. \n Copy the args and recreate the error by running contrastive_training.py for further debugging!")
42 | print(e)
43 | print(job)
44 | finally:
45 | if job is not None:
46 | job_queue.put(job)
47 | results_queue.put(None)
48 | print(f'Stopping consumer {consumer_id} {os.getpid()}')
49 |
50 | def process_job(consumer_vars, job):
51 | new_job = consumer_vars + " " + job
52 | run_command(new_job)
53 |
54 | def run_command(command):
55 | p = subprocess.Popen(command, shell=True)
56 | p.wait()
57 |
58 | def main(args):
59 | gpus = args.gpus.split(",")
60 |
61 | task_folders = glob.glob(os.path.join(args.folder, "*"))
62 | print(task_folders)
63 | task_folder_lookup = {}
64 | for folder in task_folders:
65 | re_match = re.search("task(?P\d*)-.*", os.path.basename(folder))
66 | if re_match is not None:
67 | re_match_dict = re_match.groupdict()
68 | if "task_idx" in re_match_dict:
69 | task_idx = int(re_match_dict["task_idx"])
70 | models = sorted(glob.glob(os.path.join(folder, '*.ckpt')))
71 | if len(models) > 0:
72 | task_folder_lookup[task_idx] = models[-1]
73 |
74 | all_jobs = []
75 | for task_idx in range(args.start_task, args.num_tasks):
76 | if task_idx in task_folder_lookup:
77 | job = f'TAG="T{task_idx}-{args.tag}" TASK_IDX={task_idx} NUM_TASKS={args.num_tasks} \
78 | PRETRAINED_PATH="{task_folder_lookup[task_idx]}" \
79 | bash {args.script}'
80 | all_jobs.append(job)
81 | # run_command(job)
82 | else:
83 | print("==============")
84 | print(f"[ERR] Cannot find model for task {task_idx}")
85 | print("==============")
86 |
87 | job_queue = mp.Queue()
88 | results_queue = mp.Queue()
89 | processes = [mp.Process(target=consumer, args=(i, job_queue, results_queue, f"CUDA_VISIBLE_DEVICES={gpus[i]}")) for i in range(len(gpus))]
90 | process_pids = []
91 | active_consumer_counter = len(processes)
92 | finished_job_counter = 0
93 | try:
94 | print("Putting jobs...")
95 | for job in all_jobs:
96 | job_queue.put(job)
97 |
98 | print(f"{os.getpid()} Server - starting consumers")
99 | for p in processes:
100 | p.start()
101 | for _ in range(len(processes)):
102 | job_queue.put(None)
103 | print(f"{os.getpid()} Server - finished putting jobs")
104 |
105 | while(True):
106 | job_results = results_queue.get()
107 | print("Job results", job_results)
108 | if job_results is None:
109 | active_consumer_counter -= 1
110 | if active_consumer_counter == 0:
111 | break
112 | elif job_results['type'] == 'init':
113 | process_pids.append(job_results['data'])
114 | elif job_results['type'] == 'job_finished':
115 | finished_job_counter += 1
116 |
117 | print('Closing workers')
118 | for p in processes:
119 | p.join(60)
120 | except KeyboardInterrupt:
121 | print("Interrupted from Keyboard")
122 | finally:
123 | print("Terminating Processes", processes)
124 | for p in processes:
125 | try:
126 | p.terminate()
127 | except Exception as e:
128 | print(f"Unable to terminate process {p}, processes might still exist.", e)
129 | print("Killing Processes", process_pids)
130 | for pid in process_pids:
131 | try:
132 | # os.kill(pid, signal.SIGTERM)
133 | os.kill(pid, signal.SIGKILL)
134 | except Exception as e:
135 | print(f"Unable to kill process {pid}, processes might still exist.", e)
136 | try:
137 | job_queue.close()
138 | results_queue.close()
139 | except Exception as e:
140 | print("Unable to close job queues, processes might still be open", e)
141 |
142 | print(f'Finished, processed {finished_job_counter} jobs')
143 |
144 | if __name__ == "__main__":
145 | args, unknown_args = parser.parse_known_args()
146 | main(args)
147 |
148 |
--------------------------------------------------------------------------------
/kaizen/distiller_factories/knowledge.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from typing import Any, List, Sequence
3 |
4 | import torch
5 | from torch import nn
6 | import torch.nn.functional as F
7 | from kaizen.args.utils import strtobool
8 |
9 | def cross_entropy(preds, targets):
10 | return -torch.mean(
11 | torch.sum(F.softmax(targets, dim=-1) * torch.log_softmax(preds, dim=-1), dim=-1)
12 | )
13 |
14 | def knowledge_distill_factory(
15 | Method=object, class_tag="",
16 | distill_current_key="z", distill_frozen_key="frozen_z", output_dim=256
17 | ):
18 | distill_lamb_name = f"{class_tag}_distill_lamb"
19 | distill_proj_hidden_dim_name = f"{class_tag}_distill_proj_hidden_dim"
20 | distill_temperature_name = f"{class_tag}_distill_temperature"
21 | distill_no_predictior_name = f"{class_tag}_distill_no_predictior"
22 |
23 | frozen_prototypes_name = f"{class_tag}_frozen_prototypes"
24 | distill_predictor_name = f"{class_tag}_distill_predictor"
25 | distill_prototypes_name = f"{class_tag}_distill_prototypes"
26 | class KnowledgeDistillWrapper(Method):
27 | def __init__(self, **kwargs):
28 | distill_lamb: float = kwargs.pop(distill_lamb_name)
29 | distill_proj_hidden_dim: int = kwargs.pop(distill_proj_hidden_dim_name)
30 | distill_temperature: float = kwargs.pop(distill_temperature_name)
31 | distill_no_predictior = kwargs.pop(distill_no_predictior_name)
32 | super().__init__(**kwargs)
33 |
34 | setattr(self, distill_lamb_name, distill_lamb)
35 | setattr(self, distill_proj_hidden_dim_name, distill_proj_hidden_dim)
36 | setattr(self, distill_temperature_name, distill_temperature)
37 | setattr(self, distill_no_predictior_name, distill_no_predictior)
38 | # TODO: Allow different num_prototypes for different distillers
39 | # TODO: Verify that these prototypes can be used for classifier distillation
40 | num_prototypes = kwargs["num_prototypes"]
41 |
42 | setattr(self, frozen_prototypes_name, nn.utils.weight_norm(
43 | nn.Linear(output_dim, num_prototypes, bias=False)
44 | ))
45 | for frozen_pg, pg in zip(
46 | getattr(self, frozen_prototypes_name).parameters(), self.prototypes.parameters() # TODO: Check this
47 | ):
48 | frozen_pg.data.copy_(pg.data)
49 | frozen_pg.requires_grad = False
50 |
51 | if distill_no_predictior:
52 | setattr(self, distill_predictor_name, nn.Identity())
53 | else:
54 | setattr(self, distill_predictor_name, nn.Sequential(
55 | nn.Linear(output_dim, distill_proj_hidden_dim),
56 | nn.BatchNorm1d(distill_proj_hidden_dim),
57 | nn.ReLU(),
58 | nn.Linear(distill_proj_hidden_dim, output_dim),
59 | ))
60 |
61 |
62 | setattr(self, distill_prototypes_name, nn.utils.weight_norm(
63 | nn.Linear(output_dim, num_prototypes, bias=False)
64 | ))
65 |
66 | @staticmethod
67 | def add_model_specific_args(
68 | parent_parser: argparse.ArgumentParser,
69 | ) -> argparse.ArgumentParser:
70 | parser = parent_parser.add_argument_group(f"knowledge_{class_tag}_distiller")
71 |
72 | parser.add_argument(f"--{distill_lamb_name}", type=float, default=1)
73 | parser.add_argument(f"--{distill_proj_hidden_dim_name}", type=int, default=2048)
74 | parser.add_argument(f"--{distill_temperature_name}", type=float, default=0.1)
75 | parser.add_argument(f"--{distill_no_predictior_name}", type=strtobool, default=True)
76 |
77 | return parent_parser
78 |
79 | @property
80 | def learnable_params(self) -> List[dict]:
81 | """Adds distill predictor parameters to the parent's learnable parameters.
82 |
83 | Returns:
84 | List[dict]: list of learnable parameters.
85 | """
86 |
87 | extra_learnable_params = [
88 | {"params": getattr(self, distill_predictor_name).parameters()},
89 | {"params": getattr(self, distill_prototypes_name).parameters()},
90 | ]
91 | return super().learnable_params + extra_learnable_params
92 |
93 | def on_train_start(self):
94 | super().on_train_start()
95 |
96 | if self.current_task_idx > 0:
97 | for frozen_pg, pg in zip(
98 | getattr(self, frozen_prototypes_name).parameters(), self.prototypes.parameters()
99 | ):
100 | # TODO: logits and prototypes have different shape
101 | frozen_pg.data.copy_(pg.data)
102 | frozen_pg.requires_grad = False
103 |
104 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
105 | out = super().training_step(batch, batch_idx)
106 | z1, z2 = out[distill_current_key]
107 | frozen_z1, frozen_z2 = out[distill_frozen_key]
108 |
109 | with torch.no_grad():
110 | frozen_z1 = F.normalize(frozen_z1)
111 | frozen_z2 = F.normalize(frozen_z2)
112 | frozen_p1 = getattr(self, frozen_prototypes_name)(frozen_z1) / getattr(self, distill_temperature_name)
113 | frozen_p2 = getattr(self, frozen_prototypes_name)(frozen_z2) / getattr(self, distill_temperature_name)
114 |
115 | distill_z1 = F.normalize(self.distill_predictor(z1))
116 | distill_z2 = F.normalize(self.distill_predictor(z2))
117 | distill_p1 = getattr(self, distill_prototypes_name)(distill_z1) / getattr(self, distill_temperature_name)
118 | distill_p2 = getattr(self, distill_prototypes_name)(distill_z2) / getattr(self, distill_temperature_name)
119 |
120 | distill_loss = (
121 | cross_entropy(distill_p1, frozen_p1) + cross_entropy(distill_p2, frozen_p2)
122 | ) / 2
123 |
124 | self.log(f"train_{class_tag}_knowledge_distill_loss", distill_loss, on_epoch=True, sync_dist=True)
125 |
126 | out["loss"] += getattr(self, distill_lamb_name) * distill_loss
127 | return out
128 |
129 | return KnowledgeDistillWrapper
130 |
--------------------------------------------------------------------------------
/kaizen/methods/ressl.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from typing import Any, Dict, List, Sequence, Tuple
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from kaizen.losses.ressl import ressl_loss_func
8 | from kaizen.methods.base import BaseMomentumModel
9 | from kaizen.utils.gather_layer import gather
10 | from kaizen.utils.momentum import initialize_momentum_params
11 |
12 |
13 | class ReSSL(BaseMomentumModel):
14 | def __init__(
15 | self,
16 | output_dim: int,
17 | proj_hidden_dim: int,
18 | temperature_q: float,
19 | temperature_k: float,
20 | queue_size: int,
21 | **kwargs,
22 | ):
23 | """Implements ReSSL (https://arxiv.org/abs/2107.09282v1).
24 |
25 | Args:
26 | output_dim (int): number of dimensions of projected features.
27 | proj_hidden_dim (int): number of neurons of the hidden layers of the projector.
28 | pred_hidden_dim (int): number of neurons of the hidden layers of the predictor.
29 | temperature_q (float): temperature for the contrastive augmentations.
30 | temperature_k (float): temperature for the weak augmentation.
31 | """
32 |
33 | super().__init__(**kwargs)
34 |
35 | # projector
36 | self.projector = nn.Sequential(
37 | nn.Linear(self.features_dim, proj_hidden_dim),
38 | nn.ReLU(),
39 | nn.Linear(proj_hidden_dim, output_dim),
40 | )
41 |
42 | # momentum projector
43 | self.momentum_projector = nn.Sequential(
44 | nn.Linear(self.features_dim, proj_hidden_dim),
45 | nn.ReLU(),
46 | nn.Linear(proj_hidden_dim, output_dim),
47 | )
48 | initialize_momentum_params(self.projector, self.momentum_projector)
49 |
50 | self.temperature_q = temperature_q
51 | self.temperature_k = temperature_k
52 | self.queue_size = queue_size
53 |
54 | # queue
55 | self.register_buffer("queue", torch.randn(self.queue_size, output_dim))
56 | self.queue = F.normalize(self.queue, dim=1)
57 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
58 |
59 | @staticmethod
60 | def add_model_specific_args(parent_parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
61 | parent_parser = super(ReSSL, ReSSL).add_model_specific_args(parent_parser)
62 | parser = parent_parser.add_argument_group("ressl")
63 |
64 | # projector
65 | parser.add_argument("--output_dim", type=int, default=256)
66 | parser.add_argument("--proj_hidden_dim", type=int, default=2048)
67 |
68 | # queue settings
69 | parser.add_argument("--queue_size", default=65536, type=int)
70 |
71 | # parameters
72 | parser.add_argument("--temperature_q", type=float, default=0.1)
73 | parser.add_argument("--temperature_k", type=float, default=0.04)
74 |
75 | return parent_parser
76 |
77 | @property
78 | def learnable_params(self) -> List[dict]:
79 | """Adds projector parameters to the parent's learnable parameters.
80 |
81 | Returns:
82 | List[dict]: list of learnable parameters.
83 | """
84 |
85 | extra_learnable_params = [
86 | {"params": self.projector.parameters()},
87 | ]
88 | return super().learnable_params + extra_learnable_params
89 |
90 | @property
91 | def momentum_pairs(self) -> List[Tuple[Any, Any]]:
92 | """Adds (projector, momentum_projector) to the parent's momentum pairs.
93 |
94 | Returns:
95 | List[Tuple[Any, Any]]: list of momentum pairs.
96 | """
97 |
98 | extra_momentum_pairs = [(self.projector, self.momentum_projector)]
99 | return super().momentum_pairs + extra_momentum_pairs
100 |
101 | @torch.no_grad()
102 | def dequeue_and_enqueue(self, k: torch.Tensor):
103 | """Adds new samples and removes old samples from the queue in a fifo manner.
104 |
105 | Args:
106 | z (torch.Tensor): batch of projected features.
107 | """
108 |
109 | k = gather(k)
110 |
111 | batch_size = k.shape[0]
112 |
113 | ptr = int(self.queue_ptr) # type: ignore
114 | assert self.queue_size % batch_size == 0
115 |
116 | self.queue[ptr : ptr + batch_size, :] = k
117 | ptr = (ptr + batch_size) % self.queue_size
118 |
119 | self.queue_ptr[0] = ptr # type: ignore
120 |
121 | def forward(self, X: torch.Tensor, *args, **kwargs) -> Dict[str, Any]:
122 | """Performs forward pass of the online encoder (encoder, projector and predictor).
123 |
124 | Args:
125 | X (torch.Tensor): batch of images in tensor format.
126 |
127 | Returns:
128 | Dict[str, Any]: a dict containing the outputs of the parent and the logits of the head.
129 | """
130 |
131 | out = super().forward(X, *args, **kwargs)
132 | q = F.normalize(self.projector(out["feats"]), dim=-1)
133 | return {**out, "q": q}
134 |
135 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
136 | """Training step for BYOL reusing BaseModel training step.
137 |
138 | Args:
139 | batch (Sequence[Any]): a batch of data in the format of [img_indexes, [X], Y], where
140 | [X] is a list of size self.num_crops containing batches of images.
141 | batch_idx (int): index of the batch.
142 |
143 | Returns:
144 | torch.Tensor: total loss composed of BYOL and classification loss.
145 | """
146 |
147 | out = super().training_step(batch, batch_idx)
148 | class_loss = out["loss"]
149 | feats1, _ = out["feats"]
150 | _, momentum_feats2 = out["momentum_feats"]
151 |
152 | q = self.projector(feats1)
153 |
154 | # forward momentum encoder
155 | with torch.no_grad():
156 | k = self.momentum_projector(momentum_feats2)
157 |
158 | q = F.normalize(q, dim=-1)
159 | k = F.normalize(k, dim=-1)
160 |
161 | # ------- contrastive loss -------
162 | queue = self.queue.clone().detach()
163 | ressl_loss = ressl_loss_func(q, k, queue, self.temperature_q, self.temperature_k)
164 |
165 | self.log("ressl_loss", ressl_loss, on_epoch=True, sync_dist=True)
166 |
167 | # dequeue and enqueue
168 | self.dequeue_and_enqueue(k)
169 |
170 | return ressl_loss + class_loss
171 |
--------------------------------------------------------------------------------
/kaizen/methods/simclr.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from typing import Any, Dict, List, Sequence
3 |
4 | import torch
5 | import torch.nn as nn
6 | from einops import repeat
7 | from kaizen.losses.simclr import manual_simclr_loss_func, simclr_loss_func
8 | from kaizen.methods.base import BaseModel
9 |
10 |
11 | class SimCLR(BaseModel):
12 | def __init__(
13 | self,
14 | output_dim: int,
15 | proj_hidden_dim: int,
16 | temperature: float,
17 | supervised: bool = False,
18 | **kwargs,
19 | ):
20 | """Implements SimCLR (https://arxiv.org/abs/2002.05709).
21 |
22 | Args:
23 | output_dim (int): number of dimensions of the projected features.
24 | proj_hidden_dim (int): number of neurons in the hidden layers of the projector.
25 | temperature (float): temperature for the softmax in the contrastive loss.
26 | supervised (bool): whether or not to use supervised contrastive loss. Defaults to False.
27 | """
28 |
29 | super().__init__(**kwargs)
30 |
31 | self.temperature = temperature
32 | self.supervised = supervised
33 |
34 | # projector
35 | self.projector = nn.Sequential(
36 | nn.Linear(self.features_dim, proj_hidden_dim),
37 | nn.ReLU(),
38 | nn.Linear(proj_hidden_dim, output_dim),
39 | )
40 |
41 | @staticmethod
42 | def add_model_specific_args(parent_parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
43 | parent_parser = super(SimCLR, SimCLR).add_model_specific_args(parent_parser)
44 | parser = parent_parser.add_argument_group("simclr")
45 |
46 | # projector
47 | parser.add_argument("--output_dim", type=int, default=128)
48 | parser.add_argument("--proj_hidden_dim", type=int, default=2048)
49 |
50 | # parameters
51 | parser.add_argument("--temperature", type=float, default=0.1)
52 |
53 | # supervised-simclr
54 | parser.add_argument("--supervised", action="store_true")
55 | return parent_parser
56 |
57 | @property
58 | def learnable_params(self) -> List[dict]:
59 | """Adds projector parameters to the parent's learnable parameters.
60 |
61 | Returns:
62 | List[dict]: list of learnable parameters.
63 | """
64 |
65 | extra_learnable_params = [{"params": self.projector.parameters()}]
66 | return super().learnable_params + extra_learnable_params
67 |
68 | def forward(self, X: torch.tensor, *args, **kwargs) -> Dict[str, Any]:
69 | """Performs the forward pass of the encoder, the projector and the predictor.
70 |
71 | Args:
72 | X (torch.Tensor): a batch of images in the tensor format.
73 |
74 | Returns:
75 | Dict[str, Any]:
76 | a dict containing the outputs of the parent
77 | and the projected and predicted features.
78 | """
79 |
80 | out = super().forward(X, *args, **kwargs)
81 | z = self.projector(out["feats"])
82 | return {**out, "z": z}
83 |
84 | @torch.no_grad()
85 | def gen_extra_positives_gt(self, Y: torch.Tensor) -> torch.Tensor:
86 | """Generates extra positives for supervised contrastive learning.
87 |
88 | Args:
89 | Y (torch.Tensor): labels of the samples of the batch.
90 |
91 | Returns:
92 | torch.Tensor: matrix with extra positives generated using the labels.
93 | """
94 |
95 | if self.multicrop:
96 | n_augs = self.num_crops + self.num_small_crops
97 | else:
98 | n_augs = 2
99 | labels_matrix = repeat(Y, "b -> c (d b)", c=n_augs * Y.size(0), d=n_augs)
100 | labels_matrix = (labels_matrix == labels_matrix.t()).fill_diagonal_(False)
101 | return labels_matrix
102 |
103 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
104 | """Training step for SimCLR and supervised SimCLR reusing BaseModel training step.
105 |
106 | Args:
107 | batch (Sequence[Any]): a batch of data in the format of [img_indexes, [X], Y], where
108 | [X] is a list of size self.num_crops containing batches of images.
109 | batch_idx (int): index of the batch.
110 |
111 | Returns:
112 | torch.Tensor: total loss composed of SimCLR loss and classification loss.
113 | """
114 |
115 | indexes, *_, target = batch[f"task{self.current_task_idx}"]
116 |
117 | out = super().training_step(batch, batch_idx)
118 |
119 | if self.multicrop:
120 | n_augs = self.num_crops + self.num_small_crops
121 |
122 | feats = out["feats"]
123 |
124 | z = torch.cat([self.projector(f) for f in feats])
125 |
126 | # ------- contrastive loss -------
127 | if self.supervised:
128 | pos_mask = self.gen_extra_positives_gt(target)
129 | else:
130 | index_matrix = repeat(indexes, "b -> c (d b)", c=n_augs * indexes.size(0), d=n_augs)
131 | pos_mask = (index_matrix == index_matrix.t()).fill_diagonal_(False)
132 | neg_mask = (~pos_mask).fill_diagonal_(False)
133 |
134 | nce_loss = manual_simclr_loss_func(
135 | z,
136 | pos_mask=pos_mask,
137 | neg_mask=neg_mask,
138 | temperature=self.temperature,
139 | )
140 | else:
141 | feats1, feats2 = out["feats"]
142 |
143 | z1 = self.projector(feats1)
144 | z2 = self.projector(feats2)
145 |
146 | # ------- contrastive loss -------
147 | if self.supervised:
148 | pos_mask = self.gen_extra_positives_gt(target)
149 | nce_loss = simclr_loss_func(
150 | z1, z2, extra_pos_mask=pos_mask, temperature=self.temperature
151 | )
152 | else:
153 | nce_loss = simclr_loss_func(z1, z2, temperature=self.temperature)
154 |
155 | # compute number of extra positives
156 | n_positives = (
157 | (pos_mask != 0).sum().float()
158 | if self.supervised
159 | else torch.tensor(0.0, device=self.device)
160 | )
161 |
162 | metrics = {
163 | "train_nce_loss": nce_loss,
164 | "train_n_positives": n_positives,
165 | }
166 | self.log_dict(metrics, on_epoch=True, sync_dist=True)
167 |
168 | out.update({"loss": out["loss"] + nce_loss, "z": [z1, z2]})
169 | return out
170 |
--------------------------------------------------------------------------------
/kaizen/utils/kmeans.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Sequence
2 |
3 | import numpy as np
4 | import torch
5 | import torch.distributed as dist
6 | import torch.nn.functional as F
7 | from scipy.sparse import csr_matrix
8 |
9 |
10 | class KMeans:
11 | def __init__(
12 | self,
13 | world_size: int,
14 | rank: int,
15 | num_crops: int,
16 | dataset_size: int,
17 | proj_features_dim: int,
18 | num_prototypes: int,
19 | kmeans_iters: int = 10,
20 | ):
21 | """Class that performs K-Means on the hypersphere.
22 |
23 | Args:
24 | world_size (int): world size.
25 | rank (int): rank of the current process.
26 | num_crops (int): number of crops.
27 | dataset_size (int): total size of the dataset (number of samples).
28 | proj_features_dim (int): number of dimensions of the projected features.
29 | num_prototypes (int): number of prototypes.
30 | kmeans_iters (int, optional): number of iterations for the k-means clustering.
31 | Defaults to 10.
32 | """
33 | self.world_size = world_size
34 | self.rank = rank
35 | self.num_crops = num_crops
36 | self.dataset_size = dataset_size
37 | self.proj_features_dim = proj_features_dim
38 | self.num_prototypes = num_prototypes
39 | self.kmeans_iters = kmeans_iters
40 |
41 | @staticmethod
42 | def get_indices_sparse(data: np.ndarray):
43 | cols = np.arange(data.size)
44 | M = csr_matrix((cols, (data.ravel(), cols)), shape=(int(data.max()) + 1, data.size))
45 | return [np.unravel_index(row.data, data.shape) for row in M]
46 |
47 | def cluster_memory(
48 | self,
49 | local_memory_index: torch.Tensor,
50 | local_memory_embeddings: torch.Tensor,
51 | ) -> Sequence[Any]:
52 | """Performs K-Means clustering on the hypersphere and returns centroids and
53 | assignments for each sample.
54 |
55 | Args:
56 | local_memory_index (torch.Tensor): memory bank cointaining indices of the
57 | samples.
58 | local_memory_embeddings (torch.Tensor): memory bank cointaining embeddings
59 | of the samples.
60 |
61 | Returns:
62 | Sequence[Any]: assignments and centroids.
63 | """
64 | j = 0
65 | device = local_memory_embeddings.device
66 | assignments = -torch.ones(len(self.num_prototypes), self.dataset_size).long()
67 | centroids_list = []
68 | with torch.no_grad():
69 | for i_K, K in enumerate(self.num_prototypes):
70 | # run distributed k-means
71 |
72 | # init centroids with elements from memory bank of rank 0
73 | centroids = torch.empty(K, self.proj_features_dim).to(device, non_blocking=True)
74 | if self.rank == 0:
75 | random_idx = torch.randperm(len(local_memory_embeddings[j]))[:K]
76 | assert len(random_idx) >= K, "please reduce the number of centroids"
77 | centroids = local_memory_embeddings[j][random_idx]
78 | if dist.is_available() and dist.is_initialized():
79 | dist.broadcast(centroids, 0)
80 |
81 | for n_iter in range(self.kmeans_iters + 1):
82 |
83 | # E step
84 | dot_products = torch.mm(local_memory_embeddings[j], centroids.t())
85 | _, local_assignments = dot_products.max(dim=1)
86 |
87 | # finish
88 | if n_iter == self.kmeans_iters:
89 | break
90 |
91 | # M step
92 | where_helper = self.get_indices_sparse(local_assignments.cpu().numpy())
93 | counts = torch.zeros(K).to(device, non_blocking=True).int()
94 | emb_sums = torch.zeros(K, self.proj_features_dim).to(device, non_blocking=True)
95 | for k in range(len(where_helper)):
96 | if len(where_helper[k][0]) > 0:
97 | emb_sums[k] = torch.sum(
98 | local_memory_embeddings[j][where_helper[k][0]],
99 | dim=0,
100 | )
101 | counts[k] = len(where_helper[k][0])
102 | if dist.is_available() and dist.is_initialized():
103 | dist.all_reduce(counts)
104 | dist.all_reduce(emb_sums)
105 | mask = counts > 0
106 | centroids[mask] = emb_sums[mask] / counts[mask].unsqueeze(1)
107 |
108 | # normalize centroids
109 | centroids = F.normalize(centroids, dim=1, p=2)
110 |
111 | centroids_list.append(centroids)
112 |
113 | if dist.is_available() and dist.is_initialized():
114 | # gather the assignments
115 | assignments_all = torch.empty(
116 | self.world_size,
117 | local_assignments.size(0),
118 | dtype=local_assignments.dtype,
119 | device=local_assignments.device,
120 | )
121 | assignments_all = list(assignments_all.unbind(0))
122 |
123 | dist_process = dist.all_gather(
124 | assignments_all, local_assignments, async_op=True
125 | )
126 | dist_process.wait()
127 | assignments_all = torch.cat(assignments_all).cpu()
128 |
129 | # gather the indexes
130 | indexes_all = torch.empty(
131 | self.world_size,
132 | local_memory_index.size(0),
133 | dtype=local_memory_index.dtype,
134 | device=local_memory_index.device,
135 | )
136 | indexes_all = list(indexes_all.unbind(0))
137 | dist_process = dist.all_gather(indexes_all, local_memory_index, async_op=True)
138 | dist_process.wait()
139 | indexes_all = torch.cat(indexes_all).cpu()
140 |
141 | else:
142 | assignments_all = local_assignments
143 | indexes_all = local_memory_index
144 |
145 | # log assignments
146 | assignments[i_K][indexes_all] = assignments_all
147 |
148 | # next memory bank to use
149 | j = (j + 1) % self.num_crops
150 |
151 | return assignments, centroids_list
152 |
--------------------------------------------------------------------------------
/kaizen/utils/auto_umap.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | from argparse import ArgumentParser, Namespace
4 | from pathlib import Path
5 | from typing import Optional, Union
6 |
7 | import pandas as pd
8 | import pytorch_lightning as pl
9 | import seaborn as sns
10 | import torch
11 | import umap
12 | import wandb
13 | from matplotlib import pyplot as plt
14 | from pytorch_lightning.callbacks import Callback
15 |
16 | from .gather_layer import gather
17 |
18 |
19 | class AutoUMAP(Callback):
20 | def __init__(
21 | self,
22 | args: Namespace,
23 | logdir: Union[str, Path] = Path("auto_umap"),
24 | frequency: int = 1,
25 | keep_previous: bool = False,
26 | color_palette: str = "hls",
27 | ):
28 | """UMAP callback that automatically runs UMAP on the validation dataset and uploads the
29 | figure to wandb.
30 |
31 | Args:
32 | args (Namespace): namespace object containing at least an attribute name.
33 | logdir (Union[str, Path], optional): base directory to store checkpoints.
34 | Defaults to Path("auto_umap").
35 | frequency (int, optional): number of epochs between each UMAP. Defaults to 1.
36 | color_palette (str, optional): color scheme for the classes. Defaults to "hls".
37 | keep_previous (bool, optional): whether to keep previous plots or not.
38 | Defaults to False.
39 | """
40 |
41 | super().__init__()
42 |
43 | self.args = args
44 | self.logdir = Path(logdir)
45 | self.frequency = frequency
46 | self.color_palette = color_palette
47 | self.keep_previous = keep_previous
48 |
49 | @staticmethod
50 | def add_auto_umap_args(parent_parser: ArgumentParser):
51 | """Adds user-required arguments to a parser.
52 |
53 | Args:
54 | parent_parser (ArgumentParser): parser to add new args to.
55 | """
56 |
57 | parser = parent_parser.add_argument_group("auto_umap")
58 | parser.add_argument("--auto_umap_dir", default=Path("auto_umap"), type=Path)
59 | parser.add_argument("--auto_umap_frequency", default=1, type=int)
60 | return parent_parser
61 |
62 | def initial_setup(self, trainer: pl.Trainer):
63 | """Creates the directories and does the initial setup needed.
64 |
65 | Args:
66 | trainer (pl.Trainer): pytorch lightning trainer object.
67 | """
68 |
69 | if trainer.logger is None:
70 | version = None
71 | else:
72 | version = str(trainer.logger.version)
73 | if version is not None:
74 | self.path = self.logdir / version
75 | self.umap_placeholder = f"{self.args.name}-{version}" + "-ep={}.pdf"
76 | else:
77 | self.path = self.logdir
78 | self.umap_placeholder = f"{self.args.name}" + "-ep={}.pdf"
79 | self.last_ckpt: Optional[str] = None
80 |
81 | # create logging dirs
82 | if trainer.is_global_zero:
83 | os.makedirs(self.path, exist_ok=True)
84 |
85 | def on_train_start(self, trainer: pl.Trainer, _):
86 | """Performs initial setup on training start.
87 |
88 | Args:
89 | trainer (pl.Trainer): pytorch lightning trainer object.
90 | """
91 |
92 | self.initial_setup(trainer)
93 |
94 | def plot(self, trainer: pl.Trainer, module: pl.LightningModule):
95 | """Produces a UMAP visualization by forwarding all data of the
96 | first validation dataloader through the module.
97 |
98 | Args:
99 | trainer (pl.Trainer): pytorch lightning trainer object.
100 | module (pl.LightningModule): current module object.
101 | """
102 |
103 | device = module.device
104 | data = []
105 | Y = []
106 |
107 | # set module to eval model and collect all feature representations
108 | module.eval()
109 | with torch.no_grad():
110 | for x, y in trainer.val_dataloaders[0]:
111 | x = x.to(device, non_blocking=True)
112 | y = y.to(device, non_blocking=True)
113 |
114 | feats = module(x)["feats"]
115 |
116 | feats = gather(feats)
117 | y = gather(y)
118 |
119 | data.append(feats.cpu())
120 | Y.append(y.cpu())
121 | module.train()
122 |
123 | if trainer.is_global_zero and len(data):
124 | data = torch.cat(data, dim=0).numpy()
125 | Y = torch.cat(Y, dim=0)
126 | num_classes = len(torch.unique(Y))
127 | Y = Y.numpy()
128 |
129 | data = umap.UMAP(n_components=2).fit_transform(data)
130 |
131 | # passing to dataframe
132 | df = pd.DataFrame()
133 | df["feat_1"] = data[:, 0]
134 | df["feat_2"] = data[:, 1]
135 | df["Y"] = Y
136 | plt.figure(figsize=(9, 9))
137 | ax = sns.scatterplot(
138 | x="feat_1",
139 | y="feat_2",
140 | hue="Y",
141 | palette=sns.color_palette(self.color_palette, num_classes),
142 | data=df,
143 | legend="full",
144 | alpha=0.3,
145 | )
146 | ax.set(xlabel="", ylabel="", xticklabels=[], yticklabels=[])
147 | ax.tick_params(left=False, right=False, bottom=False, top=False)
148 |
149 | # manually improve quality of imagenet umaps
150 | if num_classes > 100:
151 | anchor = (0.5, 1.8)
152 | else:
153 | anchor = (0.5, 1.35)
154 |
155 | plt.legend(loc="upper center", bbox_to_anchor=anchor, ncol=math.ceil(num_classes / 10))
156 | plt.tight_layout()
157 |
158 | if isinstance(trainer.logger, pl.loggers.WandbLogger):
159 | wandb.log(
160 | {"validation_umap": wandb.Image(ax)},
161 | commit=False,
162 | )
163 |
164 | # save plot locally as well
165 | epoch = trainer.current_epoch # type: ignore
166 | plt.savefig(self.path / self.umap_placeholder.format(epoch))
167 | plt.close()
168 |
169 | def on_validation_end(self, trainer: pl.Trainer, module: pl.LightningModule):
170 | """Tries to generate an up-to-date UMAP visualization of the features
171 | at the end of each validation epoch.
172 |
173 | Args:
174 | trainer (pl.Trainer): pytorch lightning trainer object.
175 | """
176 |
177 | epoch = trainer.current_epoch # type: ignore
178 | if epoch % self.frequency == 0 and not trainer.sanity_checking:
179 | self.plot(trainer, module)
180 |
--------------------------------------------------------------------------------
/kaizen/utils/knn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchmetrics.metric import Metric
3 |
4 |
5 | from typing import Sequence
6 |
7 | import torch
8 | from torchmetrics.metric import Metric
9 |
10 |
11 | class WeightedKNNClassifier(Metric):
12 | def __init__(
13 | self,
14 | k: int = 20,
15 | T: float = 0.07,
16 | max_distance_matrix_size: int = int(5e6),
17 | distance_fx: str = "cosine",
18 | epsilon: float = 0.00001,
19 | dist_sync_on_step: bool = False,
20 | ):
21 | """Implements the weighted k-NN classifier used for evaluation.
22 | Args:
23 | k (int, optional): number of neighbors. Defaults to 20.
24 | T (float, optional): temperature for the exponential. Only used with cosine
25 | distance. Defaults to 0.07.
26 | max_distance_matrix_size (int, optional): maximum number of elements in the
27 | distance matrix. Defaults to 5e6.
28 | distance_fx (str, optional): Distance function. Accepted arguments: "cosine" or
29 | "euclidean". Defaults to "cosine".
30 | epsilon (float, optional): Small value for numerical stability. Only used with
31 | euclidean distance. Defaults to 0.00001.
32 | dist_sync_on_step (bool, optional): whether to sync distributed values at every
33 | step. Defaults to False.
34 | """
35 |
36 | super().__init__(dist_sync_on_step=dist_sync_on_step, compute_on_step=False)
37 |
38 | self.k = k
39 | self.T = T
40 | self.max_distance_matrix_size = max_distance_matrix_size
41 | self.distance_fx = distance_fx
42 | self.epsilon = epsilon
43 |
44 | self.add_state("train_features", default=[], persistent=False)
45 | self.add_state("train_targets", default=[], persistent=False)
46 | self.add_state("test_features", default=[], persistent=False)
47 | self.add_state("test_targets", default=[], persistent=False)
48 |
49 | def update(
50 | self,
51 | train_features: torch.Tensor = None,
52 | train_targets: torch.Tensor = None,
53 | test_features: torch.Tensor = None,
54 | test_targets: torch.Tensor = None,
55 | ):
56 | """Updates the memory banks. If train (test) features are passed as input, the
57 | corresponding train (test) targets must be passed as well.
58 | Args:
59 | train_features (torch.Tensor, optional): a batch of train features. Defaults to None.
60 | train_targets (torch.Tensor, optional): a batch of train targets. Defaults to None.
61 | test_features (torch.Tensor, optional): a batch of test features. Defaults to None.
62 | test_targets (torch.Tensor, optional): a batch of test targets. Defaults to None.
63 | """
64 | assert (train_features is None) == (train_targets is None)
65 | assert (test_features is None) == (test_targets is None)
66 |
67 | if train_features is not None:
68 | assert train_features.size(0) == train_targets.size(0)
69 | self.train_features.append(train_features.detach())
70 | self.train_targets.append(train_targets.detach())
71 |
72 | if test_features is not None:
73 | assert test_features.size(0) == test_targets.size(0)
74 | self.test_features.append(test_features.detach())
75 | self.test_targets.append(test_targets.detach())
76 |
77 | @torch.no_grad()
78 | def compute(self) -> Sequence[float]:
79 | """Computes weighted k-NN accuracy @1 and @5. If cosine distance is selected,
80 | the weight is computed using the exponential of the temperature scaled cosine
81 | distance of the samples. If euclidean distance is selected, the weight corresponds
82 | to the inverse of the euclidean distance.
83 | Returns:
84 | Sequence[float]: k-NN accuracy @1 and @5.
85 | """
86 |
87 | train_features = torch.cat(self.train_features)
88 | train_targets = torch.cat(self.train_targets)
89 | test_features = torch.cat(self.test_features)
90 | test_targets = torch.cat(self.test_targets)
91 |
92 | num_classes = torch.unique(test_targets).numel()
93 | num_train_images = train_targets.size(0)
94 | num_test_images = test_targets.size(0)
95 | num_train_images = train_targets.size(0)
96 | chunk_size = min(
97 | max(1, self.max_distance_matrix_size // num_train_images),
98 | num_test_images,
99 | )
100 | k = min(self.k, num_train_images)
101 |
102 | top1, top5, total = 0.0, 0.0, 0
103 | retrieval_one_hot = torch.zeros(k, num_classes).to(train_features.device)
104 | for idx in range(0, num_test_images, chunk_size):
105 | # get the features for test images
106 | features = test_features[idx : min((idx + chunk_size), num_test_images), :]
107 | targets = test_targets[idx : min((idx + chunk_size), num_test_images)]
108 | batch_size = targets.size(0)
109 |
110 | # calculate the dot product and compute top-k neighbors
111 | if self.distance_fx == "cosine":
112 | similarity = torch.mm(features, train_features.t())
113 | elif self.distance_fx == "euclidean":
114 | similarity = 1 / (torch.cdist(features, train_features) + self.epsilon)
115 | else:
116 | raise NotImplementedError
117 |
118 | distances, indices = similarity.topk(k, largest=True, sorted=True)
119 | candidates = train_targets.view(1, -1).expand(batch_size, -1)
120 | retrieved_neighbors = torch.gather(candidates, 1, indices)
121 |
122 | retrieval_one_hot.resize_(batch_size * k, num_classes).zero_()
123 | retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1)
124 |
125 | if self.distance_fx == "cosine":
126 | distances = distances.clone().div_(self.T).exp_()
127 |
128 | probs = torch.sum(
129 | torch.mul(
130 | retrieval_one_hot.view(batch_size, -1, num_classes),
131 | distances.view(batch_size, -1, 1),
132 | ),
133 | 1,
134 | )
135 | _, predictions = probs.sort(1, True)
136 |
137 | # find the predictions that match the target
138 | correct = predictions.eq(targets.data.view(-1, 1))
139 | top1 = top1 + correct.narrow(1, 0, 1).sum().item()
140 | top5 = (
141 | top5 + correct.narrow(1, 0, min(5, k, correct.size(-1))).sum().item()
142 | ) # top5 does not make sense if k < 5
143 | total += targets.size(0)
144 |
145 | top1 = top1 * 100.0 / total
146 | top5 = top5 * 100.0 / total
147 |
148 | self.reset()
149 |
150 | return top1, top5
151 |
--------------------------------------------------------------------------------
/main_eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import types
3 |
4 | import torch
5 | import torch.nn as nn
6 | from pytorch_lightning import Trainer, seed_everything
7 | from pytorch_lightning.callbacks import LearningRateMonitor
8 | from pytorch_lightning.loggers import WandbLogger
9 | from pytorch_lightning.plugins import DDPPlugin
10 | from torchvision.models import resnet18, resnet50
11 |
12 | from kaizen.args.setup import parse_args_eval
13 | from kaizen.methods.full_model import FullModel
14 |
15 | try:
16 | from kaizen.methods.dali import ClassificationABC
17 | except ImportError:
18 | _dali_avaliable = False
19 | else:
20 | _dali_avaliable = True
21 | from kaizen.methods.linear import LinearModel
22 | from kaizen.methods.multi_layer_classifier import MultiLayerClassifier
23 | from kaizen.utils.classification_dataloader import prepare_data
24 | from kaizen.utils.checkpointer import Checkpointer
25 |
26 |
27 | def main():
28 | args = parse_args_eval()
29 |
30 | # split classes into tasks
31 | tasks = None
32 | if args.split_strategy == "class":
33 | assert args.num_classes % args.num_tasks == 0
34 | torch.manual_seed(args.split_seed)
35 | tasks = torch.randperm(args.num_classes).chunk(args.num_tasks)
36 |
37 | seed_everything(args.global_seed)
38 |
39 |
40 | # Build backbone
41 | if args.encoder == "resnet18":
42 | backbone = resnet18()
43 | elif args.encoder == "resnet50":
44 | backbone = resnet50()
45 | else:
46 | raise ValueError("Only [resnet18, resnet50] are currently supported.")
47 |
48 | if args.cifar:
49 | backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=2, bias=False)
50 | backbone.maxpool = nn.Identity()
51 | backbone.fc = nn.Identity()
52 |
53 | assert (
54 | args.pretrained_model.endswith(".ckpt")
55 | or args.pretrained_model.endswith(".pth")
56 | or args.pretrained_model.endswith(".pt")
57 | )
58 | ckpt_path = args.pretrained_model
59 |
60 |
61 | state = torch.load(ckpt_path, map_location=None if torch.cuda.is_available() else "cpu")["state_dict"]
62 | extracted_state = {}
63 | for k in list(state.keys()):
64 | if "encoder" in k:
65 | extracted_state[k.replace("encoder.", "")] = state[k]
66 | missing_keys_backbone, unexpected_keys_backbone = backbone.load_state_dict(extracted_state, strict=False)
67 | print("Missing keys - Backbone:", missing_keys_backbone)
68 |
69 | # Build Classifier
70 | classifier = MultiLayerClassifier(backbone.inplanes, args.num_classes, args.classifier_layers)
71 | # "--evaluation_mode", choices=["linear_eval", "classifier_eval", "online_classifier_eval"]
72 | if args.evaluation_mode == "linear_eval":
73 | is_model_training = True
74 | elif args.evaluation_mode == "classifier_eval":
75 | extracted_state = {}
76 | for k in state:
77 | if k.startswith("classifier."):
78 | extracted_state[k.replace("classifier.", "")] = state[k]
79 | missing_keys_classifier, unexpected_keys_classifier = classifier.load_state_dict(extracted_state, strict=False)
80 | is_model_training = False
81 | print("Missing keys - Classifier:", missing_keys_classifier)
82 | print("Unexpected keys - Classifier:", unexpected_keys_classifier)
83 | elif args.evaluation_mode == "online_classifier_eval":
84 | extracted_state = {}
85 | for k in state:
86 | if k.startswith("online_eval_classifier."):
87 | extracted_state[k.replace("online_eval_classifier.", "")] = state[k]
88 | missing_keys_classifier, unexpected_keys_classifier = classifier.load_state_dict(extracted_state, strict=False)
89 | is_model_training = False
90 | print("Missing keys - Classifier:", missing_keys_classifier)
91 | print("Unexpected keys - Classifier:", unexpected_keys_classifier)
92 |
93 | print(f"Loaded {ckpt_path}")
94 |
95 | if args.dali:
96 | assert _dali_avaliable, "Dali is not currently avaiable, please install it first."
97 | raise NotImplementedError("Dali is not supported")
98 | MethodClass = types.new_class(
99 | f"Dali{LinearModel.__name__}", (ClassificationABC, LinearModel)
100 | )
101 | # else:
102 | # MethodClass = LinearModel
103 |
104 | model = FullModel(backbone, classifier=classifier, **args.__dict__, tasks=tasks)
105 |
106 | if is_model_training:
107 | train_loader, val_loader = prepare_data(
108 | args.dataset,
109 | data_dir=args.data_dir,
110 | train_dir=args.train_dir,
111 | val_dir=args.val_dir,
112 | batch_size=args.batch_size,
113 | num_workers=args.num_workers,
114 | semi_supervised=args.semi_supervised,
115 | training_data_source=args.linear_classifier_training_data_source,
116 | training_num_tasks=args.num_tasks,
117 | training_tasks=tasks,
118 | training_task_idx=args.task_idx,
119 | training_split_strategy=args.split_strategy,
120 | training_split_seed=args.split_seed,
121 | replay=args.replay,
122 | replay_proportion=args.replay_proportion,
123 | replay_memory_bank_size=args.replay_memory_bank_size
124 | )
125 | else:
126 | _, val_loader = prepare_data(
127 | args.dataset,
128 | data_dir=args.data_dir,
129 | train_dir=args.train_dir,
130 | val_dir=args.val_dir,
131 | batch_size=args.batch_size,
132 | num_workers=args.num_workers,
133 | semi_supervised=args.semi_supervised,
134 | )
135 |
136 | callbacks = []
137 |
138 | # wandb logging
139 | if args.wandb:
140 | wandb_logger = WandbLogger(
141 | name=args.name, project=args.project, entity=args.entity, offline=args.offline
142 | )
143 | wandb_logger.watch(model, log="gradients", log_freq=100)
144 | wandb_logger.log_hyperparams(args)
145 |
146 | # lr logging
147 | lr_monitor = LearningRateMonitor(logging_interval="epoch")
148 | callbacks.append(lr_monitor)
149 |
150 | # save checkpoint on last epoch only
151 | ckpt = Checkpointer(
152 | args,
153 | logdir=os.path.join(args.checkpoint_dir, "linear"),
154 | frequency=args.checkpoint_frequency,
155 | )
156 | callbacks.append(ckpt)
157 |
158 | trainer = Trainer.from_argparse_args(
159 | args,
160 | logger=wandb_logger if args.wandb else None,
161 | callbacks=callbacks,
162 | plugins=DDPPlugin(find_unused_parameters=True),
163 | checkpoint_callback=False,
164 | terminate_on_nan=True,
165 | accelerator="ddp" if torch.cuda.is_available() else "cpu",
166 | )
167 | if is_model_training:
168 | if args.dali:
169 | trainer.fit(model, val_dataloaders=val_loader)
170 | else:
171 | trainer.fit(model, train_loader, val_loader)
172 | else:
173 | trainer.validate(model, val_loader)
174 |
175 | if __name__ == "__main__":
176 | main()
177 |
--------------------------------------------------------------------------------
/kaizen/methods/mocov2plus.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from typing import Any, Dict, List, Sequence, Tuple
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from kaizen.losses.moco import moco_loss_func
8 | from kaizen.methods.base import BaseMomentumModel
9 | from kaizen.utils.gather_layer import gather
10 | from kaizen.utils.momentum import initialize_momentum_params
11 |
12 |
13 | class MoCoV2Plus(BaseMomentumModel):
14 | queue: torch.Tensor
15 |
16 | def __init__(
17 | self, output_dim: int, proj_hidden_dim: int, temperature: float, queue_size: int, **kwargs
18 | ):
19 | """Implements MoCo V2+ (https://arxiv.org/abs/2011.10566).
20 |
21 | Args:
22 | output_dim (int): number of dimensions of projected features.
23 | proj_hidden_dim (int): number of neurons of the hidden layers of the projector.
24 | temperature (float): temperature for the softmax in the contrastive loss.
25 | queue_size (int): number of samples to keep in the queue.
26 | """
27 |
28 | super().__init__(**kwargs)
29 |
30 | self.temperature = temperature
31 | self.queue_size = queue_size
32 |
33 | # projector
34 | self.projector = nn.Sequential(
35 | nn.Linear(self.features_dim, proj_hidden_dim),
36 | nn.ReLU(),
37 | nn.Linear(proj_hidden_dim, output_dim),
38 | )
39 |
40 | # momentum projector
41 | self.momentum_projector = nn.Sequential(
42 | nn.Linear(self.features_dim, proj_hidden_dim),
43 | nn.ReLU(),
44 | nn.Linear(proj_hidden_dim, output_dim),
45 | )
46 | initialize_momentum_params(self.projector, self.momentum_projector)
47 |
48 | # create the queue
49 | self.register_buffer("queue", torch.randn(2, output_dim, queue_size))
50 | self.queue = nn.functional.normalize(self.queue, dim=1)
51 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
52 |
53 | @staticmethod
54 | def add_model_specific_args(parent_parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
55 | parent_parser = super(MoCoV2Plus, MoCoV2Plus).add_model_specific_args(parent_parser)
56 | parser = parent_parser.add_argument_group("mocov2plus")
57 |
58 | # projector
59 | parser.add_argument("--output_dim", type=int, default=128)
60 | parser.add_argument("--proj_hidden_dim", type=int, default=2048)
61 |
62 | # parameters
63 | parser.add_argument("--temperature", type=float, default=0.1)
64 |
65 | # queue settings
66 | parser.add_argument("--queue_size", default=65536, type=int)
67 |
68 | return parent_parser
69 |
70 | @property
71 | def learnable_params(self) -> List[dict]:
72 | """Adds projector parameters together with parent's learnable parameters.
73 |
74 | Returns:
75 | List[dict]: list of learnable parameters.
76 | """
77 |
78 | extra_learnable_params = [{"params": self.projector.parameters()}]
79 | return super().learnable_params + extra_learnable_params
80 |
81 | @property
82 | def momentum_pairs(self) -> List[Tuple[Any, Any]]:
83 | """Adds (projector, momentum_projector) to the parent's momentum pairs.
84 |
85 | Returns:
86 | List[Tuple[Any, Any]]: list of momentum pairs.
87 | """
88 |
89 | extra_momentum_pairs = [(self.projector, self.momentum_projector)]
90 | return super().momentum_pairs + extra_momentum_pairs
91 |
92 | @torch.no_grad()
93 | def _dequeue_and_enqueue(self, keys: torch.Tensor):
94 | """Adds new samples and removes old samples from the queue in a fifo manner.
95 |
96 | Args:
97 | keys (torch.Tensor): output features of the momentum encoder.
98 | """
99 |
100 | batch_size = keys.shape[1]
101 | ptr = int(self.queue_ptr) # type: ignore
102 | keys = keys.permute(0, 2, 1)
103 |
104 | # assert self.queue_size % batch_size == 0 # for simplicity
105 | # Allow non divisible queue size
106 | if ptr + batch_size > self.queue_size:
107 | remaining_slots = self.queue_size - ptr
108 | self.queue[:, :, ptr : ptr + remaining_slots] = keys[:, :, :remaining_slots]
109 | self.queue[:, :, :batch_size - remaining_slots] = keys[:, :, remaining_slots:]
110 | ptr = batch_size - remaining_slots
111 | else:
112 | # replace the keys at ptr (dequeue and enqueue)
113 | self.queue[:, :, ptr : ptr + batch_size] = keys
114 | # ptr = (ptr + batch_size) % self.queue_size # move pointer
115 | ptr += batch_size
116 | self.queue_ptr[0] = ptr # type: ignore
117 |
118 | def forward(self, X: torch.Tensor, *args, **kwargs) -> Dict[str, Any]:
119 | """Performs the forward pass of the online encoder and the online projection.
120 |
121 | Args:
122 | X (torch.Tensor): a batch of images in the tensor format.
123 |
124 | Returns:
125 | Dict[str, Any]: a dict containing the outputs of the parent and the projected features.
126 | """
127 |
128 | out = super().forward(X, *args, **kwargs)
129 | q = F.normalize(self.projector(out["feats"]), dim=-1)
130 | return {**out, "q": q}
131 |
132 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
133 | """
134 | Training step for MoCo reusing BaseMomentumModel training step.
135 |
136 | Args:
137 | batch (Sequence[Any]): a batch of data in the
138 | format of [img_indexes, [X], Y], where [X] is a list of size self.num_crops
139 | containing batches of images.
140 | batch_idx (int): index of the batch.
141 |
142 | Returns:
143 | torch.Tensor: total loss composed of MOCO loss and classification loss.
144 |
145 | """
146 |
147 | out = super().training_step(batch, batch_idx)
148 | feats1, feats2 = out["feats"]
149 | momentum_feats1, momentum_feats2 = out["momentum_feats"]
150 |
151 | q1 = self.projector(feats1)
152 | q2 = self.projector(feats2)
153 | q1 = F.normalize(q1, dim=-1)
154 | q2 = F.normalize(q2, dim=-1)
155 |
156 | with torch.no_grad():
157 | k1 = self.momentum_projector(momentum_feats1)
158 | k2 = self.momentum_projector(momentum_feats2)
159 | k1 = F.normalize(k1, dim=-1)
160 | k2 = F.normalize(k2, dim=-1)
161 |
162 | # ------- contrastive loss -------
163 | # symmetric
164 | queue = self.queue.clone().detach()
165 | nce_loss = (
166 | moco_loss_func(q1, k2, queue[1], self.temperature)
167 | + moco_loss_func(q2, k1, queue[0], self.temperature)
168 | ) / 2
169 |
170 | # ------- update queue -------
171 | keys = torch.stack((gather(k1), gather(k2)))
172 | self._dequeue_and_enqueue(keys)
173 |
174 | self.log("train_nce_loss", nce_loss, on_epoch=True, sync_dist=True)
175 |
176 | out.update({"loss": out["loss"] + nce_loss, "z": [q1, q2]})
177 | return out
178 |
--------------------------------------------------------------------------------
/kaizen/args/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from argparse import Namespace
3 | import distutils
4 |
5 | N_CLASSES_PER_DATASET = {
6 | "cifar10": 10,
7 | "cifar100": 100,
8 | "stl10": 10,
9 | "imagenet": 1000,
10 | "imagenet100": 100,
11 | "domainnet": 345,
12 | }
13 |
14 | def strtobool(v):
15 | return bool(distutils.util.strtobool(v))
16 |
17 | def additional_setup_pretrain(args: Namespace):
18 | """Provides final setup for pretraining to non-user given parameters by changing args.
19 |
20 | Parsers arguments to extract the number of classes of a dataset, create
21 | transformations kwargs, correctly parse gpus, identify if a cifar dataset
22 | is being used and adjust the lr.
23 |
24 | Args:
25 | args (Namespace): object that needs to contain, at least:
26 | - dataset: dataset name.
27 | - brightness, contrast, saturation, hue, min_scale: required augmentations
28 | settings.
29 | - multicrop: flag to use multicrop.
30 | - dali: flag to use dali.
31 | - optimizer: optimizer name being used.
32 | - gpus: list of gpus to use.
33 | - lr: learning rate.
34 |
35 | [optional]
36 | - gaussian_prob, solarization_prob: optional augmentations settings.
37 | """
38 |
39 | args.transform_kwargs = {}
40 |
41 | if args.dataset in N_CLASSES_PER_DATASET:
42 | args.num_classes = N_CLASSES_PER_DATASET[args.dataset]
43 | else:
44 | # hack to maintain the current pipeline
45 | # even if the custom dataset doesn't have any labels
46 | dir_path = args.data_dir / args.train_dir
47 | args.num_classes = max(
48 | 1,
49 | len([entry.name for entry in os.scandir(dir_path) if entry.is_dir]),
50 | )
51 |
52 | unique_augs = max(
53 | len(p)
54 | for p in [
55 | args.brightness,
56 | args.contrast,
57 | args.saturation,
58 | args.hue,
59 | args.gaussian_prob,
60 | args.solarization_prob,
61 | args.min_scale,
62 | args.size,
63 | ]
64 | )
65 | assert unique_augs == args.num_crops or unique_augs == 1
66 |
67 | # assert that either all unique augmentation pipelines have a unique
68 | # parameter or that a single parameter is replicated to all pipelines
69 | for p in [
70 | "brightness",
71 | "contrast",
72 | "saturation",
73 | "hue",
74 | "gaussian_prob",
75 | "solarization_prob",
76 | "min_scale",
77 | "size",
78 | ]:
79 | values = getattr(args, p)
80 | n = len(values)
81 | assert n == unique_augs or n == 1
82 |
83 | if n == 1:
84 | setattr(args, p, getattr(args, p) * unique_augs)
85 |
86 | args.unique_augs = unique_augs
87 |
88 | if unique_augs > 1:
89 | args.transform_kwargs = [
90 | dict(
91 | brightness=brightness,
92 | contrast=contrast,
93 | saturation=saturation,
94 | hue=hue,
95 | gaussian_prob=gaussian_prob,
96 | solarization_prob=solarization_prob,
97 | min_scale=min_scale,
98 | size=size,
99 | )
100 | for (
101 | brightness,
102 | contrast,
103 | saturation,
104 | hue,
105 | gaussian_prob,
106 | solarization_prob,
107 | min_scale,
108 | size,
109 | ) in zip(
110 | args.brightness,
111 | args.contrast,
112 | args.saturation,
113 | args.hue,
114 | args.gaussian_prob,
115 | args.solarization_prob,
116 | args.min_scale,
117 | args.size,
118 | )
119 | ]
120 |
121 | elif not args.multicrop:
122 | args.transform_kwargs = dict(
123 | brightness=args.brightness[0],
124 | contrast=args.contrast[0],
125 | saturation=args.saturation[0],
126 | hue=args.hue[0],
127 | gaussian_prob=args.gaussian_prob[0],
128 | solarization_prob=args.solarization_prob[0],
129 | min_scale=args.min_scale[0],
130 | size=args.size[0],
131 | )
132 | else:
133 | args.transform_kwargs = dict(
134 | brightness=args.brightness[0],
135 | contrast=args.contrast[0],
136 | saturation=args.saturation[0],
137 | hue=args.hue[0],
138 | gaussian_prob=args.gaussian_prob[0],
139 | solarization_prob=args.solarization_prob[0],
140 | )
141 |
142 | # add support for custom mean and std
143 | if args.dataset == "custom":
144 | if isinstance(args.transform_kwargs, dict):
145 | args.transform_kwargs["mean"] = args.mean
146 | args.transform_kwargs["std"] = args.std
147 | else:
148 | for kwargs in args.transform_kwargs:
149 | kwargs["mean"] = args.mean
150 | kwargs["std"] = args.std
151 |
152 | if args.dataset in ["cifar10", "cifar100", "stl10"]:
153 | if isinstance(args.transform_kwargs, dict):
154 | del args.transform_kwargs["size"]
155 | else:
156 | for kwargs in args.transform_kwargs:
157 | del kwargs["size"]
158 |
159 | args.cifar = True if args.dataset in ["cifar10", "cifar100"] else False
160 |
161 | if args.dali:
162 | assert args.dataset in ["imagenet100", "imagenet", "domainnet", "custom"]
163 |
164 | args.extra_optimizer_args = {}
165 | if args.optimizer == "sgd":
166 | args.extra_optimizer_args["momentum"] = 0.9
167 |
168 | if isinstance(args.gpus, int):
169 | args.gpus = [args.gpus]
170 | elif isinstance(args.gpus, str):
171 | args.gpus = [int(gpu) for gpu in args.gpus.split(",") if gpu]
172 |
173 | # adjust lr according to batch size
174 | args.lr = args.lr * args.batch_size * len(args.gpus) / 256
175 |
176 |
177 | def additional_setup_linear(args: Namespace):
178 | """Provides final setup for linear evaluation to non-user given parameters by changing args.
179 |
180 | Parsers arguments to extract the number of classes of a dataset, correctly parse gpus, identify
181 | if a cifar dataset is being used and adjust the lr.
182 |
183 | Args:
184 | args: Namespace object that needs to contain, at least:
185 | - dataset: dataset name.
186 | - optimizer: optimizer name being used.
187 | - gpus: list of gpus to use.
188 | - lr: learning rate.
189 | """
190 |
191 | assert args.dataset in N_CLASSES_PER_DATASET
192 | args.num_classes = N_CLASSES_PER_DATASET[args.dataset]
193 |
194 | args.cifar = True if args.dataset in ["cifar10", "cifar100"] else False
195 |
196 | if args.dali:
197 | assert args.dataset in ["imagenet100", "imagenet", "domainnet"]
198 |
199 | args.extra_optimizer_args = {}
200 | if args.optimizer == "sgd":
201 | args.extra_optimizer_args["momentum"] = 0.9
202 |
203 | if isinstance(args.gpus, int):
204 | args.gpus = [args.gpus]
205 | elif isinstance(args.gpus, str):
206 | args.gpus = [int(gpu) for gpu in args.gpus.split(",") if gpu]
207 |
--------------------------------------------------------------------------------