├── assets ├── method.png └── results.png ├── cassle ├── __init__.py ├── args │ ├── __init__.py │ ├── continual.py │ ├── dataset.py │ └── setup.py ├── losses │ ├── wmse.py │ ├── byol.py │ ├── simsiam.py │ ├── nnclr.py │ ├── __init__.py │ ├── moco.py │ ├── swav.py │ ├── deepclusterv2.py │ ├── ressl.py │ ├── barlow.py │ ├── vicreg.py │ ├── dino.py │ └── simclr.py ├── utils │ ├── __init__.py │ ├── gather_layer.py │ ├── whitening.py │ ├── metrics.py │ ├── trunc_normal.py │ ├── datasets.py │ ├── sinkhorn_knopp.py │ ├── momentum.py │ └── lars.py ├── distillers │ ├── __init__.py │ ├── base.py │ ├── predictive_mse.py │ ├── predictive.py │ ├── contrastive.py │ ├── decorrelative.py │ └── knowledge.py └── methods │ ├── __init__.py │ ├── barlow_twins.py │ ├── vicreg.py │ ├── wmse.py │ └── simsiam.py ├── bash_files ├── linear │ ├── imagenet-100 │ │ ├── class │ │ │ ├── byol_linear.sh │ │ │ ├── barlow_linear.sh │ │ │ ├── simclr_linear.sh │ │ │ ├── supcon_linear.sh │ │ │ ├── swav_linear.sh │ │ │ ├── vicreg_linear.sh │ │ │ ├── simsiam_linear.sh │ │ │ └── mocov2plus_linear.sh │ │ └── data │ │ │ ├── barlow_linear.sh │ │ │ ├── byol_linear.sh │ │ │ ├── swav_linear.sh │ │ │ ├── simclr_linear.sh │ │ │ ├── supcon_linear.sh │ │ │ ├── vicreg_linear.sh │ │ │ ├── simsiam_linear.sh │ │ │ └── mocov2plus_linear.sh │ └── domainnet │ │ └── domain │ │ ├── byol_linear.sh │ │ ├── barlow_linear.sh │ │ ├── swav_linear.sh │ │ ├── vicreg_linear.sh │ │ ├── supcon_linear.sh │ │ ├── simclr_linear.sh │ │ └── mocov2plus_linear.sh └── continual │ ├── cifar │ ├── simclr.sh │ ├── barlow.sh │ ├── swav.sh │ ├── vicreg.sh │ ├── simclr_distill.sh │ ├── byol.sh │ ├── barlow_distill.sh │ ├── swav_distill.sh │ ├── vicreg_distill.sh │ └── byol_distill.sh │ ├── domainnet │ ├── simclr.sh │ ├── supcon.sh │ ├── mocov2plus.sh │ ├── simclr_distill.sh │ ├── barlow.sh │ ├── supcon_distill.sh │ ├── vicreg.sh │ ├── mocov2plus_distill.sh │ ├── swav.sh │ ├── barlow_distill.sh │ ├── byol.sh │ ├── vicreg_distill.sh │ ├── swav_distill.sh │ └── byol_distill.sh │ └── imagenet-100 │ ├── class │ ├── simclr.sh │ ├── supcon.sh │ ├── mocov2plus.sh │ ├── simclr_distill.sh │ ├── barlow.sh │ ├── supcon_distill.sh │ ├── vicreg.sh │ ├── mocov2plus_distill.sh │ ├── swav.sh │ ├── byol.sh │ ├── barlow_distill.sh │ ├── vicreg_distill.sh │ ├── swav_distill.sh │ └── byol_distill.sh │ └── data │ ├── simclr.sh │ ├── supcon.sh │ ├── mocov2plus.sh │ ├── simclr_distill.sh │ ├── barlow.sh │ ├── supcon_distill.sh │ ├── vicreg.sh │ ├── mocov2plus_distill.sh │ ├── swav.sh │ ├── barlow_distill.sh │ ├── byol.sh │ ├── vicreg_distill.sh │ ├── swav_distill.sh │ └── byol_distill.sh ├── LICENSE ├── .gitignore ├── main_continual.py ├── job_launcher.py └── main_linear.py /assets/method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DonkeyShot21/cassle/HEAD/assets/method.png -------------------------------------------------------------------------------- /assets/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DonkeyShot21/cassle/HEAD/assets/results.png -------------------------------------------------------------------------------- /cassle/__init__.py: -------------------------------------------------------------------------------- 1 | from cassle import args, losses, methods, utils 2 | 3 | __all__ = ["args", "losses", "methods", "utils"] 4 | -------------------------------------------------------------------------------- /cassle/args/__init__.py: -------------------------------------------------------------------------------- 1 | from cassle.args import dataset, setup, utils, continual 2 | 3 | __all__ = ["dataset", "setup", "utils", "continual"] 4 | -------------------------------------------------------------------------------- /cassle/args/continual.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | 4 | def continual_args(parser: ArgumentParser): 5 | """Adds continual learning arguments to a parser. 6 | 7 | Args: 8 | parser (ArgumentParser): parser to add dataset args to. 9 | """ 10 | # base continual learning args 11 | parser.add_argument("--num_tasks", type=int, default=2) 12 | parser.add_argument("--task_idx", type=int, required=True) 13 | 14 | SPLIT_STRATEGIES = ["class", "data", "domain"] 15 | parser.add_argument("--split_strategy", choices=SPLIT_STRATEGIES, type=str, required=True) 16 | 17 | # distillation args 18 | parser.add_argument("--distiller", type=str, default=None) 19 | -------------------------------------------------------------------------------- /bash_files/linear/imagenet-100/class/byol_linear.sh: -------------------------------------------------------------------------------- 1 | python3 main_linear.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy class \ 8 | --num_tasks 5 \ 9 | --max_epochs 100 \ 10 | --gpus 0 \ 11 | --precision 16 \ 12 | --optimizer sgd \ 13 | --scheduler step \ 14 | --lr 3.0 \ 15 | --lr_decay_steps 60 80 \ 16 | --weight_decay 0 \ 17 | --batch_size 256 \ 18 | --num_workers 8 \ 19 | --dali \ 20 | --name byol-imagenet100-5T-linear-eval \ 21 | --pretrained_feature_extractor $PRETRAINED_PATH \ 22 | --project ever-learn \ 23 | --entity unitn-mhug \ 24 | --wandb \ 25 | --save_checkpoint -------------------------------------------------------------------------------- /bash_files/linear/imagenet-100/class/barlow_linear.sh: -------------------------------------------------------------------------------- 1 | python3 main_linear.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy class \ 8 | --num_tasks 5 \ 9 | --max_epochs 100 \ 10 | --gpus 0 \ 11 | --precision 16 \ 12 | --optimizer sgd \ 13 | --scheduler step \ 14 | --lr 0.1 \ 15 | --lr_decay_steps 60 80 \ 16 | --weight_decay 0 \ 17 | --batch_size 256 \ 18 | --num_workers 8 \ 19 | --dali \ 20 | --name barlow-imagenet100-5T-linear-eval \ 21 | --pretrained_feature_extractor $PRETRAINED_PATH \ 22 | --project ever-learn \ 23 | --entity unitn-mhug \ 24 | --wandb \ 25 | --save_checkpoint 26 | -------------------------------------------------------------------------------- /bash_files/linear/imagenet-100/class/simclr_linear.sh: -------------------------------------------------------------------------------- 1 | python3 main_linear.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy class \ 8 | --num_tasks 5 \ 9 | --max_epochs 100 \ 10 | --gpus 0 \ 11 | --precision 16 \ 12 | --optimizer sgd \ 13 | --scheduler step \ 14 | --lr 1.0 \ 15 | --lr_decay_steps 60 80 \ 16 | --weight_decay 0 \ 17 | --batch_size 256 \ 18 | --num_workers 7 \ 19 | --dali \ 20 | --name simclr-imagenet100-5T-linear-eval \ 21 | --pretrained_feature_extractor $PRETRAINED_PATH \ 22 | --project ever-learn \ 23 | --entity unitn-mhug \ 24 | --wandb \ 25 | --save_checkpoint 26 | -------------------------------------------------------------------------------- /bash_files/linear/imagenet-100/class/supcon_linear.sh: -------------------------------------------------------------------------------- 1 | python3 main_linear.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy class \ 8 | --num_tasks 5 \ 9 | --max_epochs 100 \ 10 | --gpus 0 \ 11 | --precision 16 \ 12 | --optimizer sgd \ 13 | --scheduler step \ 14 | --lr 1.0 \ 15 | --lr_decay_steps 60 80 \ 16 | --weight_decay 0 \ 17 | --batch_size 256 \ 18 | --num_workers 7 \ 19 | --dali \ 20 | --name supcon-imagenet100-5T-linear-eval \ 21 | --pretrained_feature_extractor $PRETRAINED_PATH \ 22 | --project ever-learn \ 23 | --entity unitn-mhug \ 24 | --wandb \ 25 | --save_checkpoint 26 | -------------------------------------------------------------------------------- /bash_files/linear/imagenet-100/class/swav_linear.sh: -------------------------------------------------------------------------------- 1 | python3 main_linear.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy class \ 8 | --num_tasks 5 \ 9 | --max_epochs 100 \ 10 | --gpus 0 \ 11 | --precision 16 \ 12 | --optimizer sgd \ 13 | --scheduler step \ 14 | --lr 0.15 \ 15 | --lr_decay_steps 60 80 \ 16 | --weight_decay 0 \ 17 | --batch_size 256 \ 18 | --num_workers 7 \ 19 | --dali \ 20 | --name swav-imagenet100-5T-linear-eval \ 21 | --pretrained_feature_extractor $PRETRAINED_PATH \ 22 | --project ever-learn \ 23 | --entity unitn-mhug \ 24 | --wandb \ 25 | --save_checkpoint 26 | -------------------------------------------------------------------------------- /bash_files/linear/imagenet-100/class/vicreg_linear.sh: -------------------------------------------------------------------------------- 1 | python3 main_linear.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy class \ 8 | --num_tasks 5 \ 9 | --max_epochs 100 \ 10 | --gpus 0 \ 11 | --precision 16 \ 12 | --optimizer sgd \ 13 | --scheduler step \ 14 | --lr 0.3 \ 15 | --lr_decay_steps 60 80 \ 16 | --weight_decay 0 \ 17 | --batch_size 256 \ 18 | --num_workers 7 \ 19 | --dali \ 20 | --name vicreg-imagenet100-5T-linear-eval \ 21 | --pretrained_feature_extractor $PRETRAINED_PATH \ 22 | --project ever-learn \ 23 | --entity unitn-mhug \ 24 | --wandb \ 25 | --save_checkpoint 26 | -------------------------------------------------------------------------------- /bash_files/linear/imagenet-100/data/barlow_linear.sh: -------------------------------------------------------------------------------- 1 | python3 main_linear.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy class \ 8 | --num_tasks 5 \ 9 | --max_epochs 100 \ 10 | --gpus 0 \ 11 | --precision 16 \ 12 | --optimizer sgd \ 13 | --scheduler step \ 14 | --lr 0.1 \ 15 | --lr_decay_steps 60 80 \ 16 | --weight_decay 0 \ 17 | --batch_size 256 \ 18 | --num_workers 8 \ 19 | --dali \ 20 | --name barlow-imagenet100-5T_data-linear-eval \ 21 | --pretrained_feature_extractor $PRETRAINED_PATH \ 22 | --project ever-learn \ 23 | --entity unitn-mhug \ 24 | --wandb \ 25 | --save_checkpoint -------------------------------------------------------------------------------- /bash_files/linear/imagenet-100/data/byol_linear.sh: -------------------------------------------------------------------------------- 1 | python3 main_linear.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy class \ 8 | --num_tasks 5 \ 9 | --max_epochs 100 \ 10 | --gpus 0 \ 11 | --precision 16 \ 12 | --optimizer sgd \ 13 | --scheduler step \ 14 | --lr 3.0 \ 15 | --lr_decay_steps 60 80 \ 16 | --weight_decay 0 \ 17 | --batch_size 256 \ 18 | --num_workers 8 \ 19 | --dali \ 20 | --name byol-imagenet100-5T_data-linear-eval \ 21 | --pretrained_feature_extractor $PRETRAINED_PATH \ 22 | --project ever-learn \ 23 | --entity unitn-mhug \ 24 | --wandb \ 25 | --save_checkpoint 26 | -------------------------------------------------------------------------------- /bash_files/linear/imagenet-100/data/swav_linear.sh: -------------------------------------------------------------------------------- 1 | python3 main_linear.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy class \ 8 | --num_tasks 5 \ 9 | --max_epochs 100 \ 10 | --gpus 0 \ 11 | --precision 16 \ 12 | --optimizer sgd \ 13 | --scheduler step \ 14 | --lr 0.15 \ 15 | --lr_decay_steps 60 80 \ 16 | --weight_decay 0 \ 17 | --batch_size 256 \ 18 | --num_workers 7 \ 19 | --dali \ 20 | --name swav-imagenet100-5T_data-linear-eval \ 21 | --pretrained_feature_extractor $PRETRAINED_PATH \ 22 | --project ever-learn \ 23 | --entity unitn-mhug \ 24 | --wandb \ 25 | --save_checkpoint 26 | -------------------------------------------------------------------------------- /bash_files/linear/imagenet-100/class/simsiam_linear.sh: -------------------------------------------------------------------------------- 1 | python3 main_linear.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy class \ 8 | --num_tasks 5 \ 9 | --max_epochs 100 \ 10 | --gpus 0 \ 11 | --precision 16 \ 12 | --optimizer sgd \ 13 | --scheduler step \ 14 | --lr 30.0 \ 15 | --lr_decay_steps 60 80 \ 16 | --weight_decay 0 \ 17 | --batch_size 256 \ 18 | --num_workers 10 \ 19 | --dali \ 20 | --name simsiam-imagenet100-5T-linear-eval \ 21 | --pretrained_feature_extractor $PRETRAINED_PATH \ 22 | --project ever-learn \ 23 | --entity unitn-mhug \ 24 | --wandb \ 25 | --save_checkpoint 26 | -------------------------------------------------------------------------------- /bash_files/linear/imagenet-100/data/simclr_linear.sh: -------------------------------------------------------------------------------- 1 | python3 main_linear.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy class \ 8 | --num_tasks 5 \ 9 | --max_epochs 100 \ 10 | --gpus 0 \ 11 | --precision 16 \ 12 | --optimizer sgd \ 13 | --scheduler step \ 14 | --lr 1.0 \ 15 | --lr_decay_steps 60 80 \ 16 | --weight_decay 0 \ 17 | --batch_size 256 \ 18 | --num_workers 7 \ 19 | --dali \ 20 | --name simclr-imagenet100-5T_data-linear-eval \ 21 | --pretrained_feature_extractor $PRETRAINED_PATH \ 22 | --project ever-learn \ 23 | --entity unitn-mhug \ 24 | --wandb \ 25 | --save_checkpoint 26 | -------------------------------------------------------------------------------- /bash_files/linear/imagenet-100/data/supcon_linear.sh: -------------------------------------------------------------------------------- 1 | python3 main_linear.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy class \ 8 | --num_tasks 5 \ 9 | --max_epochs 100 \ 10 | --gpus 0 \ 11 | --precision 16 \ 12 | --optimizer sgd \ 13 | --scheduler step \ 14 | --lr 1.0 \ 15 | --lr_decay_steps 60 80 \ 16 | --weight_decay 0 \ 17 | --batch_size 256 \ 18 | --num_workers 7 \ 19 | --dali \ 20 | --name supcon-imagenet100-5T_data-linear-eval \ 21 | --pretrained_feature_extractor $PRETRAINED_PATH \ 22 | --project ever-learn \ 23 | --entity unitn-mhug \ 24 | --wandb \ 25 | --save_checkpoint 26 | -------------------------------------------------------------------------------- /bash_files/linear/imagenet-100/data/vicreg_linear.sh: -------------------------------------------------------------------------------- 1 | python3 main_linear.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy class \ 8 | --num_tasks 5 \ 9 | --max_epochs 100 \ 10 | --gpus 0 \ 11 | --precision 16 \ 12 | --optimizer sgd \ 13 | --scheduler step \ 14 | --lr 0.3 \ 15 | --lr_decay_steps 60 80 \ 16 | --weight_decay 0 \ 17 | --batch_size 256 \ 18 | --num_workers 7 \ 19 | --dali \ 20 | --name vicreg-imagenet100-5T_data-linear-eval \ 21 | --pretrained_feature_extractor $PRETRAINED_PATH \ 22 | --project ever-learn \ 23 | --entity unitn-mhug \ 24 | --wandb \ 25 | --save_checkpoint 26 | -------------------------------------------------------------------------------- /bash_files/linear/imagenet-100/class/mocov2plus_linear.sh: -------------------------------------------------------------------------------- 1 | python3 main_linear.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy class \ 8 | --num_tasks 5 \ 9 | --max_epochs 100 \ 10 | --gpus 0 \ 11 | --precision 16 \ 12 | --optimizer sgd \ 13 | --scheduler step \ 14 | --lr 3.0 \ 15 | --lr_decay_steps 60 80 \ 16 | --weight_decay 0 \ 17 | --batch_size 256 \ 18 | --num_workers 10 \ 19 | --dali \ 20 | --name mocov2plus-imagenet100-5T-linear-eval \ 21 | --pretrained_feature_extractor $PRETRAINED_PATH \ 22 | --project ever-learn \ 23 | --entity unitn-mhug \ 24 | --wandb \ 25 | --save_checkpoint 26 | -------------------------------------------------------------------------------- /bash_files/linear/imagenet-100/data/simsiam_linear.sh: -------------------------------------------------------------------------------- 1 | python3 main_linear.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy class \ 8 | --num_tasks 5 \ 9 | --max_epochs 100 \ 10 | --gpus 0 \ 11 | --precision 16 \ 12 | --optimizer sgd \ 13 | --scheduler step \ 14 | --lr 30.0 \ 15 | --lr_decay_steps 60 80 \ 16 | --weight_decay 0 \ 17 | --batch_size 256 \ 18 | --num_workers 10 \ 19 | --dali \ 20 | --name simsiam-imagenet100-5T_data-linear-eval \ 21 | --pretrained_feature_extractor $PRETRAINED_PATH \ 22 | --project ever-learn \ 23 | --entity unitn-mhug \ 24 | --wandb \ 25 | --save_checkpoint 26 | -------------------------------------------------------------------------------- /bash_files/linear/imagenet-100/data/mocov2plus_linear.sh: -------------------------------------------------------------------------------- 1 | python3 main_linear.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy class \ 8 | --num_tasks 5 \ 9 | --max_epochs 100 \ 10 | --gpus 0 \ 11 | --precision 16 \ 12 | --optimizer sgd \ 13 | --scheduler step \ 14 | --lr 3.0 \ 15 | --lr_decay_steps 60 80 \ 16 | --weight_decay 0 \ 17 | --batch_size 256 \ 18 | --num_workers 10 \ 19 | --dali \ 20 | --name mocov2plus-imagenet100-5T_data-linear-eval \ 21 | --pretrained_feature_extractor $PRETRAINED_PATH \ 22 | --project ever-learn \ 23 | --entity unitn-mhug \ 24 | --wandb \ 25 | --save_checkpoint 26 | -------------------------------------------------------------------------------- /cassle/losses/wmse.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def wmse_loss_func(z1: torch.Tensor, z2: torch.Tensor, simplified: bool = True) -> torch.Tensor: 6 | """Computes W-MSE's loss given two batches of whitened features z1 and z2. 7 | 8 | Args: 9 | z1 (torch.Tensor): NxD Tensor containing whitened features from view 1. 10 | z2 (torch.Tensor): NxD Tensor containing whitened features from view 2. 11 | simplified (bool): faster computation, but with same result. 12 | 13 | Returns: 14 | torch.Tensor: W-MSE loss. 15 | """ 16 | 17 | if simplified: 18 | return 2 - 2 * F.cosine_similarity(z1, z2.detach(), dim=-1).mean() 19 | else: 20 | z1 = F.normalize(z1, dim=-1) 21 | z2 = F.normalize(z2, dim=-1) 22 | 23 | return 2 - 2 * (z1 * z2).sum(dim=-1).mean() 24 | -------------------------------------------------------------------------------- /cassle/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from cassle.utils import ( 2 | checkpointer, 3 | classification_dataloader, 4 | datasets, 5 | gather_layer, 6 | knn, 7 | lars, 8 | metrics, 9 | momentum, 10 | pretrain_dataloader, 11 | sinkhorn_knopp, 12 | ) 13 | 14 | __all__ = [ 15 | "classification_dataloader", 16 | "pretrain_dataloader", 17 | "checkpointer", 18 | "datasets", 19 | "gather_layer", 20 | "knn", 21 | "lars", 22 | "metrics", 23 | "momentum", 24 | "sinkhorn_knopp", 25 | ] 26 | 27 | try: 28 | from cassle.utils import dali_dataloader # noqa: F401 29 | except ImportError: 30 | pass 31 | else: 32 | __all__.append("dali_dataloader") 33 | 34 | try: 35 | from cassle.utils import auto_umap # noqa: F401 36 | except ImportError: 37 | pass 38 | else: 39 | __all__.append("auto_umap") 40 | -------------------------------------------------------------------------------- /cassle/losses/byol.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def byol_loss_func(p: torch.Tensor, z: torch.Tensor, simplified: bool = True) -> torch.Tensor: 6 | """Computes BYOL's loss given batch of predicted features p and projected momentum features z. 7 | 8 | Args: 9 | p (torch.Tensor): NxD Tensor containing predicted features from view 1 10 | z (torch.Tensor): NxD Tensor containing projected momentum features from view 2 11 | simplified (bool): faster computation, but with same result. Defaults to True. 12 | 13 | Returns: 14 | torch.Tensor: BYOL's loss. 15 | """ 16 | 17 | if simplified: 18 | return 2 - 2 * F.cosine_similarity(p, z.detach(), dim=-1).mean() 19 | else: 20 | p = F.normalize(p, dim=-1) 21 | z = F.normalize(z, dim=-1) 22 | 23 | return 2 - 2 * (p * z.detach()).sum(dim=1).mean() 24 | -------------------------------------------------------------------------------- /cassle/losses/simsiam.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def simsiam_loss_func(p: torch.Tensor, z: torch.Tensor, simplified: bool = True) -> torch.Tensor: 6 | """Computes SimSiam's loss given batch of predicted features p from view 1 and 7 | a batch of projected features z from view 2. 8 | 9 | Args: 10 | p (torch.Tensor): Tensor containing predicted features from view 1. 11 | z (torch.Tensor): Tensor containing projected features from view 2. 12 | simplified (bool): faster computation, but with same result. 13 | 14 | Returns: 15 | torch.Tensor: SimSiam loss. 16 | """ 17 | 18 | if simplified: 19 | return -F.cosine_similarity(p, z.detach(), dim=-1).mean() 20 | else: 21 | p = F.normalize(p, dim=-1) 22 | z = F.normalize(z, dim=-1) 23 | 24 | return -(p * z.detach()).sum(dim=1).mean() 25 | -------------------------------------------------------------------------------- /cassle/losses/nnclr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def nnclr_loss_func(nn: torch.Tensor, p: torch.Tensor, temperature: float = 0.1) -> torch.Tensor: 6 | """Computes NNCLR's loss given batch of nearest-neighbors nn from view 1 and 7 | predicted features p from view 2. 8 | 9 | Args: 10 | nn (torch.Tensor): NxD Tensor containing nearest neighbors' features from view 1. 11 | p (torch.Tensor): NxD Tensor containing predicted features from view 2 12 | temperature (float, optional): temperature of the softmax in the contrastive loss. Defaults 13 | to 0.1. 14 | 15 | Returns: 16 | torch.Tensor: NNCLR loss. 17 | """ 18 | 19 | nn = F.normalize(nn, dim=-1) 20 | p = F.normalize(p, dim=-1) 21 | 22 | logits = nn @ p.T / temperature 23 | 24 | n = p.size(0) 25 | labels = torch.arange(n, device=p.device) 26 | 27 | loss = F.cross_entropy(logits, labels) 28 | return loss 29 | -------------------------------------------------------------------------------- /bash_files/continual/cifar/simclr.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset cifar100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --split_strategy class \ 6 | --task_idx 0 \ 7 | --max_epochs 500 \ 8 | --num_tasks 2 \ 9 | --max_epochs 500 \ 10 | --gpus 0 \ 11 | --precision 16 \ 12 | --optimizer sgd \ 13 | --lars \ 14 | --grad_clip_lars \ 15 | --eta_lars 0.02 \ 16 | --exclude_bias_n_norm \ 17 | --scheduler warmup_cosine \ 18 | --lr 0.4 \ 19 | --classifier_lr 0.1 \ 20 | --weight_decay 1e-5 \ 21 | --batch_size 256 \ 22 | --num_workers 5 \ 23 | --brightness 0.8 \ 24 | --contrast 0.8 \ 25 | --saturation 0.8 \ 26 | --hue 0.2 \ 27 | --gaussian_prob 0.0 0.0 \ 28 | --name simclr-cifar100 \ 29 | --project ever-learn \ 30 | --entity unitn-mhug \ 31 | --wandb \ 32 | --save_checkpoint \ 33 | --method simclr \ 34 | --temperature 0.2 \ 35 | --proj_hidden_dim 2048 \ 36 | --output_dim 256 37 | -------------------------------------------------------------------------------- /bash_files/continual/cifar/barlow.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset cifar100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --split_strategy class \ 6 | --max_epochs 500 \ 7 | --num_tasks 5 \ 8 | --task_idx 0 \ 9 | --gpus 0 \ 10 | --num_workers 4 \ 11 | --precision 16 \ 12 | --optimizer sgd \ 13 | --lars \ 14 | --grad_clip_lars \ 15 | --eta_lars 0.02 \ 16 | --exclude_bias_n_norm \ 17 | --scheduler warmup_cosine \ 18 | --lr 0.3 \ 19 | --classifier_lr 0.1 \ 20 | --weight_decay 1e-4 \ 21 | --batch_size 256 \ 22 | --brightness 0.4 \ 23 | --contrast 0.4 \ 24 | --saturation 0.2 \ 25 | --hue 0.1 \ 26 | --gaussian_prob 0.0 0.0 \ 27 | --solarization_prob 0.0 0.2 \ 28 | --name barlow-cifar100 \ 29 | --project ever-learn \ 30 | --entity unitn-mhug \ 31 | --wandb \ 32 | --save_checkpoint \ 33 | --method barlow_twins \ 34 | --proj_hidden_dim 2048 \ 35 | --output_dim 2048 \ 36 | --scale_loss 0.1 -------------------------------------------------------------------------------- /bash_files/continual/cifar/swav.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset cifar100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --split_strategy class \ 6 | --max_epochs 500 \ 7 | --num_tasks 2 \ 8 | --task_idx 0 \ 9 | --gpus 0 \ 10 | --precision 16 \ 11 | --optimizer sgd \ 12 | --lars \ 13 | --grad_clip_lars \ 14 | --eta_lars 0.02 \ 15 | --scheduler warmup_cosine \ 16 | --lr 0.6 \ 17 | --min_lr 0.0006 \ 18 | --classifier_lr 0.1 \ 19 | --weight_decay 1e-6 \ 20 | --batch_size 256 \ 21 | --num_workers 3 \ 22 | --brightness 0.8 \ 23 | --contrast 0.8 \ 24 | --saturation 0.8 \ 25 | --hue 0.2 \ 26 | --gaussian_prob 0.0 0.0 \ 27 | --name swav-cifar100 \ 28 | --project ever-learn \ 29 | --entity unitn-mhug \ 30 | --wandb \ 31 | --method swav \ 32 | --proj_hidden_dim 2048 \ 33 | --queue_size 3840 \ 34 | --output_dim 128 \ 35 | --num_prototypes 3000 \ 36 | --epoch_queue_starts 50 \ 37 | --freeze_prototypes_epochs 2 -------------------------------------------------------------------------------- /cassle/distillers/__init__.py: -------------------------------------------------------------------------------- 1 | from cassle.distillers.base import base_distill_wrapper 2 | from cassle.distillers.contrastive import contrastive_distill_wrapper 3 | from cassle.distillers.decorrelative import decorrelative_distill_wrapper 4 | from cassle.distillers.knowledge import knowledge_distill_wrapper 5 | from cassle.distillers.predictive import predictive_distill_wrapper 6 | from cassle.distillers.predictive_mse import predictive_mse_distill_wrapper 7 | 8 | 9 | __all__ = [ 10 | "base_distill_wrapper", 11 | "contrastive_distill_wrapper", 12 | "decorrelative_distill_wrapper", 13 | "nearest_neighbor_distill_wrapper", 14 | "predictive_distill_wrapper", 15 | "predictive_mse_distill_wrapper", 16 | ] 17 | 18 | DISTILLERS = { 19 | "base": base_distill_wrapper, 20 | "contrastive": contrastive_distill_wrapper, 21 | "decorrelative": decorrelative_distill_wrapper, 22 | "knowledge": knowledge_distill_wrapper, 23 | "predictive": predictive_distill_wrapper, 24 | "predictive_mse": predictive_mse_distill_wrapper, 25 | } 26 | -------------------------------------------------------------------------------- /bash_files/continual/domainnet/simclr.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset domainnet \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR/domainnet \ 5 | --split_strategy domain \ 6 | --max_epochs 200 \ 7 | --num_tasks 6 \ 8 | --task_idx 0 \ 9 | --gpus 0,1,2,3 \ 10 | --accelerator ddp \ 11 | --sync_batchnorm \ 12 | --num_workers 5 \ 13 | --precision 16 \ 14 | --optimizer sgd \ 15 | --lars \ 16 | --grad_clip_lars \ 17 | --eta_lars 0.02 \ 18 | --exclude_bias_n_norm \ 19 | --scheduler warmup_cosine \ 20 | --lr 0.4 \ 21 | --weight_decay 1e-4 \ 22 | --batch_size 64 \ 23 | --brightness 0.8 \ 24 | --contrast 0.8 \ 25 | --saturation 0.8 \ 26 | --hue 0.2 \ 27 | --dali \ 28 | --name simclr-domainnet \ 29 | --wandb \ 30 | --save_checkpoint \ 31 | --entity unitn-mhug \ 32 | --project ever-learn \ 33 | --method simclr \ 34 | --temperature 0.2 \ 35 | --proj_hidden_dim 2048 \ 36 | --check_val_every_n_epoch 9999 \ 37 | --disable_knn_eval 38 | -------------------------------------------------------------------------------- /bash_files/continual/cifar/vicreg.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset cifar100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --split_strategy class \ 6 | --task_idx 1 \ 7 | --num_tasks 5 \ 8 | --max_epochs 500 \ 9 | --gpus 0 \ 10 | --precision 16 \ 11 | --optimizer sgd \ 12 | --lars \ 13 | --grad_clip_lars \ 14 | --eta_lars 0.02 \ 15 | --exclude_bias_n_norm \ 16 | --scheduler warmup_cosine \ 17 | --lr 0.3 \ 18 | --weight_decay 1e-4 \ 19 | --batch_size 256 \ 20 | --num_workers 3 \ 21 | --min_scale 0.2 \ 22 | --brightness 0.4 \ 23 | --contrast 0.4 \ 24 | --saturation 0.2 \ 25 | --hue 0.1 \ 26 | --solarization_prob 0.1 \ 27 | --gaussian_prob 0.0 0.0 \ 28 | --name vicreg-cifar100-5T \ 29 | --project ever-learn \ 30 | --entity unitn-mhug \ 31 | --wandb \ 32 | --save_checkpoint \ 33 | --method vicreg \ 34 | --proj_hidden_dim 2048 \ 35 | --output_dim 2048 \ 36 | --sim_loss_weight 25.0 \ 37 | --var_loss_weight 25.0 \ 38 | --cov_loss_weight 1.0 -------------------------------------------------------------------------------- /bash_files/continual/domainnet/supcon.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset domainnet \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR/domainnet \ 5 | --split_strategy domain \ 6 | --max_epochs 200 \ 7 | --num_tasks 6 \ 8 | --task_idx 0 \ 9 | --gpus 0,1,2,3 \ 10 | --accelerator ddp \ 11 | --sync_batchnorm \ 12 | --num_workers 5 \ 13 | --precision 16 \ 14 | --optimizer sgd \ 15 | --lars \ 16 | --grad_clip_lars \ 17 | --eta_lars 0.02 \ 18 | --exclude_bias_n_norm \ 19 | --scheduler warmup_cosine \ 20 | --lr 0.4 \ 21 | --weight_decay 1e-4 \ 22 | --batch_size 64 \ 23 | --brightness 0.8 \ 24 | --contrast 0.8 \ 25 | --saturation 0.8 \ 26 | --hue 0.2 \ 27 | --dali \ 28 | --name supcon-domainnet \ 29 | --entity unitn-mhug \ 30 | --project ever-learn \ 31 | --wandb \ 32 | --save_checkpoint \ 33 | --method simclr \ 34 | --temperature 0.1 \ 35 | --proj_hidden_dim 2048 \ 36 | --supervised \ 37 | --check_val_every_n_epoch 9999 \ 38 | --disable_knn_eval 39 | -------------------------------------------------------------------------------- /bash_files/continual/imagenet-100/class/simclr.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy class \ 8 | --max_epochs 400 \ 9 | --num_tasks 5 \ 10 | --task_idx 0 \ 11 | --gpus 0,1 \ 12 | --accelerator ddp \ 13 | --sync_batchnorm \ 14 | --num_workers 5 \ 15 | --precision 16 \ 16 | --optimizer sgd \ 17 | --lars \ 18 | --grad_clip_lars \ 19 | --eta_lars 0.02 \ 20 | --exclude_bias_n_norm \ 21 | --scheduler warmup_cosine \ 22 | --lr 0.4 \ 23 | --weight_decay 1e-4 \ 24 | --batch_size 128 \ 25 | --brightness 0.8 \ 26 | --contrast 0.8 \ 27 | --saturation 0.8 \ 28 | --hue 0.2 \ 29 | --dali \ 30 | --check_val_every_n_epoch 9999 \ 31 | --name simclr-imagenet100-5T \ 32 | --wandb \ 33 | --save_checkpoint \ 34 | --entity unitn-mhug \ 35 | --project ever-learn \ 36 | --method simclr \ 37 | --temperature 0.2 \ 38 | --proj_hidden_dim 2048 39 | -------------------------------------------------------------------------------- /bash_files/continual/cifar/simclr_distill.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset cifar100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --split_strategy class \ 6 | --task_idx 1 \ 7 | --max_epochs 500 \ 8 | --num_tasks 5 \ 9 | --max_epochs 500 \ 10 | --gpus 0 \ 11 | --precision 16 \ 12 | --optimizer sgd \ 13 | --lars \ 14 | --grad_clip_lars \ 15 | --eta_lars 0.02 \ 16 | --exclude_bias_n_norm \ 17 | --scheduler warmup_cosine \ 18 | --lr 0.4 \ 19 | --classifier_lr 0.1 \ 20 | --weight_decay 1e-5 \ 21 | --batch_size 256 \ 22 | --num_workers 5 \ 23 | --brightness 0.8 \ 24 | --contrast 0.8 \ 25 | --saturation 0.8 \ 26 | --hue 0.2 \ 27 | --gaussian_prob 0.0 0.0 \ 28 | --name simclr-cifar100-contrastive \ 29 | --project ever-learn \ 30 | --entity unitn-mhug \ 31 | --wandb \ 32 | --save_checkpoint \ 33 | --method simclr \ 34 | --temperature 0.2 \ 35 | --proj_hidden_dim 2048 \ 36 | --output_dim 256 \ 37 | --distiller contrastive \ 38 | --pretrained_model $PRETRAINED_PATH 39 | -------------------------------------------------------------------------------- /bash_files/continual/imagenet-100/data/simclr.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy data \ 8 | --max_epochs 400 \ 9 | --num_tasks 5 \ 10 | --task_idx 0 \ 11 | --gpus 0,1 \ 12 | --accelerator ddp \ 13 | --sync_batchnorm \ 14 | --num_workers 5 \ 15 | --precision 16 \ 16 | --optimizer sgd \ 17 | --lars \ 18 | --grad_clip_lars \ 19 | --eta_lars 0.02 \ 20 | --exclude_bias_n_norm \ 21 | --scheduler warmup_cosine \ 22 | --lr 0.4 \ 23 | --weight_decay 1e-4 \ 24 | --batch_size 128 \ 25 | --brightness 0.8 \ 26 | --contrast 0.8 \ 27 | --saturation 0.8 \ 28 | --hue 0.2 \ 29 | --dali \ 30 | --check_val_every_n_epoch 9999 \ 31 | --name simclr-imagenet100-5T_data \ 32 | --wandb \ 33 | --save_checkpoint \ 34 | --entity unitn-mhug \ 35 | --project ever-learn \ 36 | --method simclr \ 37 | --temperature 0.2 \ 38 | --proj_hidden_dim 2048 39 | -------------------------------------------------------------------------------- /bash_files/continual/cifar/byol.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset cifar100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --split_strategy class \ 6 | --max_epochs 500 \ 7 | --num_tasks 2 \ 8 | --task_idx 0 \ 9 | --gpus 0 \ 10 | --precision 16 \ 11 | --optimizer sgd \ 12 | --lars \ 13 | --grad_clip_lars \ 14 | --eta_lars 0.02 \ 15 | --exclude_bias_n_norm \ 16 | --scheduler warmup_cosine \ 17 | --lr 1.0 \ 18 | --classifier_lr 0.1 \ 19 | --weight_decay 1e-5 \ 20 | --batch_size 256 \ 21 | --num_workers 5 \ 22 | --brightness 0.4 \ 23 | --contrast 0.4 \ 24 | --saturation 0.2 \ 25 | --hue 0.1 \ 26 | --gaussian_prob 0.0 0.0 \ 27 | --solarization_prob 0.0 0.2 \ 28 | --name byol-cifar100 \ 29 | --project ever-learn \ 30 | --entity unitn-mhug \ 31 | --wandb \ 32 | --save_checkpoint \ 33 | --method byol \ 34 | --output_dim 256 \ 35 | --proj_hidden_dim 4096 \ 36 | --pred_hidden_dim 4096 \ 37 | --base_tau_momentum 0.99 \ 38 | --final_tau_momentum 1.0 \ 39 | --momentum_classifier -------------------------------------------------------------------------------- /bash_files/continual/domainnet/mocov2plus.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset domainnet \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR/domainnet \ 5 | --split_strategy domain \ 6 | --max_epochs 200 \ 7 | --num_tasks 6 \ 8 | --task_idx 0 \ 9 | --gpus 0,1,2,3 \ 10 | --accelerator ddp \ 11 | --sync_batchnorm \ 12 | --num_workers 5 \ 13 | --precision 16 \ 14 | --optimizer sgd \ 15 | --scheduler cosine \ 16 | --lr 0.4 \ 17 | --classifier_lr 0.3 \ 18 | --weight_decay 1e-4 \ 19 | --batch_size 64 \ 20 | --brightness 0.4 \ 21 | --contrast 0.4 \ 22 | --saturation 0.4 \ 23 | --hue 0.1 \ 24 | --dali \ 25 | --name mocov2plus-domainnet \ 26 | --project ever-learn \ 27 | --entity unitn-mhug \ 28 | --wandb \ 29 | --save_checkpoint \ 30 | --method mocov2plus \ 31 | --proj_hidden_dim 2048 \ 32 | --queue_size 65536 \ 33 | --temperature 0.2 \ 34 | --base_tau_momentum 0.99 \ 35 | --final_tau_momentum 0.999 \ 36 | --momentum_classifier \ 37 | --check_val_every_n_epoch 9999 \ 38 | --disable_knn_eval 39 | -------------------------------------------------------------------------------- /cassle/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from cassle.losses.barlow import barlow_loss_func 2 | from cassle.losses.byol import byol_loss_func 3 | from cassle.losses.deepclusterv2 import deepclusterv2_loss_func 4 | from cassle.losses.dino import DINOLoss 5 | from cassle.losses.moco import moco_loss_func 6 | from cassle.losses.nnclr import nnclr_loss_func 7 | from cassle.losses.ressl import ressl_loss_func 8 | from cassle.losses.simclr import manual_simclr_loss_func, simclr_loss_func, simclr_distill_loss_func 9 | from cassle.losses.simsiam import simsiam_loss_func 10 | from cassle.losses.swav import swav_loss_func 11 | from cassle.losses.vicreg import vicreg_loss_func 12 | from cassle.losses.wmse import wmse_loss_func 13 | 14 | __all__ = [ 15 | "barlow_loss_func", 16 | "byol_loss_func", 17 | "deepclusterv2_loss_func", 18 | "DINOLoss", 19 | "moco_loss_func", 20 | "nnclr_loss_func", 21 | "ressl_loss_func", 22 | "simclr_loss_func", 23 | "manual_simclr_loss_func", 24 | "simclr_distill_loss_func", 25 | "simsiam_loss_func", 26 | "swav_loss_func", 27 | "vicreg_loss_func", 28 | "wmse_loss_func", 29 | ] 30 | -------------------------------------------------------------------------------- /cassle/utils/gather_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | 4 | 5 | class GatherLayer(torch.autograd.Function): 6 | """Gathers tensors from all processes, supporting backward propagation.""" 7 | 8 | @staticmethod 9 | def forward(ctx, input): 10 | ctx.save_for_backward(input) 11 | if dist.is_available() and dist.is_initialized(): 12 | output = [torch.zeros_like(input) for _ in range(dist.get_world_size())] 13 | dist.all_gather(output, input) 14 | else: 15 | output = [input] 16 | return tuple(output) 17 | 18 | @staticmethod 19 | def backward(ctx, *grads): 20 | (input,) = ctx.saved_tensors 21 | if dist.is_available() and dist.is_initialized(): 22 | grad_out = torch.zeros_like(input) 23 | grad_out[:] = grads[dist.get_rank()] 24 | else: 25 | grad_out = grads[0] 26 | return grad_out 27 | 28 | 29 | def gather(X, dim=0): 30 | """Gathers tensors from all processes, supporting backward propagation.""" 31 | return torch.cat(GatherLayer.apply(X), dim=dim) 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /bash_files/continual/cifar/barlow_distill.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset cifar100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --split_strategy class \ 6 | --max_epochs 500 \ 7 | --num_tasks 5 \ 8 | --task_idx 1 \ 9 | --gpus 0 \ 10 | --num_workers 4 \ 11 | --precision 16 \ 12 | --optimizer sgd \ 13 | --lars \ 14 | --grad_clip_lars \ 15 | --eta_lars 0.02 \ 16 | --exclude_bias_n_norm \ 17 | --scheduler warmup_cosine \ 18 | --lr 0.3 \ 19 | --classifier_lr 0.1 \ 20 | --weight_decay 1e-4 \ 21 | --batch_size 256 \ 22 | --brightness 0.4 \ 23 | --contrast 0.4 \ 24 | --saturation 0.2 \ 25 | --hue 0.1 \ 26 | --gaussian_prob 0.0 0.0 \ 27 | --solarization_prob 0.0 0.2 \ 28 | --name barlow-cifar100-decorrelative \ 29 | --project ever-learn \ 30 | --entity unitn-mhug \ 31 | --wandb \ 32 | --save_checkpoint \ 33 | --method barlow_twins \ 34 | --proj_hidden_dim 2048 \ 35 | --output_dim 2048 \ 36 | --scale_loss 0.1 \ 37 | --distiller decorrelative \ 38 | --pretrained_model $PRETRAINED_PATH 39 | -------------------------------------------------------------------------------- /bash_files/continual/imagenet-100/class/supcon.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy class \ 8 | --max_epochs 400 \ 9 | --num_tasks 5 \ 10 | --task_idx 0 \ 11 | --gpus 0,1 \ 12 | --accelerator ddp \ 13 | --sync_batchnorm \ 14 | --num_workers 5 \ 15 | --precision 16 \ 16 | --optimizer sgd \ 17 | --lars \ 18 | --grad_clip_lars \ 19 | --eta_lars 0.02 \ 20 | --exclude_bias_n_norm \ 21 | --scheduler warmup_cosine \ 22 | --lr 0.4 \ 23 | --weight_decay 1e-4 \ 24 | --batch_size 128 \ 25 | --brightness 0.8 \ 26 | --contrast 0.8 \ 27 | --saturation 0.8 \ 28 | --hue 0.2 \ 29 | --dali \ 30 | --check_val_every_n_epoch 9999 \ 31 | --name supcon-imagenet100-5T \ 32 | --entity unitn-mhug \ 33 | --project ever-learn \ 34 | --wandb \ 35 | --save_checkpoint \ 36 | --method simclr \ 37 | --temperature 0.1 \ 38 | --proj_hidden_dim 2048 \ 39 | --supervised 40 | -------------------------------------------------------------------------------- /bash_files/continual/imagenet-100/data/supcon.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy data \ 8 | --max_epochs 400 \ 9 | --num_tasks 5 \ 10 | --task_idx 0 \ 11 | --gpus 0,1 \ 12 | --accelerator ddp \ 13 | --sync_batchnorm \ 14 | --num_workers 5 \ 15 | --precision 16 \ 16 | --optimizer sgd \ 17 | --lars \ 18 | --grad_clip_lars \ 19 | --eta_lars 0.02 \ 20 | --exclude_bias_n_norm \ 21 | --scheduler warmup_cosine \ 22 | --lr 0.4 \ 23 | --weight_decay 1e-4 \ 24 | --batch_size 128 \ 25 | --brightness 0.8 \ 26 | --contrast 0.8 \ 27 | --saturation 0.8 \ 28 | --hue 0.2 \ 29 | --dali \ 30 | --check_val_every_n_epoch 9999 \ 31 | --name supcon-imagenet100-5T_data \ 32 | --entity unitn-mhug \ 33 | --project ever-learn \ 34 | --wandb \ 35 | --save_checkpoint \ 36 | --method simclr \ 37 | --temperature 0.1 \ 38 | --proj_hidden_dim 2048 \ 39 | --supervised 40 | -------------------------------------------------------------------------------- /bash_files/continual/imagenet-100/class/mocov2plus.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy class \ 8 | --max_epochs 400 \ 9 | --num_tasks 5 \ 10 | --task_idx 0 \ 11 | --gpus 0,1 \ 12 | --accelerator ddp \ 13 | --sync_batchnorm \ 14 | --num_workers 5 \ 15 | --precision 16 \ 16 | --optimizer sgd \ 17 | --scheduler cosine \ 18 | --lr 0.4 \ 19 | --classifier_lr 0.3 \ 20 | --weight_decay 1e-4 \ 21 | --batch_size 128 \ 22 | --brightness 0.4 \ 23 | --contrast 0.4 \ 24 | --saturation 0.4 \ 25 | --hue 0.1 \ 26 | --dali \ 27 | --check_val_every_n_epoch 9999 \ 28 | --name mocov2plus-imagenet100-5T \ 29 | --project ever-learn \ 30 | --entity unitn-mhug \ 31 | --wandb \ 32 | --save_checkpoint \ 33 | --method mocov2plus \ 34 | --proj_hidden_dim 2048 \ 35 | --queue_size 65536 \ 36 | --temperature 0.2 \ 37 | --base_tau_momentum 0.99 \ 38 | --final_tau_momentum 0.999 39 | -------------------------------------------------------------------------------- /bash_files/continual/cifar/swav_distill.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset cifar100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --split_strategy class \ 6 | --max_epochs 500 \ 7 | --num_tasks 5 \ 8 | --task_idx 1 \ 9 | --gpus 0 \ 10 | --precision 16 \ 11 | --optimizer sgd \ 12 | --lars \ 13 | --grad_clip_lars \ 14 | --eta_lars 0.02 \ 15 | --scheduler warmup_cosine \ 16 | --lr 0.6 \ 17 | --min_lr 0.0006 \ 18 | --classifier_lr 0.1 \ 19 | --weight_decay 1e-6 \ 20 | --batch_size 256 \ 21 | --num_workers 3 \ 22 | --brightness 0.8 \ 23 | --contrast 0.8 \ 24 | --saturation 0.8 \ 25 | --hue 0.2 \ 26 | --gaussian_prob 0.0 0.0 \ 27 | --name swav-cifar100-knowledge \ 28 | --project ever-learn \ 29 | --entity unitn-mhug \ 30 | --wandb \ 31 | --method swav \ 32 | --proj_hidden_dim 2048 \ 33 | --queue_size 3840 \ 34 | --output_dim 128 \ 35 | --num_prototypes 3000 \ 36 | --epoch_queue_starts 50 \ 37 | --freeze_prototypes_epochs 2 \ 38 | --distiller knowledge \ 39 | --pretrained_model $PRETRAINED_PATH 40 | -------------------------------------------------------------------------------- /cassle/losses/moco.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def moco_loss_func( 6 | query: torch.Tensor, key: torch.Tensor, queue: torch.Tensor, temperature=0.1 7 | ) -> torch.Tensor: 8 | """Computes MoCo's loss given a batch of queries from view 1, a batch of keys from view 2 and a 9 | queue of past elements. 10 | 11 | Args: 12 | query (torch.Tensor): NxD Tensor containing the queries from view 1. 13 | key (torch.Tensor): NxD Tensor containing the queries from view 2. 14 | queue (torch.Tensor): a queue of negative samples for the contrastive loss. 15 | temperature (float, optional): [description]. temperature of the softmax in the contrastive 16 | loss. Defaults to 0.1. 17 | 18 | Returns: 19 | torch.Tensor: MoCo loss. 20 | """ 21 | 22 | pos = torch.einsum("nc,nc->n", [query, key]).unsqueeze(-1) 23 | neg = torch.einsum("nc,ck->nk", [query, queue]) 24 | logits = torch.cat([pos, neg], dim=1) 25 | logits /= temperature 26 | targets = torch.zeros(query.size(0), device=query.device, dtype=torch.long) 27 | return F.cross_entropy(logits, targets) 28 | -------------------------------------------------------------------------------- /bash_files/continual/domainnet/simclr_distill.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset domainnet \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR/domainnet \ 5 | --split_strategy domain \ 6 | --max_epochs 200 \ 7 | --num_tasks 6 \ 8 | --task_idx 1 \ 9 | --gpus 0,1,2,3 \ 10 | --accelerator ddp \ 11 | --sync_batchnorm \ 12 | --num_workers 5 \ 13 | --precision 16 \ 14 | --optimizer sgd \ 15 | --lars \ 16 | --grad_clip_lars \ 17 | --eta_lars 0.02 \ 18 | --exclude_bias_n_norm \ 19 | --scheduler warmup_cosine \ 20 | --lr 0.4 \ 21 | --weight_decay 1e-4 \ 22 | --batch_size 64 \ 23 | --brightness 0.8 \ 24 | --contrast 0.8 \ 25 | --saturation 0.8 \ 26 | --hue 0.2 \ 27 | --dali \ 28 | --name simclr-domainnet-contrastive \ 29 | --wandb \ 30 | --save_checkpoint \ 31 | --entity unitn-mhug \ 32 | --project ever-learn \ 33 | --method simclr \ 34 | --temperature 0.2 \ 35 | --proj_hidden_dim 2048 \ 36 | --check_val_every_n_epoch 9999 \ 37 | --disable_knn_eval \ 38 | --distiller contrastive \ 39 | --pretrained_model $PRETRAINED_PATH 40 | -------------------------------------------------------------------------------- /cassle/losses/swav.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def swav_loss_func( 8 | preds: List[torch.Tensor], assignments: List[torch.Tensor], temperature: float = 0.1 9 | ) -> torch.Tensor: 10 | """Computes SwAV's loss given list of batch predictions from multiple views 11 | and a list of cluster assignments from the same multiple views. 12 | 13 | Args: 14 | preds (torch.Tensor): list of NxC Tensors containing nearest neighbors' features from 15 | view 1. 16 | assignments (torch.Tensor): list of NxC Tensor containing predicted features from view 2. 17 | temperature (torch.Tensor): softmax temperature for the loss. Defaults to 0.1. 18 | 19 | Returns: 20 | torch.Tensor: SwAV loss. 21 | """ 22 | 23 | losses = [] 24 | for v1 in range(len(preds)): 25 | for v2 in np.delete(np.arange(len(preds)), v1): 26 | a = assignments[v1] 27 | p = preds[v2] / temperature 28 | loss = -torch.mean(torch.sum(a * torch.log_softmax(p, dim=1), dim=1)) 29 | losses.append(loss) 30 | return sum(losses) / len(losses) 31 | -------------------------------------------------------------------------------- /bash_files/continual/domainnet/barlow.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset domainnet \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR/domainnet \ 5 | --split_strategy domain \ 6 | --max_epochs 200 \ 7 | --num_tasks 6 \ 8 | --task_idx 0 \ 9 | --gpus 0,1,2,3 \ 10 | --accelerator ddp \ 11 | --sync_batchnorm \ 12 | --num_workers 5 \ 13 | --precision 16 \ 14 | --optimizer sgd \ 15 | --lars \ 16 | --grad_clip_lars \ 17 | --eta_lars 0.02 \ 18 | --exclude_bias_n_norm \ 19 | --scheduler warmup_cosine \ 20 | --lr 0.4 \ 21 | --weight_decay 1e-4 \ 22 | --batch_size 64 \ 23 | --brightness 0.4 \ 24 | --contrast 0.4 \ 25 | --saturation 0.2 \ 26 | --hue 0.1 \ 27 | --gaussian_prob 1.0 0.1 \ 28 | --solarization_prob 0.0 0.2 \ 29 | --dali \ 30 | --name barlow-domainnet \ 31 | --wandb \ 32 | --save_checkpoint \ 33 | --entity unitn-mhug \ 34 | --project ever-learn \ 35 | --scale_loss 0.1 \ 36 | --method barlow_twins \ 37 | --proj_hidden_dim 2048 \ 38 | --output_dim 2048 \ 39 | --check_val_every_n_epoch 9999 \ 40 | --disable_knn_eval 41 | -------------------------------------------------------------------------------- /bash_files/continual/imagenet-100/data/mocov2plus.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy data \ 8 | --max_epochs 400 \ 9 | --num_tasks 5 \ 10 | --task_idx 0 \ 11 | --gpus 0,1 \ 12 | --accelerator ddp \ 13 | --sync_batchnorm \ 14 | --num_workers 5 \ 15 | --precision 16 \ 16 | --optimizer sgd \ 17 | --scheduler cosine \ 18 | --lr 0.4 \ 19 | --classifier_lr 0.3 \ 20 | --weight_decay 1e-4 \ 21 | --batch_size 128 \ 22 | --brightness 0.4 \ 23 | --contrast 0.4 \ 24 | --saturation 0.4 \ 25 | --hue 0.1 \ 26 | --dali \ 27 | --check_val_every_n_epoch 9999 \ 28 | --name mocov2plus-imagenet100-5T_data \ 29 | --project ever-learn \ 30 | --entity unitn-mhug \ 31 | --wandb \ 32 | --save_checkpoint \ 33 | --method mocov2plus \ 34 | --proj_hidden_dim 2048 \ 35 | --queue_size 65536 \ 36 | --temperature 0.2 \ 37 | --base_tau_momentum 0.99 \ 38 | --final_tau_momentum 0.999 \ 39 | --momentum_classifier 40 | -------------------------------------------------------------------------------- /bash_files/continual/domainnet/supcon_distill.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset domainnet \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR/domainnet \ 5 | --split_strategy domain \ 6 | --max_epochs 200 \ 7 | --num_tasks 6 \ 8 | --task_idx 1 \ 9 | --gpus 0,1,2,3 \ 10 | --accelerator ddp \ 11 | --sync_batchnorm \ 12 | --num_workers 5 \ 13 | --precision 16 \ 14 | --optimizer sgd \ 15 | --lars \ 16 | --grad_clip_lars \ 17 | --eta_lars 0.02 \ 18 | --exclude_bias_n_norm \ 19 | --scheduler warmup_cosine \ 20 | --lr 0.4 \ 21 | --weight_decay 1e-4 \ 22 | --batch_size 64 \ 23 | --brightness 0.8 \ 24 | --contrast 0.8 \ 25 | --saturation 0.8 \ 26 | --hue 0.2 \ 27 | --dali \ 28 | --name supcon-domainnet-contrastive \ 29 | --entity unitn-mhug \ 30 | --project ever-learn \ 31 | --wandb \ 32 | --save_checkpoint \ 33 | --method simclr \ 34 | --temperature 0.1 \ 35 | --proj_hidden_dim 2048 \ 36 | --supervised \ 37 | --check_val_every_n_epoch 9999 \ 38 | --disable_knn_eval \ 39 | --distiller contrastive \ 40 | --pretrained_model $PRETRAINED_PATH 41 | -------------------------------------------------------------------------------- /bash_files/continual/cifar/vicreg_distill.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset cifar100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --split_strategy class \ 6 | --task_idx 1 \ 7 | --num_tasks 5 \ 8 | --max_epochs 500 \ 9 | --gpus 0 \ 10 | --precision 16 \ 11 | --optimizer sgd \ 12 | --lars \ 13 | --grad_clip_lars \ 14 | --eta_lars 0.02 \ 15 | --exclude_bias_n_norm \ 16 | --scheduler warmup_cosine \ 17 | --lr 0.3 \ 18 | --weight_decay 1e-4 \ 19 | --batch_size 256 \ 20 | --num_workers 3 \ 21 | --min_scale 0.2 \ 22 | --brightness 0.4 \ 23 | --contrast 0.4 \ 24 | --saturation 0.2 \ 25 | --hue 0.1 \ 26 | --solarization_prob 0.1 \ 27 | --gaussian_prob 0.0 0.0 \ 28 | --name vicreg-cifar100-5T-predictive_mse \ 29 | --project ever-learn \ 30 | --entity unitn-mhug \ 31 | --wandb \ 32 | --save_checkpoint \ 33 | --method vicreg \ 34 | --proj_hidden_dim 2048 \ 35 | --output_dim 2048 \ 36 | --sim_loss_weight 25.0 \ 37 | --var_loss_weight 25.0 \ 38 | --cov_loss_weight 1.0 \ 39 | --distiller predictive_mse \ 40 | --pretrained_model $PRETRAINED_PATH -------------------------------------------------------------------------------- /cassle/losses/deepclusterv2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def deepclusterv2_loss_func( 6 | outputs: torch.Tensor, assignments: torch.Tensor, temperature: float = 0.1 7 | ) -> torch.Tensor: 8 | """Computes DeepClusterV2's loss given a tensor containing logits from multiple views 9 | and a tensor containing cluster assignments from the same multiple views. 10 | 11 | Args: 12 | outputs (torch.Tensor): tensor of size PxVxNxC where P is the number of prototype 13 | layers and V is the number of views. 14 | assignments (torch.Tensor): tensor of size PxVxNxC containing the assignments 15 | generated using k-means. 16 | temperature (float, optional): softmax temperature for the loss. Defaults to 0.1. 17 | 18 | Returns: 19 | torch.Tensor: DeepClusterV2 loss. 20 | """ 21 | loss = 0 22 | for h in range(outputs.size(0)): 23 | scores = outputs[h].view(-1, outputs.size(-1)) / temperature 24 | targets = assignments[h].repeat(outputs.size(1)).to(outputs.device, non_blocking=True) 25 | loss += F.cross_entropy(scores, targets, ignore_index=-1) 26 | return loss / outputs.size(0) 27 | -------------------------------------------------------------------------------- /bash_files/continual/imagenet-100/class/simclr_distill.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy class \ 8 | --max_epochs 400 \ 9 | --num_tasks 5 \ 10 | --task_idx 1 \ 11 | --gpus 0,1 \ 12 | --accelerator ddp \ 13 | --sync_batchnorm \ 14 | --num_workers 5 \ 15 | --precision 16 \ 16 | --optimizer sgd \ 17 | --lars \ 18 | --grad_clip_lars \ 19 | --eta_lars 0.02 \ 20 | --exclude_bias_n_norm \ 21 | --scheduler warmup_cosine \ 22 | --lr 0.4 \ 23 | --weight_decay 1e-4 \ 24 | --batch_size 128 \ 25 | --brightness 0.8 \ 26 | --contrast 0.8 \ 27 | --saturation 0.8 \ 28 | --hue 0.2 \ 29 | --dali \ 30 | --check_val_every_n_epoch 9999 \ 31 | --name simclr-imagenet100-5T-contrastive \ 32 | --wandb \ 33 | --save_checkpoint \ 34 | --entity unitn-mhug \ 35 | --project ever-learn \ 36 | --method simclr \ 37 | --temperature 0.2 \ 38 | --proj_hidden_dim 2048 \ 39 | --distiller contrastive \ 40 | --pretrained_model $PRETRAINED_PATH 41 | -------------------------------------------------------------------------------- /bash_files/continual/imagenet-100/data/simclr_distill.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy data \ 8 | --max_epochs 400 \ 9 | --num_tasks 5 \ 10 | --task_idx 1 \ 11 | --gpus 0,1 \ 12 | --accelerator ddp \ 13 | --sync_batchnorm \ 14 | --num_workers 5 \ 15 | --precision 16 \ 16 | --optimizer sgd \ 17 | --lars \ 18 | --grad_clip_lars \ 19 | --eta_lars 0.02 \ 20 | --exclude_bias_n_norm \ 21 | --scheduler warmup_cosine \ 22 | --lr 0.4 \ 23 | --weight_decay 1e-4 \ 24 | --batch_size 128 \ 25 | --brightness 0.8 \ 26 | --contrast 0.8 \ 27 | --saturation 0.8 \ 28 | --hue 0.2 \ 29 | --dali \ 30 | --check_val_every_n_epoch 9999 \ 31 | --name simclr-imagenet100-5T_data-contrastive \ 32 | --wandb \ 33 | --save_checkpoint \ 34 | --entity unitn-mhug \ 35 | --project ever-learn \ 36 | --method simclr \ 37 | --temperature 0.2 \ 38 | --proj_hidden_dim 2048 \ 39 | --distiller contrastive \ 40 | --pretrained_model $PRETRAINED_PATH 41 | -------------------------------------------------------------------------------- /bash_files/continual/domainnet/vicreg.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset domainnet \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR/domainnet \ 5 | --split_strategy domain \ 6 | --max_epochs 200 \ 7 | --num_tasks 6 \ 8 | --task_idx 0 \ 9 | --gpus 0,1,2,3 \ 10 | --accelerator ddp \ 11 | --sync_batchnorm \ 12 | --num_workers 5 \ 13 | --precision 16 \ 14 | --optimizer sgd \ 15 | --lars \ 16 | --grad_clip_lars \ 17 | --eta_lars 0.02 \ 18 | --exclude_bias_n_norm \ 19 | --scheduler warmup_cosine \ 20 | --lr 0.4 \ 21 | --weight_decay 1e-4 \ 22 | --batch_size 64 \ 23 | --min_scale 0.2 \ 24 | --brightness 0.4 \ 25 | --contrast 0.4 \ 26 | --saturation 0.2 \ 27 | --hue 0.1 \ 28 | --solarization_prob 0.1 \ 29 | --dali \ 30 | --name vicreg-domainnet \ 31 | --entity unitn-mhug \ 32 | --project ever-learn \ 33 | --wandb \ 34 | --save_checkpoint \ 35 | --method vicreg \ 36 | --proj_hidden_dim 2048 \ 37 | --output_dim 2048 \ 38 | --sim_loss_weight 25.0 \ 39 | --var_loss_weight 25.0 \ 40 | --cov_loss_weight 1.0 \ 41 | --check_val_every_n_epoch 9999 \ 42 | --disable_knn_eval 43 | -------------------------------------------------------------------------------- /bash_files/continual/imagenet-100/class/barlow.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy class \ 8 | --max_epochs 400 \ 9 | --num_tasks 5 \ 10 | --task_idx 0 \ 11 | --gpus 0,1 \ 12 | --accelerator ddp \ 13 | --sync_batchnorm \ 14 | --num_workers 5 \ 15 | --precision 16 \ 16 | --optimizer sgd \ 17 | --lars \ 18 | --grad_clip_lars \ 19 | --eta_lars 0.02 \ 20 | --exclude_bias_n_norm \ 21 | --scheduler warmup_cosine \ 22 | --lr 0.4 \ 23 | --weight_decay 1e-4 \ 24 | --batch_size 128 \ 25 | --brightness 0.4 \ 26 | --contrast 0.4 \ 27 | --saturation 0.2 \ 28 | --hue 0.1 \ 29 | --gaussian_prob 1.0 0.1 \ 30 | --solarization_prob 0.0 0.2 \ 31 | --dali \ 32 | --check_val_every_n_epoch 9999 \ 33 | --name barlow-imagenet100-5T \ 34 | --wandb \ 35 | --save_checkpoint \ 36 | --entity unitn-mhug \ 37 | --project ever-learn \ 38 | --scale_loss 0.1 \ 39 | --method barlow_twins \ 40 | --proj_hidden_dim 2048 \ 41 | --output_dim 2048 42 | -------------------------------------------------------------------------------- /bash_files/continual/imagenet-100/data/barlow.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy data \ 8 | --max_epochs 400 \ 9 | --num_tasks 5 \ 10 | --task_idx 0 \ 11 | --gpus 0,1 \ 12 | --accelerator ddp \ 13 | --sync_batchnorm \ 14 | --num_workers 5 \ 15 | --precision 16 \ 16 | --optimizer sgd \ 17 | --lars \ 18 | --grad_clip_lars \ 19 | --eta_lars 0.02 \ 20 | --exclude_bias_n_norm \ 21 | --scheduler warmup_cosine \ 22 | --lr 0.4 \ 23 | --weight_decay 1e-4 \ 24 | --batch_size 128 \ 25 | --brightness 0.4 \ 26 | --contrast 0.4 \ 27 | --saturation 0.2 \ 28 | --hue 0.1 \ 29 | --gaussian_prob 1.0 0.1 \ 30 | --solarization_prob 0.0 0.2 \ 31 | --dali \ 32 | --check_val_every_n_epoch 9999 \ 33 | --name barlow-imagenet100-5T_data \ 34 | --wandb \ 35 | --save_checkpoint \ 36 | --entity unitn-mhug \ 37 | --project ever-learn \ 38 | --scale_loss 0.1 \ 39 | --method barlow_twins \ 40 | --proj_hidden_dim 2048 \ 41 | --output_dim 2048 42 | -------------------------------------------------------------------------------- /bash_files/continual/cifar/byol_distill.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset cifar100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --split_strategy class \ 6 | --max_epochs 500 \ 7 | --num_tasks 5 \ 8 | --task_idx 1 \ 9 | --gpus 0 \ 10 | --precision 16 \ 11 | --optimizer sgd \ 12 | --lars \ 13 | --grad_clip_lars \ 14 | --eta_lars 0.02 \ 15 | --exclude_bias_n_norm \ 16 | --scheduler warmup_cosine \ 17 | --lr 1.0 \ 18 | --classifier_lr 0.1 \ 19 | --weight_decay 1e-5 \ 20 | --batch_size 256 \ 21 | --num_workers 5 \ 22 | --brightness 0.4 \ 23 | --contrast 0.4 \ 24 | --saturation 0.2 \ 25 | --hue 0.1 \ 26 | --gaussian_prob 0.0 0.0 \ 27 | --solarization_prob 0.0 0.2 \ 28 | --name byol-cifar100-predictive \ 29 | --project ever-learn \ 30 | --entity unitn-mhug \ 31 | --wandb \ 32 | --save_checkpoint \ 33 | --method byol \ 34 | --output_dim 256 \ 35 | --proj_hidden_dim 4096 \ 36 | --pred_hidden_dim 4096 \ 37 | --base_tau_momentum 0.99 \ 38 | --final_tau_momentum 1.0 \ 39 | --momentum_classifier \ 40 | --distiller predictive \ 41 | --pretrained_model $PRETRAINED_PATH 42 | -------------------------------------------------------------------------------- /bash_files/continual/domainnet/mocov2plus_distill.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset domainnet \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR/domainnet \ 5 | --split_strategy domain \ 6 | --max_epochs 200 \ 7 | --num_tasks 6 \ 8 | --task_idx 1 \ 9 | --gpus 0,1,2,3 \ 10 | --accelerator ddp \ 11 | --sync_batchnorm \ 12 | --num_workers 5 \ 13 | --precision 16 \ 14 | --optimizer sgd \ 15 | --scheduler cosine \ 16 | --lr 0.4 \ 17 | --classifier_lr 0.3 \ 18 | --weight_decay 1e-4 \ 19 | --batch_size 64 \ 20 | --brightness 0.4 \ 21 | --contrast 0.4 \ 22 | --saturation 0.4 \ 23 | --hue 0.1 \ 24 | --dali \ 25 | --name mocov2plus-domainnet-contrastive \ 26 | --project ever-learn \ 27 | --entity unitn-mhug \ 28 | --wandb \ 29 | --save_checkpoint \ 30 | --method mocov2plus \ 31 | --proj_hidden_dim 2048 \ 32 | --queue_size 65536 \ 33 | --temperature 0.2 \ 34 | --base_tau_momentum 0.99 \ 35 | --final_tau_momentum 0.999 \ 36 | --momentum_classifier \ 37 | --check_val_every_n_epoch 9999 \ 38 | --disable_knn_eval \ 39 | --distiller contrastive \ 40 | --pretrained_model $PRETRAINED_PATH 41 | -------------------------------------------------------------------------------- /bash_files/continual/imagenet-100/class/supcon_distill.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy class \ 8 | --max_epochs 400 \ 9 | --num_tasks 5 \ 10 | --task_idx 1 \ 11 | --gpus 0,1 \ 12 | --accelerator ddp \ 13 | --sync_batchnorm \ 14 | --num_workers 5 \ 15 | --precision 16 \ 16 | --optimizer sgd \ 17 | --lars \ 18 | --grad_clip_lars \ 19 | --eta_lars 0.02 \ 20 | --exclude_bias_n_norm \ 21 | --scheduler warmup_cosine \ 22 | --lr 0.4 \ 23 | --weight_decay 1e-4 \ 24 | --batch_size 128 \ 25 | --brightness 0.8 \ 26 | --contrast 0.8 \ 27 | --saturation 0.8 \ 28 | --hue 0.2 \ 29 | --dali \ 30 | --check_val_every_n_epoch 9999 \ 31 | --name supcon-imagenet100-5T-contrastive \ 32 | --entity unitn-mhug \ 33 | --project ever-learn \ 34 | --wandb \ 35 | --save_checkpoint \ 36 | --method simclr \ 37 | --temperature 0.1 \ 38 | --proj_hidden_dim 2048 \ 39 | --supervised \ 40 | --distiller contrastive \ 41 | --pretrained_model $PRETRAINED_PATH 42 | -------------------------------------------------------------------------------- /bash_files/continual/imagenet-100/data/supcon_distill.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy data \ 8 | --max_epochs 400 \ 9 | --num_tasks 5 \ 10 | --task_idx 1 \ 11 | --gpus 0,1 \ 12 | --accelerator ddp \ 13 | --sync_batchnorm \ 14 | --num_workers 5 \ 15 | --precision 16 \ 16 | --optimizer sgd \ 17 | --lars \ 18 | --grad_clip_lars \ 19 | --eta_lars 0.02 \ 20 | --exclude_bias_n_norm \ 21 | --scheduler warmup_cosine \ 22 | --lr 0.4 \ 23 | --weight_decay 1e-4 \ 24 | --batch_size 128 \ 25 | --brightness 0.8 \ 26 | --contrast 0.8 \ 27 | --saturation 0.8 \ 28 | --hue 0.2 \ 29 | --dali \ 30 | --check_val_every_n_epoch 9999 \ 31 | --name supcon-imagenet100-5T_data-contrastive \ 32 | --entity unitn-mhug \ 33 | --project ever-learn \ 34 | --wandb \ 35 | --save_checkpoint \ 36 | --method simclr \ 37 | --temperature 0.1 \ 38 | --proj_hidden_dim 2048 \ 39 | --supervised \ 40 | --distiller contrastive \ 41 | --pretrained_model $PRETRAINED_PATH 42 | -------------------------------------------------------------------------------- /bash_files/continual/domainnet/swav.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset domainnet \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR/domainnet \ 5 | --split_strategy domain \ 6 | --max_epochs 200 \ 7 | --num_tasks 6 \ 8 | --task_idx 0 \ 9 | --gpus 0,1,2,3 \ 10 | --accelerator ddp \ 11 | --sync_batchnorm \ 12 | --num_workers 5 \ 13 | --precision 16 \ 14 | --optimizer sgd \ 15 | --lars \ 16 | --grad_clip_lars \ 17 | --eta_lars 0.02 \ 18 | --exclude_bias_n_norm \ 19 | --scheduler warmup_cosine \ 20 | --lr 0.8 \ 21 | --min_lr 0.0006 \ 22 | --classifier_lr 0.1 \ 23 | --weight_decay 1e-6 \ 24 | --batch_size 64 \ 25 | --brightness 0.8 \ 26 | --contrast 0.8 \ 27 | --saturation 0.8 \ 28 | --hue 0.2 \ 29 | --dali \ 30 | --name swav-domainnet \ 31 | --entity unitn-mhug \ 32 | --project ever-learn \ 33 | --wandb \ 34 | --save_checkpoint \ 35 | --method swav \ 36 | --proj_hidden_dim 2048 \ 37 | --queue_size 3840 \ 38 | --output_dim 128 \ 39 | --num_prototypes 3000 \ 40 | --epoch_queue_starts 50 \ 41 | --freeze_prototypes_epochs 3 \ 42 | --check_val_every_n_epoch 9999 \ 43 | --disable_knn_eval 44 | -------------------------------------------------------------------------------- /bash_files/continual/imagenet-100/class/vicreg.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy class \ 8 | --max_epochs 400 \ 9 | --num_tasks 5 \ 10 | --task_idx 0 \ 11 | --gpus 0,1 \ 12 | --accelerator ddp \ 13 | --sync_batchnorm \ 14 | --num_workers 5 \ 15 | --precision 16 \ 16 | --optimizer sgd \ 17 | --lars \ 18 | --grad_clip_lars \ 19 | --eta_lars 0.02 \ 20 | --exclude_bias_n_norm \ 21 | --scheduler warmup_cosine \ 22 | --lr 0.4 \ 23 | --weight_decay 1e-4 \ 24 | --batch_size 128 \ 25 | --min_scale 0.2 \ 26 | --brightness 0.4 \ 27 | --contrast 0.4 \ 28 | --saturation 0.2 \ 29 | --hue 0.1 \ 30 | --solarization_prob 0.1 \ 31 | --dali \ 32 | --check_val_every_n_epoch 9999 \ 33 | --name vicreg-imagenet100-5T \ 34 | --entity unitn-mhug \ 35 | --project ever-learn \ 36 | --wandb \ 37 | --save_checkpoint \ 38 | --method vicreg \ 39 | --proj_hidden_dim 2048 \ 40 | --output_dim 2048 \ 41 | --sim_loss_weight 25.0 \ 42 | --var_loss_weight 25.0 \ 43 | --cov_loss_weight 1.0 44 | -------------------------------------------------------------------------------- /bash_files/continual/imagenet-100/data/vicreg.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy data \ 8 | --max_epochs 400 \ 9 | --num_tasks 5 \ 10 | --task_idx 0 \ 11 | --gpus 0,1 \ 12 | --accelerator ddp \ 13 | --sync_batchnorm \ 14 | --num_workers 5 \ 15 | --precision 16 \ 16 | --optimizer sgd \ 17 | --lars \ 18 | --grad_clip_lars \ 19 | --eta_lars 0.02 \ 20 | --exclude_bias_n_norm \ 21 | --scheduler warmup_cosine \ 22 | --lr 0.4 \ 23 | --weight_decay 1e-4 \ 24 | --batch_size 128 \ 25 | --min_scale 0.2 \ 26 | --brightness 0.4 \ 27 | --contrast 0.4 \ 28 | --saturation 0.2 \ 29 | --hue 0.1 \ 30 | --solarization_prob 0.1 \ 31 | --dali \ 32 | --check_val_every_n_epoch 9999 \ 33 | --name vicreg-imagenet100-5T_data \ 34 | --entity unitn-mhug \ 35 | --project ever-learn \ 36 | --wandb \ 37 | --save_checkpoint \ 38 | --method vicreg \ 39 | --proj_hidden_dim 2048 \ 40 | --output_dim 2048 \ 41 | --sim_loss_weight 25.0 \ 42 | --var_loss_weight 25.0 \ 43 | --cov_loss_weight 1.0 44 | -------------------------------------------------------------------------------- /bash_files/continual/imagenet-100/class/mocov2plus_distill.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy class \ 8 | --max_epochs 400 \ 9 | --num_tasks 5 \ 10 | --task_idx 1 \ 11 | --gpus 0,1 \ 12 | --accelerator ddp \ 13 | --sync_batchnorm \ 14 | --num_workers 5 \ 15 | --precision 16 \ 16 | --optimizer sgd \ 17 | --scheduler cosine \ 18 | --lr 0.4 \ 19 | --classifier_lr 0.3 \ 20 | --weight_decay 1e-4 \ 21 | --batch_size 128 \ 22 | --brightness 0.4 \ 23 | --contrast 0.4 \ 24 | --saturation 0.4 \ 25 | --hue 0.1 \ 26 | --dali \ 27 | --check_val_every_n_epoch 9999 \ 28 | --name mocov2plus-imagenet100-5T-contrastive \ 29 | --project ever-learn \ 30 | --entity unitn-mhug \ 31 | --wandb \ 32 | --save_checkpoint \ 33 | --method mocov2plus \ 34 | --proj_hidden_dim 2048 \ 35 | --queue_size 65536 \ 36 | --temperature 0.2 \ 37 | --base_tau_momentum 0.99 \ 38 | --final_tau_momentum 0.999 \ 39 | --momentum_classifier \ 40 | --distiller contrastive \ 41 | --pretrained_model $PRETRAINED_PATH 42 | -------------------------------------------------------------------------------- /bash_files/continual/imagenet-100/data/mocov2plus_distill.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy data \ 8 | --max_epochs 400 \ 9 | --num_tasks 5 \ 10 | --task_idx 1 \ 11 | --gpus 0,1 \ 12 | --accelerator ddp \ 13 | --sync_batchnorm \ 14 | --num_workers 5 \ 15 | --precision 16 \ 16 | --optimizer sgd \ 17 | --scheduler cosine \ 18 | --lr 0.4 \ 19 | --classifier_lr 0.3 \ 20 | --weight_decay 1e-4 \ 21 | --batch_size 128 \ 22 | --brightness 0.4 \ 23 | --contrast 0.4 \ 24 | --saturation 0.4 \ 25 | --hue 0.1 \ 26 | --dali \ 27 | --check_val_every_n_epoch 9999 \ 28 | --name mocov2plus-imagenet100-5T_data-contrastive \ 29 | --project ever-learn \ 30 | --entity unitn-mhug \ 31 | --wandb \ 32 | --save_checkpoint \ 33 | --method mocov2plus \ 34 | --proj_hidden_dim 2048 \ 35 | --queue_size 65536 \ 36 | --temperature 0.2 \ 37 | --base_tau_momentum 0.99 \ 38 | --final_tau_momentum 0.999 \ 39 | --momentum_classifier \ 40 | --distiller contrastive \ 41 | --pretrained_model $PRETRAINED_PATH 42 | -------------------------------------------------------------------------------- /bash_files/continual/domainnet/barlow_distill.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset domainnet \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR/domainnet \ 5 | --split_strategy domain \ 6 | --max_epochs 200 \ 7 | --num_tasks 6 \ 8 | --task_idx 1 \ 9 | --gpus 0,1,2,3 \ 10 | --accelerator ddp \ 11 | --sync_batchnorm \ 12 | --num_workers 5 \ 13 | --precision 16 \ 14 | --optimizer sgd \ 15 | --lars \ 16 | --grad_clip_lars \ 17 | --eta_lars 0.02 \ 18 | --exclude_bias_n_norm \ 19 | --scheduler warmup_cosine \ 20 | --lr 0.4 \ 21 | --weight_decay 1e-4 \ 22 | --batch_size 64 \ 23 | --brightness 0.4 \ 24 | --contrast 0.4 \ 25 | --saturation 0.2 \ 26 | --hue 0.1 \ 27 | --gaussian_prob 1.0 0.1 \ 28 | --solarization_prob 0.0 0.2 \ 29 | --dali \ 30 | --name barlow-domainnet-decorrelative \ 31 | --wandb \ 32 | --offline \ 33 | --save_checkpoint \ 34 | --entity unitn-mhug \ 35 | --project ever-learn \ 36 | --scale_loss 0.1 \ 37 | --method barlow_twins \ 38 | --proj_hidden_dim 2048 \ 39 | --output_dim 2048 \ 40 | --check_val_every_n_epoch 9999 \ 41 | --disable_knn_eval \ 42 | --distiller decorrelative \ 43 | --pretrained_model $PRETRAINED_PATH -------------------------------------------------------------------------------- /bash_files/continual/domainnet/byol.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset domainnet \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR/domainnet \ 5 | --split_strategy domain \ 6 | --max_epochs 200 \ 7 | --num_tasks 6 \ 8 | --task_idx 0 \ 9 | --gpus 0,1,2,3 \ 10 | --accelerator ddp \ 11 | --sync_batchnorm \ 12 | --num_workers 5 \ 13 | --precision 16 \ 14 | --optimizer sgd \ 15 | --lars \ 16 | --grad_clip_lars \ 17 | --eta_lars 0.02 \ 18 | --exclude_bias_n_norm \ 19 | --scheduler warmup_cosine \ 20 | --lr 0.6 \ 21 | --classifier_lr 0.1 \ 22 | --weight_decay 1e-5 \ 23 | --batch_size 64 \ 24 | --brightness 0.4 \ 25 | --contrast 0.4 \ 26 | --saturation 0.2 \ 27 | --hue 0.1 \ 28 | --gaussian_prob 1.0 0.1 \ 29 | --solarization_prob 0.0 0.2 \ 30 | --dali \ 31 | --name byol-domainnet \ 32 | --entity unitn-mhug \ 33 | --project ever-learn \ 34 | --wandb \ 35 | --save_checkpoint \ 36 | --method byol \ 37 | --output_dim 256 \ 38 | --proj_hidden_dim 4096 \ 39 | --pred_hidden_dim 8192 \ 40 | --base_tau_momentum 0.99 \ 41 | --final_tau_momentum 1.0 \ 42 | --momentum_classifier \ 43 | --check_val_every_n_epoch 9999 \ 44 | --disable_knn_eval 45 | -------------------------------------------------------------------------------- /bash_files/continual/imagenet-100/class/swav.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy class \ 8 | --max_epochs 400 \ 9 | --num_tasks 5 \ 10 | --task_idx 0 \ 11 | --gpus 0,1 \ 12 | --accelerator ddp \ 13 | --sync_batchnorm \ 14 | --num_workers 5 \ 15 | --precision 16 \ 16 | --optimizer sgd \ 17 | --lars \ 18 | --grad_clip_lars \ 19 | --eta_lars 0.02 \ 20 | --exclude_bias_n_norm \ 21 | --scheduler warmup_cosine \ 22 | --lr 0.8 \ 23 | --min_lr 0.0006 \ 24 | --classifier_lr 0.1 \ 25 | --weight_decay 1e-6 \ 26 | --batch_size 128 \ 27 | --brightness 0.8 \ 28 | --contrast 0.8 \ 29 | --saturation 0.8 \ 30 | --hue 0.2 \ 31 | --dali \ 32 | --check_val_every_n_epoch 9999 \ 33 | --name swav-imagenet100-5T \ 34 | --entity unitn-mhug \ 35 | --project ever-learn \ 36 | --wandb \ 37 | --save_checkpoint \ 38 | --method swav \ 39 | --proj_hidden_dim 2048 \ 40 | --queue_size 3840 \ 41 | --output_dim 128 \ 42 | --num_prototypes 3000 \ 43 | --epoch_queue_starts 100 \ 44 | --freeze_prototypes_epochs 5 45 | -------------------------------------------------------------------------------- /bash_files/continual/imagenet-100/data/swav.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy data \ 8 | --max_epochs 400 \ 9 | --num_tasks 5 \ 10 | --task_idx 0 \ 11 | --gpus 0,1 \ 12 | --accelerator ddp \ 13 | --sync_batchnorm \ 14 | --num_workers 5 \ 15 | --precision 16 \ 16 | --optimizer sgd \ 17 | --lars \ 18 | --grad_clip_lars \ 19 | --eta_lars 0.02 \ 20 | --exclude_bias_n_norm \ 21 | --scheduler warmup_cosine \ 22 | --lr 0.8 \ 23 | --min_lr 0.0006 \ 24 | --classifier_lr 0.1 \ 25 | --weight_decay 1e-6 \ 26 | --batch_size 128 \ 27 | --brightness 0.8 \ 28 | --contrast 0.8 \ 29 | --saturation 0.8 \ 30 | --hue 0.2 \ 31 | --dali \ 32 | --check_val_every_n_epoch 9999 \ 33 | --name swav-imagenet100-5T_data \ 34 | --entity unitn-mhug \ 35 | --project ever-learn \ 36 | --wandb \ 37 | --save_checkpoint \ 38 | --method swav \ 39 | --proj_hidden_dim 2048 \ 40 | --queue_size 3840 \ 41 | --output_dim 128 \ 42 | --num_prototypes 3000 \ 43 | --epoch_queue_starts 100 \ 44 | --freeze_prototypes_epochs 5 45 | -------------------------------------------------------------------------------- /bash_files/continual/imagenet-100/class/byol.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy class \ 8 | --max_epochs 400 \ 9 | --num_tasks 5 \ 10 | --task_idx 0 \ 11 | --gpus 0,1 \ 12 | --accelerator ddp \ 13 | --sync_batchnorm \ 14 | --num_workers 5 \ 15 | --precision 16 \ 16 | --optimizer sgd \ 17 | --lars \ 18 | --grad_clip_lars \ 19 | --eta_lars 0.02 \ 20 | --exclude_bias_n_norm \ 21 | --scheduler warmup_cosine \ 22 | --lr 0.6 \ 23 | --classifier_lr 0.1 \ 24 | --weight_decay 1e-5 \ 25 | --batch_size 128 \ 26 | --brightness 0.4 \ 27 | --contrast 0.4 \ 28 | --saturation 0.2 \ 29 | --hue 0.1 \ 30 | --gaussian_prob 1.0 0.1 \ 31 | --solarization_prob 0.0 0.2 \ 32 | --dali \ 33 | --check_val_every_n_epoch 999 \ 34 | --name byol-imagenet100-5T \ 35 | --entity unitn-mhug \ 36 | --project ever-learn \ 37 | --wandb \ 38 | --save_checkpoint \ 39 | --method byol \ 40 | --output_dim 256 \ 41 | --proj_hidden_dim 4096 \ 42 | --pred_hidden_dim 8192 \ 43 | --base_tau_momentum 0.99 \ 44 | --final_tau_momentum 1.0 45 | -------------------------------------------------------------------------------- /bash_files/continual/domainnet/vicreg_distill.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset domainnet \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR/domainnet \ 5 | --split_strategy domain \ 6 | --max_epochs 200 \ 7 | --num_tasks 6 \ 8 | --task_idx 1 \ 9 | --gpus 0,1,2,3 \ 10 | --accelerator ddp \ 11 | --sync_batchnorm \ 12 | --num_workers 5 \ 13 | --precision 16 \ 14 | --optimizer sgd \ 15 | --lars \ 16 | --grad_clip_lars \ 17 | --eta_lars 0.02 \ 18 | --exclude_bias_n_norm \ 19 | --scheduler warmup_cosine \ 20 | --lr 0.4 \ 21 | --weight_decay 1e-4 \ 22 | --batch_size 64 \ 23 | --min_scale 0.2 \ 24 | --brightness 0.4 \ 25 | --contrast 0.4 \ 26 | --saturation 0.2 \ 27 | --hue 0.1 \ 28 | --solarization_prob 0.1 \ 29 | --dali \ 30 | --name vicreg-domainnet-predictive \ 31 | --entity unitn-mhug \ 32 | --project ever-learn \ 33 | --wandb \ 34 | --save_checkpoint \ 35 | --method vicreg \ 36 | --proj_hidden_dim 2048 \ 37 | --output_dim 2048 \ 38 | --sim_loss_weight 25.0 \ 39 | --var_loss_weight 25.0 \ 40 | --cov_loss_weight 1.0 \ 41 | --check_val_every_n_epoch 9999 \ 42 | --disable_knn_eval \ 43 | --distiller predictiveMSE \ 44 | --pretrained_model $PRETRAINED_PATH 45 | -------------------------------------------------------------------------------- /bash_files/continual/imagenet-100/class/barlow_distill.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy class \ 8 | --max_epochs 400 \ 9 | --num_tasks 5 \ 10 | --task_idx 1 \ 11 | --gpus 0,1 \ 12 | --accelerator ddp \ 13 | --sync_batchnorm \ 14 | --num_workers 5 \ 15 | --precision 16 \ 16 | --optimizer sgd \ 17 | --lars \ 18 | --grad_clip_lars \ 19 | --eta_lars 0.02 \ 20 | --exclude_bias_n_norm \ 21 | --scheduler warmup_cosine \ 22 | --lr 0.4 \ 23 | --weight_decay 1e-4 \ 24 | --batch_size 128 \ 25 | --brightness 0.4 \ 26 | --contrast 0.4 \ 27 | --saturation 0.2 \ 28 | --hue 0.1 \ 29 | --gaussian_prob 1.0 0.1 \ 30 | --solarization_prob 0.0 0.2 \ 31 | --dali \ 32 | --check_val_every_n_epoch 9999 \ 33 | --name barlow-imagenet100-5T-decorrelative \ 34 | --wandb \ 35 | --save_checkpoint \ 36 | --entity unitn-mhug \ 37 | --project ever-learn \ 38 | --scale_loss 0.1 \ 39 | --method barlow_twins \ 40 | --proj_hidden_dim 2048 \ 41 | --output_dim 2048 \ 42 | --distiller decorrelative \ 43 | --pretrained_model $PRETRAINED_PATH 44 | -------------------------------------------------------------------------------- /bash_files/continual/imagenet-100/data/barlow_distill.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy data \ 8 | --max_epochs 400 \ 9 | --num_tasks 5 \ 10 | --task_idx 1 \ 11 | --gpus 0,1 \ 12 | --accelerator ddp \ 13 | --sync_batchnorm \ 14 | --num_workers 5 \ 15 | --precision 16 \ 16 | --optimizer sgd \ 17 | --lars \ 18 | --grad_clip_lars \ 19 | --eta_lars 0.02 \ 20 | --exclude_bias_n_norm \ 21 | --scheduler warmup_cosine \ 22 | --lr 0.4 \ 23 | --weight_decay 1e-4 \ 24 | --batch_size 128 \ 25 | --brightness 0.4 \ 26 | --contrast 0.4 \ 27 | --saturation 0.2 \ 28 | --hue 0.1 \ 29 | --gaussian_prob 1.0 0.1 \ 30 | --solarization_prob 0.0 0.2 \ 31 | --dali \ 32 | --check_val_every_n_epoch 9999 \ 33 | --name barlow-imagenet100-5T_data-decorrelative \ 34 | --wandb \ 35 | --save_checkpoint \ 36 | --entity unitn-mhug \ 37 | --project ever-learn \ 38 | --scale_loss 0.1 \ 39 | --method barlow_twins \ 40 | --proj_hidden_dim 2048 \ 41 | --output_dim 2048 \ 42 | --distiller decorrelative \ 43 | --pretrained_model $PRETRAINED_PATH 44 | -------------------------------------------------------------------------------- /bash_files/continual/domainnet/swav_distill.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset domainnet \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR/domainnet \ 5 | --split_strategy domain \ 6 | --max_epochs 200 \ 7 | --num_tasks 6 \ 8 | --task_idx 1 \ 9 | --gpus 0,1,2,3 \ 10 | --accelerator ddp \ 11 | --sync_batchnorm \ 12 | --num_workers 5 \ 13 | --precision 16 \ 14 | --optimizer sgd \ 15 | --lars \ 16 | --grad_clip_lars \ 17 | --eta_lars 0.02 \ 18 | --exclude_bias_n_norm \ 19 | --scheduler warmup_cosine \ 20 | --lr 0.8 \ 21 | --min_lr 0.0006 \ 22 | --classifier_lr 0.1 \ 23 | --weight_decay 1e-6 \ 24 | --batch_size 64 \ 25 | --brightness 0.8 \ 26 | --contrast 0.8 \ 27 | --saturation 0.8 \ 28 | --hue 0.2 \ 29 | --dali \ 30 | --name swav-domainnet-knowlegde \ 31 | --entity unitn-mhug \ 32 | --project ever-learn \ 33 | --wandb \ 34 | --save_checkpoint \ 35 | --method swav \ 36 | --proj_hidden_dim 2048 \ 37 | --queue_size 3840 \ 38 | --output_dim 128 \ 39 | --num_prototypes 3000 \ 40 | --epoch_queue_starts 100 \ 41 | --freeze_prototypes_epochs 5 \ 42 | --check_val_every_n_epoch 9999 \ 43 | --disable_knn_eval \ 44 | --distiller knowledge \ 45 | --pretrained_model $PRETRAINED_PATH 46 | -------------------------------------------------------------------------------- /bash_files/continual/imagenet-100/data/byol.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy data \ 8 | --max_epochs 400 \ 9 | --num_tasks 5 \ 10 | --task_idx 0 \ 11 | --gpus 0,1 \ 12 | --accelerator ddp \ 13 | --sync_batchnorm \ 14 | --num_workers 5 \ 15 | --precision 16 \ 16 | --optimizer sgd \ 17 | --lars \ 18 | --grad_clip_lars \ 19 | --eta_lars 0.02 \ 20 | --exclude_bias_n_norm \ 21 | --scheduler warmup_cosine \ 22 | --lr 0.6 \ 23 | --classifier_lr 0.1 \ 24 | --weight_decay 1e-5 \ 25 | --batch_size 128 \ 26 | --brightness 0.4 \ 27 | --contrast 0.4 \ 28 | --saturation 0.2 \ 29 | --hue 0.1 \ 30 | --gaussian_prob 1.0 0.1 \ 31 | --solarization_prob 0.0 0.2 \ 32 | --dali \ 33 | --check_val_every_n_epoch 9999 \ 34 | --name byol-imagenet100-5T_data \ 35 | --entity unitn-mhug \ 36 | --project ever-learn \ 37 | --wandb \ 38 | --save_checkpoint \ 39 | --method byol \ 40 | --output_dim 256 \ 41 | --proj_hidden_dim 4096 \ 42 | --pred_hidden_dim 8192 \ 43 | --base_tau_momentum 0.99 \ 44 | --final_tau_momentum 1.0 \ 45 | --momentum_classifier 46 | -------------------------------------------------------------------------------- /cassle/losses/ressl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def ressl_loss_func( 6 | q: torch.Tensor, 7 | k: torch.Tensor, 8 | queue: torch.Tensor, 9 | temperature_q: float = 0.1, 10 | temperature_k: float = 0.04, 11 | ) -> torch.Tensor: 12 | """Computes ReSSL's loss given a batch of queries from view 1, a batch of keys from view 2 and a 13 | queue of past elements. 14 | 15 | Args: 16 | query (torch.Tensor): NxD Tensor containing the queries from view 1. 17 | key (torch.Tensor): NxD Tensor containing the queries from view 2. 18 | queue (torch.Tensor): a queue of negative samples for the contrastive loss. 19 | temperature_q (float, optional): [description]. temperature of the softmax for the query. 20 | Defaults to 0.1. 21 | temperature_k (float, optional): [description]. temperature of the softmax for the key. 22 | Defaults to 0.04. 23 | 24 | Returns: 25 | torch.Tensor: ReSSL loss. 26 | """ 27 | 28 | logits_q = torch.einsum("nc,kc->nk", [q, queue]) 29 | logits_k = torch.einsum("nc,kc->nk", [k, queue]) 30 | 31 | loss = -torch.sum( 32 | F.softmax(logits_k.detach() / temperature_k, dim=1) 33 | * F.log_softmax(logits_q / temperature_q, dim=1), 34 | dim=1, 35 | ).mean() 36 | 37 | return loss 38 | -------------------------------------------------------------------------------- /bash_files/continual/imagenet-100/class/vicreg_distill.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy class \ 8 | --max_epochs 400 \ 9 | --num_tasks 5 \ 10 | --task_idx 1 \ 11 | --gpus 0,1 \ 12 | --accelerator ddp \ 13 | --sync_batchnorm \ 14 | --num_workers 5 \ 15 | --precision 16 \ 16 | --optimizer sgd \ 17 | --lars \ 18 | --grad_clip_lars \ 19 | --eta_lars 0.02 \ 20 | --exclude_bias_n_norm \ 21 | --scheduler warmup_cosine \ 22 | --lr 0.4 \ 23 | --weight_decay 1e-4 \ 24 | --batch_size 128 \ 25 | --min_scale 0.2 \ 26 | --brightness 0.4 \ 27 | --contrast 0.4 \ 28 | --saturation 0.2 \ 29 | --hue 0.1 \ 30 | --solarization_prob 0.1 \ 31 | --dali \ 32 | --check_val_every_n_epoch 9999 \ 33 | --name vicreg-imagenet100-5T-predictive \ 34 | --entity unitn-mhug \ 35 | --project ever-learn \ 36 | --wandb \ 37 | --save_checkpoint \ 38 | --method vicreg \ 39 | --proj_hidden_dim 2048 \ 40 | --output_dim 2048 \ 41 | --sim_loss_weight 25.0 \ 42 | --var_loss_weight 25.0 \ 43 | --cov_loss_weight 1.0 \ 44 | --distiller predictive_mse \ 45 | --pretrained_model $PRETRAINED_PATH 46 | -------------------------------------------------------------------------------- /bash_files/continual/imagenet-100/data/vicreg_distill.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy data \ 8 | --max_epochs 400 \ 9 | --num_tasks 5 \ 10 | --task_idx 1 \ 11 | --gpus 0,1 \ 12 | --accelerator ddp \ 13 | --sync_batchnorm \ 14 | --num_workers 5 \ 15 | --precision 16 \ 16 | --optimizer sgd \ 17 | --lars \ 18 | --grad_clip_lars \ 19 | --eta_lars 0.02 \ 20 | --exclude_bias_n_norm \ 21 | --scheduler warmup_cosine \ 22 | --lr 0.4 \ 23 | --weight_decay 1e-4 \ 24 | --batch_size 128 \ 25 | --min_scale 0.2 \ 26 | --brightness 0.4 \ 27 | --contrast 0.4 \ 28 | --saturation 0.2 \ 29 | --hue 0.1 \ 30 | --solarization_prob 0.1 \ 31 | --dali \ 32 | --check_val_every_n_epoch 9999 \ 33 | --name vicreg-imagenet100-5T_data-predictive \ 34 | --entity unitn-mhug \ 35 | --project ever-learn \ 36 | --wandb \ 37 | --save_checkpoint \ 38 | --method vicreg \ 39 | --proj_hidden_dim 2048 \ 40 | --output_dim 2048 \ 41 | --sim_loss_weight 25.0 \ 42 | --var_loss_weight 25.0 \ 43 | --cov_loss_weight 1.0 \ 44 | --distiller predictive_mse \ 45 | --pretrained_model $PRETRAINED_PATH 46 | -------------------------------------------------------------------------------- /bash_files/continual/domainnet/byol_distill.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset domainnet \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR/domainnet \ 5 | --split_strategy domain \ 6 | --max_epochs 200 \ 7 | --num_tasks 6 \ 8 | --task_idx 1 \ 9 | --gpus 0,1,2,3 \ 10 | --accelerator ddp \ 11 | --sync_batchnorm \ 12 | --num_workers 5 \ 13 | --precision 16 \ 14 | --optimizer sgd \ 15 | --lars \ 16 | --grad_clip_lars \ 17 | --eta_lars 0.02 \ 18 | --exclude_bias_n_norm \ 19 | --scheduler warmup_cosine \ 20 | --lr 0.6 \ 21 | --classifier_lr 0.1 \ 22 | --weight_decay 1e-5 \ 23 | --batch_size 256 \ 24 | --brightness 0.4 \ 25 | --contrast 0.4 \ 26 | --saturation 0.2 \ 27 | --hue 0.1 \ 28 | --gaussian_prob 1.0 0.1 \ 29 | --solarization_prob 0.0 0.2 \ 30 | --dali \ 31 | --name byol-domainnet-predictive \ 32 | --entity unitn-mhug \ 33 | --project ever-learn \ 34 | --wandb \ 35 | --save_checkpoint \ 36 | --method byol \ 37 | --output_dim 256 \ 38 | --proj_hidden_dim 4096 \ 39 | --pred_hidden_dim 8192 \ 40 | --base_tau_momentum 0.99 \ 41 | --final_tau_momentum 1.0 \ 42 | --momentum_classifier \ 43 | --check_val_every_n_epoch 9999 \ 44 | --disable_knn_eval \ 45 | --distiller predictive \ 46 | --pretrained_model $PRETRAINED_PATH 47 | -------------------------------------------------------------------------------- /bash_files/continual/imagenet-100/class/swav_distill.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy class \ 8 | --max_epochs 400 \ 9 | --num_tasks 5 \ 10 | --task_idx 1 \ 11 | --gpus 0,1 \ 12 | --accelerator ddp \ 13 | --sync_batchnorm \ 14 | --num_workers 5 \ 15 | --precision 16 \ 16 | --optimizer sgd \ 17 | --lars \ 18 | --grad_clip_lars \ 19 | --eta_lars 0.02 \ 20 | --exclude_bias_n_norm \ 21 | --scheduler warmup_cosine \ 22 | --lr 0.8 \ 23 | --min_lr 0.0006 \ 24 | --classifier_lr 0.1 \ 25 | --weight_decay 1e-6 \ 26 | --batch_size 128 \ 27 | --brightness 0.8 \ 28 | --contrast 0.8 \ 29 | --saturation 0.8 \ 30 | --hue 0.2 \ 31 | --dali \ 32 | --check_val_every_n_epoch 9999 \ 33 | --name swav-imagenet100-5T-knowlegde \ 34 | --entity unitn-mhug \ 35 | --project ever-learn \ 36 | --wandb \ 37 | --save_checkpoint \ 38 | --method swav \ 39 | --proj_hidden_dim 2048 \ 40 | --queue_size 3840 \ 41 | --output_dim 128 \ 42 | --num_prototypes 3000 \ 43 | --epoch_queue_starts 100 \ 44 | --freeze_prototypes_epochs 5 \ 45 | --distiller knowledge \ 46 | --pretrained_model $PRETRAINED_PATH 47 | -------------------------------------------------------------------------------- /bash_files/continual/imagenet-100/data/swav_distill.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy data \ 8 | --max_epochs 400 \ 9 | --num_tasks 5 \ 10 | --task_idx 1 \ 11 | --gpus 0,1 \ 12 | --accelerator ddp \ 13 | --sync_batchnorm \ 14 | --num_workers 5 \ 15 | --precision 16 \ 16 | --optimizer sgd \ 17 | --lars \ 18 | --grad_clip_lars \ 19 | --eta_lars 0.02 \ 20 | --exclude_bias_n_norm \ 21 | --scheduler warmup_cosine \ 22 | --lr 0.8 \ 23 | --min_lr 0.0006 \ 24 | --classifier_lr 0.1 \ 25 | --weight_decay 1e-6 \ 26 | --batch_size 128 \ 27 | --brightness 0.8 \ 28 | --contrast 0.8 \ 29 | --saturation 0.8 \ 30 | --hue 0.2 \ 31 | --dali \ 32 | --check_val_every_n_epoch 9999 \ 33 | --name swav-imagenet100-5T_data-knowlegde \ 34 | --entity unitn-mhug \ 35 | --project ever-learn \ 36 | --wandb \ 37 | --save_checkpoint \ 38 | --method swav \ 39 | --proj_hidden_dim 2048 \ 40 | --queue_size 3840 \ 41 | --output_dim 128 \ 42 | --num_prototypes 3000 \ 43 | --epoch_queue_starts 100 \ 44 | --freeze_prototypes_epochs 5 \ 45 | --distiller knowledge \ 46 | --pretrained_model $PRETRAINED_PATH 47 | -------------------------------------------------------------------------------- /bash_files/continual/imagenet-100/data/byol_distill.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy data \ 8 | --max_epochs 400 \ 9 | --num_tasks 5 \ 10 | --task_idx 1 \ 11 | --gpus 0,1 \ 12 | --accelerator ddp \ 13 | --sync_batchnorm \ 14 | --num_workers 5 \ 15 | --precision 16 \ 16 | --optimizer sgd \ 17 | --lars \ 18 | --grad_clip_lars \ 19 | --eta_lars 0.02 \ 20 | --exclude_bias_n_norm \ 21 | --scheduler warmup_cosine \ 22 | --lr 0.6 \ 23 | --classifier_lr 0.1 \ 24 | --weight_decay 1e-5 \ 25 | --batch_size 128 \ 26 | --brightness 0.4 \ 27 | --contrast 0.4 \ 28 | --saturation 0.2 \ 29 | --hue 0.1 \ 30 | --gaussian_prob 1.0 0.1 \ 31 | --solarization_prob 0.0 0.2 \ 32 | --dali \ 33 | --check_val_every_n_epoch 9999 \ 34 | --name byol-imagenet100-5T_data-predictive \ 35 | --entity unitn-mhug \ 36 | --project ever-learn \ 37 | --wandb \ 38 | --save_checkpoint \ 39 | --method byol \ 40 | --output_dim 256 \ 41 | --proj_hidden_dim 4096 \ 42 | --pred_hidden_dim 8192 \ 43 | --base_tau_momentum 0.99 \ 44 | --final_tau_momentum 1.0 \ 45 | --distiller predictive \ 46 | --pretrained_model $PRETRAINED_PATH 47 | -------------------------------------------------------------------------------- /bash_files/continual/imagenet-100/class/byol_distill.sh: -------------------------------------------------------------------------------- 1 | python3 main_continual.py \ 2 | --dataset imagenet100 \ 3 | --encoder resnet18 \ 4 | --data_dir $DATA_DIR \ 5 | --train_dir imagenet-100/train \ 6 | --val_dir imagenet-100/val \ 7 | --split_strategy class \ 8 | --max_epochs 400 \ 9 | --num_tasks 5 \ 10 | --task_idx 1 \ 11 | --gpus 0,1 \ 12 | --accelerator ddp \ 13 | --sync_batchnorm \ 14 | --num_workers 5 \ 15 | --precision 16 \ 16 | --optimizer sgd \ 17 | --lars \ 18 | --grad_clip_lars \ 19 | --eta_lars 0.02 \ 20 | --exclude_bias_n_norm \ 21 | --scheduler warmup_cosine \ 22 | --lr 0.6 \ 23 | --classifier_lr 0.1 \ 24 | --weight_decay 1e-5 \ 25 | --batch_size 128 \ 26 | --brightness 0.4 \ 27 | --contrast 0.4 \ 28 | --saturation 0.2 \ 29 | --hue 0.1 \ 30 | --gaussian_prob 1.0 0.1 \ 31 | --solarization_prob 0.0 0.2 \ 32 | --dali \ 33 | --check_val_every_n_epoch 9999 \ 34 | --name byol-imagenet100-5T-predictive \ 35 | --entity unitn-mhug \ 36 | --project ever-learn \ 37 | --wandb \ 38 | --save_checkpoint \ 39 | --method byol \ 40 | --output_dim 256 \ 41 | --proj_hidden_dim 4096 \ 42 | --pred_hidden_dim 8192 \ 43 | --base_tau_momentum 0.99 \ 44 | --final_tau_momentum 1.0 \ 45 | --momentum_classifier \ 46 | --distiller predictive \ 47 | --pretrained_model $PRETRAINED_PATH 48 | -------------------------------------------------------------------------------- /cassle/losses/barlow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torch.distributed as dist 4 | 5 | 6 | def barlow_loss_func( 7 | z1: torch.Tensor, z2: torch.Tensor, lamb: float = 5e-3, scale_loss: float = 0.025 8 | ) -> torch.Tensor: 9 | """Computes Barlow Twins' loss given batch of projected features z1 from view 1 and 10 | projected features z2 from view 2. 11 | 12 | Args: 13 | z1 (torch.Tensor): NxD Tensor containing projected features from view 1. 14 | z2 (torch.Tensor): NxD Tensor containing projected features from view 2. 15 | lamb (float, optional): off-diagonal scaling factor for the cross-covariance matrix. 16 | Defaults to 5e-3. 17 | scale_loss (float, optional): final scaling factor of the loss. Defaults to 0.025. 18 | 19 | Returns: 20 | torch.Tensor: Barlow Twins' loss. 21 | """ 22 | 23 | N, D = z1.size() 24 | 25 | # to match the original code 26 | bn = torch.nn.BatchNorm1d(D, affine=False).to(z1.device) 27 | z1 = bn(z1) 28 | z2 = bn(z2) 29 | 30 | corr = torch.einsum("bi, bj -> ij", z1, z2) / N 31 | 32 | if dist.is_available() and dist.is_initialized(): 33 | dist.all_reduce(corr) 34 | world_size = dist.get_world_size() 35 | corr /= world_size 36 | 37 | diag = torch.eye(D, device=corr.device) 38 | cdif = (corr - diag).pow(2) 39 | cdif[~diag.bool()] *= lamb 40 | loss = scale_loss * cdif.sum() 41 | return loss 42 | -------------------------------------------------------------------------------- /cassle/methods/__init__.py: -------------------------------------------------------------------------------- 1 | from cassle.methods.barlow_twins import BarlowTwins 2 | from cassle.methods.base import BaseModel 3 | from cassle.methods.byol import BYOL 4 | from cassle.methods.deepclusterv2 import DeepClusterV2 5 | from cassle.methods.dino import DINO 6 | from cassle.methods.linear import LinearModel 7 | from cassle.methods.mocov2plus import MoCoV2Plus 8 | from cassle.methods.nnclr import NNCLR 9 | from cassle.methods.ressl import ReSSL 10 | from cassle.methods.simclr import SimCLR 11 | from cassle.methods.simsiam import SimSiam 12 | from cassle.methods.swav import SwAV 13 | from cassle.methods.vicreg import VICReg 14 | from cassle.methods.wmse import WMSE 15 | 16 | METHODS = { 17 | # base classes 18 | "base": BaseModel, 19 | "linear": LinearModel, 20 | # methods 21 | "barlow_twins": BarlowTwins, 22 | "byol": BYOL, 23 | "deepclusterv2": DeepClusterV2, 24 | "dino": DINO, 25 | "mocov2plus": MoCoV2Plus, 26 | "nnclr": NNCLR, 27 | "ressl": ReSSL, 28 | "simclr": SimCLR, 29 | "simsiam": SimSiam, 30 | "swav": SwAV, 31 | "vicreg": VICReg, 32 | "wmse": WMSE, 33 | } 34 | __all__ = [ 35 | "BarlowTwins", 36 | "BYOL", 37 | "BaseModel", 38 | "DeepClusterV2", 39 | "DINO", 40 | "LinearModel", 41 | "MoCoV2Plus", 42 | "NNCLR", 43 | "ReSSL", 44 | "SimCLR", 45 | "SimSiam", 46 | "SwAV", 47 | "VICReg", 48 | "WMSE", 49 | ] 50 | 51 | try: 52 | from cassle.methods import dali # noqa: F401 53 | except ImportError: 54 | pass 55 | else: 56 | __all__.append("dali") 57 | -------------------------------------------------------------------------------- /cassle/distillers/base.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Any, Sequence 3 | import torch 4 | 5 | 6 | def base_distill_wrapper(Method=object): 7 | class BaseDistillWrapper(Method): 8 | def __init__(self, **kwargs) -> None: 9 | super().__init__(**kwargs) 10 | 11 | self.output_dim = kwargs["output_dim"] 12 | 13 | self.frozen_encoder = deepcopy(self.encoder) 14 | self.frozen_projector = deepcopy(self.projector) 15 | 16 | def on_train_start(self): 17 | super().on_train_start() 18 | 19 | if self.current_task_idx > 0: 20 | 21 | self.frozen_encoder = deepcopy(self.encoder) 22 | self.frozen_projector = deepcopy(self.projector) 23 | 24 | for pg in self.frozen_encoder.parameters(): 25 | pg.requires_grad = False 26 | for pg in self.frozen_projector.parameters(): 27 | pg.requires_grad = False 28 | 29 | @torch.no_grad() 30 | def frozen_forward(self, X): 31 | feats = self.frozen_encoder(X) 32 | return feats, self.frozen_projector(feats) 33 | 34 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: 35 | _, (X1, X2), _ = batch[f"task{self.current_task_idx}"] 36 | 37 | out = super().training_step(batch, batch_idx) 38 | 39 | frozen_feats1, frozen_z1 = self.frozen_forward(X1) 40 | frozen_feats2, frozen_z2 = self.frozen_forward(X2) 41 | 42 | out.update( 43 | {"frozen_feats": [frozen_feats1, frozen_feats2], "frozen_z": [frozen_z1, frozen_z2]} 44 | ) 45 | return out 46 | 47 | return BaseDistillWrapper 48 | -------------------------------------------------------------------------------- /cassle/utils/whitening.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.cuda.amp import custom_fwd 4 | from torch.nn.functional import conv2d 5 | 6 | 7 | class Whitening2d(nn.Module): 8 | def __init__(self, output_dim: int, eps: float = 0.0): 9 | """Layer that computes hard whitening for W-MSE using the Cholesky decomposition. 10 | 11 | Args: 12 | output_dim (int): number of dimension of projected features. 13 | eps (float, optional): eps for numerical stability in Cholesky decomposition. Defaults 14 | to 0.0. 15 | """ 16 | 17 | super(Whitening2d, self).__init__() 18 | self.output_dim = output_dim 19 | self.eps = eps 20 | 21 | @custom_fwd(cast_inputs=torch.float32) 22 | def forward(self, x: torch.Tensor) -> torch.Tensor: 23 | """Performs whitening using the Cholesky decomposition. 24 | 25 | Args: 26 | x (torch.Tensor): a batch or slice of projected features. 27 | 28 | Returns: 29 | torch.Tensor: a batch or slice of whitened features. 30 | """ 31 | 32 | x = x.unsqueeze(2).unsqueeze(3) 33 | m = x.mean(0).view(self.output_dim, -1).mean(-1).view(1, -1, 1, 1) 34 | xn = x - m 35 | 36 | T = xn.permute(1, 0, 2, 3).contiguous().view(self.output_dim, -1) 37 | f_cov = torch.mm(T, T.permute(1, 0)) / (T.shape[-1] - 1) 38 | 39 | eye = torch.eye(self.output_dim).type(f_cov.type()) 40 | 41 | f_cov_shrinked = (1 - self.eps) * f_cov + self.eps * eye 42 | 43 | inv_sqrt = torch.triangular_solve(eye, torch.cholesky(f_cov_shrinked), upper=False)[0] 44 | inv_sqrt = inv_sqrt.contiguous().view(self.output_dim, self.output_dim, 1, 1) 45 | 46 | decorrelated = conv2d(xn, inv_sqrt) 47 | 48 | return decorrelated.squeeze(2).squeeze(2) 49 | -------------------------------------------------------------------------------- /cassle/utils/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Sequence 2 | 3 | import torch 4 | 5 | 6 | def accuracy_at_k( 7 | outputs: torch.Tensor, targets: torch.Tensor, top_k: Sequence[int] = (1, 5) 8 | ) -> Sequence[int]: 9 | """Computes the accuracy over the k top predictions for the specified values of k. 10 | 11 | Args: 12 | outputs (torch.Tensor): output of a classifier (logits or probabilities). 13 | targets (torch.Tensor): ground truth labels. 14 | top_k (Sequence[int], optional): sequence of top k values to compute the accuracy over. 15 | Defaults to (1, 5). 16 | 17 | Returns: 18 | Sequence[int]: accuracies at the desired k. 19 | """ 20 | 21 | with torch.no_grad(): 22 | maxk = max(top_k) 23 | batch_size = targets.size(0) 24 | 25 | _, pred = outputs.topk(maxk, 1, True, True) 26 | pred = pred.t() 27 | correct = pred.eq(targets.view(1, -1).expand_as(pred)) 28 | 29 | res = [] 30 | for k in top_k: 31 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) 32 | res.append(correct_k.mul_(100.0 / batch_size)) 33 | return res 34 | 35 | 36 | def weighted_mean(outputs: List[Dict], key: str, batch_size_key: str) -> float: 37 | """Computes the mean of the values of a key weighted by the batch size. 38 | 39 | Args: 40 | outputs (List[Dict]): list of dicts containing the outputs of a validation step. 41 | key (str): key of the metric of interest. 42 | batch_size_key (str): key of batch size values. 43 | 44 | Returns: 45 | float: weighted mean of the values of a key 46 | """ 47 | 48 | value = 0 49 | n = 0 50 | for out in outputs: 51 | value += out[batch_size_key] * out[key] 52 | n += out[batch_size_key] 53 | value = value / n 54 | return value.squeeze(0) 55 | -------------------------------------------------------------------------------- /cassle/utils/trunc_normal.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | 4 | import torch 5 | 6 | 7 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 8 | """Copy & paste from PyTorch official master until it's in a few official releases - RW 9 | Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 10 | """ 11 | 12 | def norm_cdf(x): 13 | """Computes standard normal cumulative distribution function""" 14 | 15 | return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 16 | 17 | if (mean < a - 2 * std) or (mean > b + 2 * std): 18 | warnings.warn( 19 | "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 20 | "The distribution of values may be incorrect.", 21 | stacklevel=2, 22 | ) 23 | 24 | with torch.no_grad(): 25 | # Values are generated by using a truncated uniform distribution and 26 | # then using the inverse CDF for the normal distribution. 27 | # Get upper and lower cdf values 28 | l = norm_cdf((a - mean) / std) 29 | u = norm_cdf((b - mean) / std) 30 | 31 | # Uniformly fill tensor with values from [l, u], then translate to 32 | # [2l-1, 2u-1]. 33 | tensor.uniform_(2 * l - 1, 2 * u - 1) 34 | 35 | # Use inverse cdf transform for normal distribution to get truncated 36 | # standard normal 37 | tensor.erfinv_() 38 | 39 | # Transform to proper mean, std 40 | tensor.mul_(std * math.sqrt(2.0)) 41 | tensor.add_(mean) 42 | 43 | # Clamp to ensure it's in the proper range 44 | tensor.clamp_(min=a, max=b) 45 | return tensor 46 | 47 | 48 | def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): 49 | """Copy & paste from PyTorch official master until it's in a few official releases - RW 50 | Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 51 | """ 52 | 53 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 54 | -------------------------------------------------------------------------------- /cassle/utils/datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils.data.dataset import Dataset 4 | from PIL import Image 5 | 6 | 7 | class DomainNetDataset(Dataset): 8 | def __init__( 9 | self, 10 | data_root, 11 | image_list_root, 12 | domain_names, 13 | split="train", 14 | transform=None, 15 | return_domain=False, 16 | ): 17 | self.data_root = data_root 18 | self.transform = transform 19 | self.domain_names = domain_names 20 | self.return_domain = return_domain 21 | 22 | if domain_names is None: 23 | self.domain_names = [ 24 | "clipart", 25 | "infograph", 26 | "painting", 27 | "quickdraw", 28 | "real", 29 | "sketch", 30 | ] 31 | if not isinstance(domain_names, list): 32 | self.domain_name = [domain_names] 33 | 34 | image_list_paths = [ 35 | os.path.join(image_list_root, d + "_" + split + ".txt") for d in self.domain_names 36 | ] 37 | self.imgs = self._make_dataset(image_list_paths) 38 | 39 | def _make_dataset(self, image_list_paths): 40 | images = [] 41 | for image_list_path in image_list_paths: 42 | image_list = open(image_list_path).readlines() 43 | images += [(val.split()[0], int(val.split()[1])) for val in image_list] 44 | return images 45 | 46 | def _rgb_loader(self, path): 47 | with open(path, "rb") as f: 48 | with Image.open(f) as img: 49 | return img.convert("RGB") 50 | 51 | def __getitem__(self, index): 52 | path, target = self.imgs[index] 53 | img = self._rgb_loader(os.path.join(self.data_root, path)) 54 | 55 | if self.transform is not None: 56 | img = self.transform(img) 57 | 58 | domain = None 59 | if self.return_domain: 60 | domain = [d for d in self.domain_names if d in path] 61 | assert len(domain) == 1 62 | domain = domain[0] 63 | 64 | return domain if self.return_domain else index, img, target 65 | 66 | def __len__(self): 67 | return len(self.imgs) 68 | -------------------------------------------------------------------------------- /cassle/distillers/predictive_mse.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Any, List, Sequence 3 | 4 | import torch 5 | from torch import nn 6 | from cassle.distillers.base import base_distill_wrapper 7 | from cassle.losses.vicreg import invariance_loss 8 | 9 | 10 | def predictive_mse_distill_wrapper(Method=object): 11 | class PredictiveMSEDistillWrapper(base_distill_wrapper(Method)): 12 | def __init__(self, distill_lamb: float, distill_proj_hidden_dim, **kwargs): 13 | super().__init__(**kwargs) 14 | 15 | self.distill_lamb = distill_lamb 16 | output_dim = kwargs["output_dim"] 17 | 18 | self.distill_predictor = nn.Sequential( 19 | nn.Linear(output_dim, distill_proj_hidden_dim), 20 | nn.BatchNorm1d(distill_proj_hidden_dim), 21 | nn.ReLU(), 22 | nn.Linear(distill_proj_hidden_dim, output_dim), 23 | ) 24 | 25 | @staticmethod 26 | def add_model_specific_args( 27 | parent_parser: argparse.ArgumentParser, 28 | ) -> argparse.ArgumentParser: 29 | parser = parent_parser.add_argument_group("contrastive_distiller") 30 | 31 | parser.add_argument("--distill_lamb", type=float, default=25) 32 | parser.add_argument("--distill_proj_hidden_dim", type=int, default=2048) 33 | 34 | return parent_parser 35 | 36 | @property 37 | def learnable_params(self) -> List[dict]: 38 | """Adds distill predictor parameters to the parent's learnable parameters. 39 | 40 | Returns: 41 | List[dict]: list of learnable parameters. 42 | """ 43 | 44 | extra_learnable_params = [ 45 | {"params": self.distill_predictor.parameters()}, 46 | ] 47 | return super().learnable_params + extra_learnable_params 48 | 49 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: 50 | out = super().training_step(batch, batch_idx) 51 | z1, z2 = out["z"] 52 | frozen_z1, frozen_z2 = out["frozen_z"] 53 | 54 | p1 = self.distill_predictor(z1) 55 | p2 = self.distill_predictor(z2) 56 | 57 | distill_loss = (invariance_loss(p1, frozen_z1) + invariance_loss(p2, frozen_z2)) / 2 58 | 59 | self.log("train_predictive_distill_loss", distill_loss, on_epoch=True, sync_dist=True) 60 | 61 | return out["loss"] + self.distill_lamb * distill_loss 62 | 63 | return PredictiveMSEDistillWrapper 64 | -------------------------------------------------------------------------------- /cassle/args/dataset.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from pathlib import Path 3 | 4 | 5 | def dataset_args(parser: ArgumentParser): 6 | """Adds dataset-related arguments to a parser. 7 | 8 | Args: 9 | parser (ArgumentParser): parser to add dataset args to. 10 | """ 11 | 12 | SUPPORTED_DATASETS = [ 13 | "cifar10", 14 | "cifar100", 15 | "stl10", 16 | "imagenet", 17 | "imagenet100", 18 | "domainnet", 19 | "custom", 20 | ] 21 | 22 | parser.add_argument("--dataset", choices=SUPPORTED_DATASETS, type=str, required=True) 23 | 24 | # dataset path 25 | parser.add_argument("--data_dir", type=Path, required=True) 26 | parser.add_argument("--train_dir", type=Path, default=None) 27 | parser.add_argument("--val_dir", type=Path, default=None) 28 | 29 | # dali (imagenet-100/imagenet/custom only) 30 | parser.add_argument("--dali", action="store_true") 31 | parser.add_argument("--dali_device", type=str, default="gpu") 32 | 33 | # custom dataset only 34 | parser.add_argument("--no_labels", action="store_true") 35 | parser.add_argument("--semi_supervised", default=None, type=float) 36 | 37 | 38 | def augmentations_args(parser: ArgumentParser): 39 | """Adds augmentation-related arguments to a parser. 40 | 41 | Args: 42 | parser (ArgumentParser): parser to add augmentation args to. 43 | """ 44 | 45 | # cropping 46 | parser.add_argument("--multicrop", action="store_true") 47 | parser.add_argument("--num_crops", type=int, default=2) 48 | parser.add_argument("--num_small_crops", type=int, default=0) 49 | 50 | # augmentations 51 | parser.add_argument("--brightness", type=float, required=True, nargs="+") 52 | parser.add_argument("--contrast", type=float, required=True, nargs="+") 53 | parser.add_argument("--saturation", type=float, required=True, nargs="+") 54 | parser.add_argument("--hue", type=float, required=True, nargs="+") 55 | parser.add_argument("--gaussian_prob", type=float, default=[0.5], nargs="+") 56 | parser.add_argument("--solarization_prob", type=float, default=[0.0], nargs="+") 57 | parser.add_argument("--min_scale", type=float, default=[0.08], nargs="+") 58 | 59 | # for imagenet or custom dataset 60 | parser.add_argument("--size", type=int, default=[224], nargs="+") 61 | 62 | # for custom dataset 63 | parser.add_argument("--mean", type=float, default=[0.485, 0.456, 0.406], nargs="+") 64 | parser.add_argument("--std", type=float, default=[0.228, 0.224, 0.225], nargs="+") 65 | 66 | # debug 67 | parser.add_argument("--debug_augmentations", action="store_true") 68 | -------------------------------------------------------------------------------- /cassle/utils/sinkhorn_knopp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | 4 | 5 | class SinkhornKnopp(torch.nn.Module): 6 | def __init__(self, num_iters: int = 3, epsilon: float = 0.05, world_size: int = 1): 7 | """Approximates optimal transport using the Sinkhorn-Knopp algorithm. 8 | 9 | A simple iterative method to approach the double stochastic matrix is to alternately rescale 10 | rows and columns of the matrix to sum to 1. 11 | 12 | Args: 13 | num_iters (int, optional): number of times to perform row and column normalization. 14 | Defaults to 3. 15 | epsilon (float, optional): weight for the entropy regularization term. Defaults to 0.05. 16 | world_size (int, optional): number of nodes for distributed training. Defaults to 1. 17 | """ 18 | 19 | super().__init__() 20 | self.num_iters = num_iters 21 | self.epsilon = epsilon 22 | self.world_size = world_size 23 | 24 | @torch.no_grad() 25 | def forward(self, Q: torch.Tensor) -> torch.Tensor: 26 | """Produces assignments using Sinkhorn-Knopp algorithm. 27 | 28 | Applies the entropy regularization, normalizes the Q matrix and then normalizes rows and 29 | columns in an alternating fashion for num_iter times. Before returning it normalizes again 30 | the columns in order for the output to be an assignment of samples to prototypes. 31 | 32 | Args: 33 | Q (torch.Tensor): cosine similarities between the features of the 34 | samples and the prototypes. 35 | 36 | Returns: 37 | torch.Tensor: assignment of samples to prototypes according to optimal transport. 38 | """ 39 | 40 | Q = torch.exp(Q / self.epsilon).t() 41 | B = Q.shape[1] * self.world_size 42 | K = Q.shape[0] # num prototypes 43 | 44 | # make the matrix sums to 1 45 | sum_Q = torch.sum(Q) 46 | if dist.is_available() and dist.is_initialized(): 47 | dist.all_reduce(sum_Q) 48 | Q /= sum_Q 49 | 50 | for it in range(self.num_iters): 51 | # normalize each row: total weight per prototype must be 1/K 52 | sum_of_rows = torch.sum(Q, dim=1, keepdim=True) 53 | if dist.is_available() and dist.is_initialized(): 54 | dist.all_reduce(sum_of_rows) 55 | Q /= sum_of_rows 56 | Q /= K 57 | 58 | # normalize each column: total weight per sample must be 1/B 59 | Q /= torch.sum(Q, dim=0, keepdim=True) 60 | Q /= B 61 | 62 | Q *= B # the colomns must sum to 1 so that Q is an assignment 63 | return Q.t() 64 | -------------------------------------------------------------------------------- /cassle/utils/momentum.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | @torch.no_grad() 8 | def initialize_momentum_params(online_net: nn.Module, momentum_net: nn.Module): 9 | """Copies the parameters of the online network to the momentum network. 10 | 11 | Args: 12 | online_net (nn.Module): online network (e.g. online encoder, online projection, etc...). 13 | momentum_net (nn.Module): momentum network (e.g. momentum encoder, 14 | momentum projection, etc...). 15 | """ 16 | 17 | params_online = online_net.parameters() 18 | params_momentum = momentum_net.parameters() 19 | for po, pm in zip(params_online, params_momentum): 20 | pm.data.copy_(po.data) 21 | pm.requires_grad = False 22 | 23 | 24 | class MomentumUpdater: 25 | def __init__(self, base_tau: float = 0.996, final_tau: float = 1.0): 26 | """Updates momentum parameters using exponential moving average. 27 | 28 | Args: 29 | base_tau (float, optional): base value of the weight decrease coefficient 30 | (should be in [0,1]). Defaults to 0.996. 31 | final_tau (float, optional): final value of the weight decrease coefficient 32 | (should be in [0,1]). Defaults to 1.0. 33 | """ 34 | 35 | super().__init__() 36 | 37 | assert 0 <= base_tau <= 1 38 | assert 0 <= final_tau <= 1 and base_tau <= final_tau 39 | 40 | self.base_tau = base_tau 41 | self.cur_tau = base_tau 42 | self.final_tau = final_tau 43 | 44 | @torch.no_grad() 45 | def update(self, online_net: nn.Module, momentum_net: nn.Module): 46 | """Performs the momentum update for each param group. 47 | 48 | Args: 49 | online_net (nn.Module): online network (e.g. online encoder, online projection, etc...). 50 | momentum_net (nn.Module): momentum network (e.g. momentum encoder, 51 | momentum projection, etc...). 52 | """ 53 | 54 | for op, mp in zip(online_net.parameters(), momentum_net.parameters()): 55 | mp.data = self.cur_tau * mp.data + (1 - self.cur_tau) * op.data 56 | 57 | def update_tau(self, cur_step: int, max_steps: int): 58 | """Computes the next value for the weighting decrease coefficient tau using cosine annealing. 59 | 60 | Args: 61 | cur_step (int): number of gradient steps so far. 62 | max_steps (int): overall number of gradient steps in the whole training. 63 | """ 64 | 65 | self.cur_tau = ( 66 | self.final_tau 67 | - (self.final_tau - self.base_tau) * (math.cos(math.pi * cur_step / max_steps) + 1) / 2 68 | ) 69 | -------------------------------------------------------------------------------- /cassle/distillers/predictive.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Any, List, Sequence 3 | 4 | import torch 5 | from torch import nn 6 | from cassle.distillers.base import base_distill_wrapper 7 | from cassle.losses.byol import byol_loss_func 8 | 9 | 10 | def predictive_distill_wrapper(Method=object): 11 | class PredictiveDistillWrapper(base_distill_wrapper(Method)): 12 | def __init__(self, distill_lamb: float, distill_proj_hidden_dim, **kwargs): 13 | super().__init__(**kwargs) 14 | 15 | self.distill_lamb = distill_lamb 16 | output_dim = kwargs["output_dim"] 17 | 18 | self.distill_predictor = nn.Sequential( 19 | nn.Linear(output_dim, distill_proj_hidden_dim), 20 | nn.BatchNorm1d(distill_proj_hidden_dim), 21 | nn.ReLU(), 22 | nn.Linear(distill_proj_hidden_dim, output_dim), 23 | ) 24 | 25 | @staticmethod 26 | def add_model_specific_args( 27 | parent_parser: argparse.ArgumentParser, 28 | ) -> argparse.ArgumentParser: 29 | parser = parent_parser.add_argument_group("contrastive_distiller") 30 | 31 | parser.add_argument("--distill_lamb", type=float, default=1) 32 | parser.add_argument("--distill_proj_hidden_dim", type=int, default=2048) 33 | 34 | return parent_parser 35 | 36 | @property 37 | def learnable_params(self) -> List[dict]: 38 | """Adds distill predictor parameters to the parent's learnable parameters. 39 | 40 | Returns: 41 | List[dict]: list of learnable parameters. 42 | """ 43 | 44 | extra_learnable_params = [ 45 | { 46 | "name": "distill_predictor", 47 | "params": self.distill_predictor.parameters(), 48 | "lr": self.lr if self.distill_lamb >= 1 else self.lr / self.distill_lamb, 49 | "weight_decay": self.weight_decay, 50 | }, 51 | ] 52 | return super().learnable_params + extra_learnable_params 53 | 54 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: 55 | out = super().training_step(batch, batch_idx) 56 | z1, z2 = out["z"] 57 | frozen_z1, frozen_z2 = out["frozen_z"] 58 | 59 | p1 = self.distill_predictor(z1) 60 | p2 = self.distill_predictor(z2) 61 | 62 | distill_loss = (byol_loss_func(p1, frozen_z1) + byol_loss_func(p2, frozen_z2)) / 2 63 | 64 | self.log("train_predictive_distill_loss", distill_loss, on_epoch=True, sync_dist=True) 65 | 66 | return out["loss"] + self.distill_lamb * distill_loss 67 | 68 | return PredictiveDistillWrapper 69 | -------------------------------------------------------------------------------- /cassle/distillers/contrastive.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Any, List, Sequence 3 | 4 | import torch 5 | from torch import nn 6 | from cassle.distillers.base import base_distill_wrapper 7 | from cassle.losses.simclr import simclr_distill_loss_func 8 | 9 | 10 | def contrastive_distill_wrapper(Method=object): 11 | class ContrastiveDistillWrapper(base_distill_wrapper(Method)): 12 | def __init__( 13 | self, 14 | distill_lamb: float, 15 | distill_proj_hidden_dim: int, 16 | distill_temperature: float, 17 | **kwargs 18 | ): 19 | super().__init__(**kwargs) 20 | 21 | self.distill_lamb = distill_lamb 22 | self.distill_temperature = distill_temperature 23 | output_dim = kwargs["output_dim"] 24 | 25 | self.distill_predictor = nn.Sequential( 26 | nn.Linear(output_dim, distill_proj_hidden_dim), 27 | nn.BatchNorm1d(distill_proj_hidden_dim), 28 | nn.ReLU(), 29 | nn.Linear(distill_proj_hidden_dim, output_dim), 30 | ) 31 | 32 | @staticmethod 33 | def add_model_specific_args( 34 | parent_parser: argparse.ArgumentParser, 35 | ) -> argparse.ArgumentParser: 36 | parser = parent_parser.add_argument_group("contrastive_distiller") 37 | 38 | parser.add_argument("--distill_lamb", type=float, default=1) 39 | parser.add_argument("--distill_proj_hidden_dim", type=int, default=2048) 40 | parser.add_argument("--distill_temperature", type=float, default=0.2) 41 | 42 | return parent_parser 43 | 44 | @property 45 | def learnable_params(self) -> List[dict]: 46 | """Adds distill predictor parameters to the parent's learnable parameters. 47 | 48 | Returns: 49 | List[dict]: list of learnable parameters. 50 | """ 51 | 52 | extra_learnable_params = [ 53 | {"params": self.distill_predictor.parameters()}, 54 | ] 55 | return super().learnable_params + extra_learnable_params 56 | 57 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: 58 | out = super().training_step(batch, batch_idx) 59 | z1, z2 = out["z"] 60 | frozen_z1, frozen_z2 = out["frozen_z"] 61 | 62 | p1 = self.distill_predictor(z1) 63 | p2 = self.distill_predictor(z2) 64 | 65 | distill_loss = ( 66 | simclr_distill_loss_func(p1, p2, frozen_z1, frozen_z2, self.distill_temperature) 67 | + simclr_distill_loss_func(frozen_z1, frozen_z2, p1, p2, self.distill_temperature) 68 | ) / 2 69 | 70 | self.log("train_contrastive_distill_loss", distill_loss, on_epoch=True, sync_dist=True) 71 | 72 | return out["loss"] + self.distill_lamb * distill_loss 73 | 74 | return ContrastiveDistillWrapper 75 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # checkpoint 2 | last_checkpoint.txt 3 | 4 | # tensorboard dir 5 | runs/ 6 | 7 | # wandb dir 8 | wandb/ 9 | wandb*/ 10 | 11 | # umap dir 12 | auto_umap/ 13 | 14 | # datasets dir 15 | datasets/ 16 | # saved models 17 | *.pt 18 | *.pth 19 | *.ckpt 20 | *.tar 21 | 22 | *.png 23 | !assets/*.png 24 | *.jpg 25 | *.jpeg 26 | 27 | saved_models/ 28 | model_storage/ 29 | model_storage*/ 30 | lightning_logs/ 31 | 32 | *.json 33 | 34 | *logs*/ 35 | 36 | # Created by https://www.gitignore.io/api/python,visualstudiocode 37 | # Edit at https://www.gitignore.io/?templates=python,visualstudiocode 38 | 39 | ### Python ### 40 | # Byte-compiled / optimized / DLL files 41 | __pycache__/ 42 | *.py[cod] 43 | *$py.class 44 | 45 | # C extensions 46 | *.so 47 | 48 | # Distribution / packaging 49 | .Python 50 | build/ 51 | develop-eggs/ 52 | dist/ 53 | downloads/ 54 | eggs/ 55 | .eggs/ 56 | lib/ 57 | lib64/ 58 | parts/ 59 | sdist/ 60 | var/ 61 | wheels/ 62 | pip-wheel-metadata/ 63 | share/python-wheels/ 64 | *.egg-info/ 65 | .installed.cfg 66 | *.egg 67 | MANIFEST 68 | 69 | # PyInstaller 70 | # Usually these files are written by a python script from a template 71 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 72 | *.manifest 73 | *.spec 74 | 75 | # Installer logs 76 | pip-log.txt 77 | pip-delete-this-directory.txt 78 | 79 | # Unit test / coverage reports 80 | htmlcov/ 81 | .tox/ 82 | .nox/ 83 | .coverage 84 | .coverage.* 85 | .cache 86 | nosetests.xml 87 | coverage.xml 88 | *.cover 89 | .hypothesis/ 90 | .pytest_cache/ 91 | 92 | # Translations 93 | *.mo 94 | *.pot 95 | 96 | # Scrapy stuff: 97 | .scrapy 98 | 99 | # Sphinx documentation 100 | docs/_build/ 101 | 102 | # PyBuilder 103 | target/ 104 | 105 | # pyenv 106 | .python-version 107 | 108 | # pipenv 109 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 110 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 111 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 112 | # install all needed dependencies. 113 | #Pipfile.lock 114 | 115 | # celery beat schedule file 116 | celerybeat-schedule 117 | 118 | # SageMath parsed files 119 | *.sage.py 120 | 121 | # Spyder project settings 122 | .spyderproject 123 | .spyproject 124 | 125 | # Rope project settings 126 | .ropeproject 127 | 128 | # Mr Developer 129 | .mr.developer.cfg 130 | .project 131 | .pydevproject 132 | 133 | # mkdocs documentation 134 | /site 135 | 136 | # mypy 137 | .mypy_cache/ 138 | .dmypy.json 139 | dmypy.json 140 | 141 | # Pyre type checker 142 | .pyre/ 143 | 144 | ### VisualStudioCode ### 145 | .vscode/* 146 | !.vscode/settings.json 147 | !.vscode/tasks.json 148 | !.vscode/launch.json 149 | !.vscode/extensions.json 150 | 151 | ### VisualStudioCode Patch ### 152 | # Ignore all local history of files 153 | .history 154 | 155 | # End of https://www.gitignore.io/api/python,visualstudiocode 156 | -------------------------------------------------------------------------------- /main_continual.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import itertools 3 | import subprocess 4 | import sys 5 | import os 6 | import json 7 | 8 | 9 | def str_to_dict(command): 10 | d = {} 11 | for part, part_next in itertools.zip_longest(command[:-1], command[1:]): 12 | if part[:2] == "--": 13 | if part_next[:2] != "--": 14 | d[part] = part_next 15 | else: 16 | d[part] = part 17 | elif part[:2] != "--" and part_next[:2] != "--": 18 | part_prev = list(d.keys())[-1] 19 | if not isinstance(d[part_prev], list): 20 | d[part_prev] = [d[part_prev]] 21 | if not part_next[:2] == "--": 22 | d[part_prev].append(part_next) 23 | return d 24 | 25 | 26 | def dict_to_list(command): 27 | s = [] 28 | for k, v in command.items(): 29 | s.append(k) 30 | if k != v and v[:2] != "--": 31 | s.append(v) 32 | return s 33 | 34 | 35 | def run_bash_command(args): 36 | for i, a in enumerate(args): 37 | if isinstance(a, list): 38 | args[i] = " ".join(a) 39 | command = ("python3 main_pretrain.py", *args) 40 | command = " ".join(command) 41 | p = subprocess.Popen(command, shell=True) 42 | p.wait() 43 | 44 | 45 | if __name__ == "__main__": 46 | args = sys.argv[1:] 47 | args = str_to_dict(args) 48 | 49 | # parse args from the script 50 | num_tasks = int(args["--num_tasks"]) 51 | start_task_idx = int(args.get("--task_idx", 0)) 52 | distill_args = {k: v for k, v in args.items() if "distill" in k} 53 | 54 | # delete things that shouldn't be used for task_idx 0 55 | args.pop("--task_idx", None) 56 | for k in distill_args.keys(): 57 | args.pop(k, None) 58 | 59 | # check if this experiment is being resumed 60 | # look for the file last_checkpoint.txt 61 | last_checkpoint_file = os.path.join(args["--checkpoint_dir"], "last_checkpoint.txt") 62 | if os.path.exists(last_checkpoint_file): 63 | with open(last_checkpoint_file) as f: 64 | ckpt_path, args_path = [line.rstrip() for line in f.readlines()] 65 | start_task_idx = json.load(open(args_path))["task_idx"] 66 | args["--resume_from_checkpoint"] = ckpt_path 67 | 68 | # main task loop 69 | for task_idx in range(start_task_idx, num_tasks): 70 | print(f"\n#### Starting Task {task_idx} ####") 71 | 72 | task_args = copy.deepcopy(args) 73 | 74 | # add pretrained model arg 75 | if task_idx != 0 and task_idx != start_task_idx: 76 | task_args.pop("--resume_from_checkpoint", None) 77 | task_args.pop("--pretrained_model", None) 78 | assert os.path.exists(last_checkpoint_file) 79 | ckpt_path = open(last_checkpoint_file).readlines()[0].rstrip() 80 | task_args["--pretrained_model"] = ckpt_path 81 | 82 | if task_idx != 0 and distill_args: 83 | task_args.update(distill_args) 84 | 85 | task_args["--task_idx"] = str(task_idx) 86 | task_args = dict_to_list(task_args) 87 | 88 | run_bash_command(task_args) 89 | -------------------------------------------------------------------------------- /cassle/losses/vicreg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def invariance_loss(z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor: 6 | """Computes mse loss given batch of projected features z1 from view 1 and 7 | projected features z2 from view 2. 8 | 9 | Args: 10 | z1 (torch.Tensor): NxD Tensor containing projected features from view 1. 11 | z2 (torch.Tensor): NxD Tensor containing projected features from view 2. 12 | 13 | Returns: 14 | torch.Tensor: invariance loss (mean squared error). 15 | """ 16 | 17 | return F.mse_loss(z1, z2) 18 | 19 | 20 | def variance_loss(z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor: 21 | """Computes variance loss given batch of projected features z1 from view 1 and 22 | projected features z2 from view 2. 23 | 24 | Args: 25 | z1 (torch.Tensor): NxD Tensor containing projected features from view 1. 26 | z2 (torch.Tensor): NxD Tensor containing projected features from view 2. 27 | 28 | Returns: 29 | torch.Tensor: variance regularization loss. 30 | """ 31 | 32 | eps = 1e-4 33 | std_z1 = torch.sqrt(z1.var(dim=0) + eps) 34 | std_z2 = torch.sqrt(z2.var(dim=0) + eps) 35 | std_loss = torch.mean(F.relu(1 - std_z1)) + torch.mean(F.relu(1 - std_z2)) 36 | return std_loss 37 | 38 | 39 | def covariance_loss(z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor: 40 | """Computes covariance loss given batch of projected features z1 from view 1 and 41 | projected features z2 from view 2. 42 | 43 | Args: 44 | z1 (torch.Tensor): NxD Tensor containing projected features from view 1. 45 | z2 (torch.Tensor): NxD Tensor containing projected features from view 2. 46 | 47 | Returns: 48 | torch.Tensor: covariance regularization loss. 49 | """ 50 | 51 | N, D = z1.size() 52 | 53 | z1 = z1 - z1.mean(dim=0) 54 | z2 = z2 - z2.mean(dim=0) 55 | cov_z1 = (z1.T @ z1) / (N - 1) 56 | cov_z2 = (z2.T @ z2) / (N - 1) 57 | 58 | diag = torch.eye(D, device=z1.device) 59 | cov_loss = cov_z1[~diag.bool()].pow_(2).sum() / D + cov_z2[~diag.bool()].pow_(2).sum() / D 60 | return cov_loss 61 | 62 | 63 | def vicreg_loss_func( 64 | z1: torch.Tensor, 65 | z2: torch.Tensor, 66 | sim_loss_weight: float = 25.0, 67 | var_loss_weight: float = 25.0, 68 | cov_loss_weight: float = 1.0, 69 | ) -> torch.Tensor: 70 | """Computes VICReg's loss given batch of projected features z1 from view 1 and 71 | projected features z2 from view 2. 72 | 73 | Args: 74 | z1 (torch.Tensor): NxD Tensor containing projected features from view 1. 75 | z2 (torch.Tensor): NxD Tensor containing projected features from view 2. 76 | sim_loss_weight (float): invariance loss weight. 77 | var_loss_weight (float): variance loss weight. 78 | cov_loss_weight (float): covariance loss weight. 79 | 80 | Returns: 81 | torch.Tensor: VICReg loss. 82 | """ 83 | 84 | sim_loss = invariance_loss(z1, z2) 85 | var_loss = variance_loss(z1, z2) 86 | cov_loss = covariance_loss(z1, z2) 87 | 88 | loss = sim_loss_weight * sim_loss + var_loss_weight * var_loss + cov_loss_weight * cov_loss 89 | return loss 90 | -------------------------------------------------------------------------------- /cassle/distillers/decorrelative.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Any, List, Sequence 3 | 4 | import torch 5 | from torch import nn 6 | from cassle.distillers.base import base_distill_wrapper 7 | from cassle.losses.barlow import barlow_loss_func 8 | 9 | 10 | def decorrelative_distill_wrapper(Method=object): 11 | class DecorrelativeDistillWrapper(base_distill_wrapper(Method)): 12 | def __init__( 13 | self, 14 | distill_lamb: float, 15 | distill_proj_hidden_dim: int, 16 | distill_barlow_lamb: float, 17 | distill_scale_loss: float, 18 | **kwargs 19 | ): 20 | super().__init__(**kwargs) 21 | 22 | output_dim = kwargs["output_dim"] 23 | self.distill_lamb = distill_lamb 24 | self.distill_barlow_lamb = distill_barlow_lamb 25 | self.distill_scale_loss = distill_scale_loss 26 | 27 | self.distill_predictor = nn.Sequential( 28 | nn.Linear(output_dim, distill_proj_hidden_dim), 29 | nn.BatchNorm1d(distill_proj_hidden_dim), 30 | nn.ReLU(), 31 | nn.Linear(distill_proj_hidden_dim, output_dim), 32 | ) 33 | 34 | @staticmethod 35 | def add_model_specific_args( 36 | parent_parser: argparse.ArgumentParser, 37 | ) -> argparse.ArgumentParser: 38 | parser = parent_parser.add_argument_group("contrastive_distiller") 39 | 40 | parser.add_argument("--distill_lamb", type=float, default=1) 41 | parser.add_argument("--distill_proj_hidden_dim", type=int, default=2048) 42 | parser.add_argument("--distill_barlow_lamb", type=float, default=5e-3) 43 | parser.add_argument("--distill_scale_loss", type=float, default=0.1) 44 | 45 | return parent_parser 46 | 47 | @property 48 | def learnable_params(self) -> List[dict]: 49 | """Adds distill predictor parameters to the parent's learnable parameters. 50 | 51 | Returns: 52 | List[dict]: list of learnable parameters. 53 | """ 54 | 55 | extra_learnable_params = [ 56 | {"params": self.distill_predictor.parameters()}, 57 | ] 58 | return super().learnable_params + extra_learnable_params 59 | 60 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: 61 | out = super().training_step(batch, batch_idx) 62 | z1, z2 = out["z"] 63 | frozen_z1, frozen_z2 = out["frozen_z"] 64 | 65 | p1 = self.distill_predictor(z1) 66 | p2 = self.distill_predictor(z2) 67 | 68 | distill_loss = ( 69 | barlow_loss_func( 70 | p1, 71 | frozen_z1, 72 | lamb=self.distill_barlow_lamb, 73 | scale_loss=self.distill_scale_loss, 74 | ) 75 | + barlow_loss_func( 76 | p2, 77 | frozen_z2, 78 | lamb=self.distill_barlow_lamb, 79 | scale_loss=self.distill_scale_loss, 80 | ) 81 | ) / 2 82 | 83 | self.log( 84 | "train_decorrelative_distill_loss", distill_loss, on_epoch=True, sync_dist=True 85 | ) 86 | 87 | return out["loss"] + self.distill_lamb * distill_loss 88 | 89 | return DecorrelativeDistillWrapper 90 | -------------------------------------------------------------------------------- /job_launcher.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import subprocess 4 | import argparse 5 | from datetime import datetime 6 | import inspect 7 | 8 | from main_continual import str_to_dict 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--script", type=str, required=True) 12 | parser.add_argument("--mode", type=str, default="normal") 13 | parser.add_argument("--experiment_dir", type=str, default=None) 14 | parser.add_argument("--base_experiment_dir", type=str, default="./experiments") 15 | parser.add_argument("--gpu", type=str, default="v100-16g") 16 | parser.add_argument("--num_gpus", type=int, default=2) 17 | parser.add_argument("--hours", type=int, default=20) 18 | parser.add_argument("--requeue", type=int, default=0) 19 | 20 | args = parser.parse_args() 21 | 22 | # load file 23 | if os.path.exists(args.script): 24 | with open(args.script) as f: 25 | command = [line.strip().strip("\\").strip() for line in f.readlines()] 26 | else: 27 | print(f"{args.script} does not exist.") 28 | exit() 29 | 30 | assert ( 31 | "--checkpoint_dir" not in command 32 | ), "Please remove the --checkpoint_dir argument, it will be added automatically" 33 | 34 | # collect args 35 | command_args = str_to_dict(" ".join(command).split(" ")[2:]) 36 | 37 | # create experiment directory 38 | if args.experiment_dir is None: 39 | args.experiment_dir = datetime.now().strftime("%Y_%m_%d_%H_%M_%S") 40 | args.experiment_dir += f"-{command_args['--name']}" 41 | full_experiment_dir = os.path.join(args.base_experiment_dir, args.experiment_dir) 42 | os.makedirs(full_experiment_dir, exist_ok=True) 43 | print(f"Experiment directory: {full_experiment_dir}") 44 | 45 | # add experiment directory to the command 46 | command.extend(["--checkpoint_dir", full_experiment_dir]) 47 | command = " ".join(command) 48 | 49 | # run command 50 | if args.mode == "normal": 51 | p = subprocess.Popen(command, shell=True, stdout=sys.stdout, stderr=sys.stdout) 52 | p.wait() 53 | 54 | elif args.mode == "slurm": 55 | # infer qos 56 | if 0 <= args.hours <= 2: 57 | qos = "qos_gpu-dev" 58 | elif args.hours <= 20: 59 | qos = "qos_gpu-t3" 60 | elif args.hours <= 100: 61 | qos = "qos_gpu-t4" 62 | 63 | # build slurm command 64 | command = inspect.cleandoc( 65 | f""" 66 | #!/bin/bash 67 | #SBATCH --job-name {command_args['--name']} 68 | #SBATCH -C {args.gpu} 69 | #SBATCH --qos {qos} 70 | #SBATCH --nodes=1 71 | #SBATCH --gres gpu:{args.num_gpus} 72 | #SBATCH --cpus-per-task {int(int(command_args['--num_workers']) * 2 * args.num_gpus)} 73 | #SBATCH --hint nomultithread 74 | #SBATCH --time {args.hours}:00:00 75 | #SBATCH --output outs/{command_args['--name']}.out 76 | #SBATCH --error outs/{command_args['--name']}.err 77 | #SBATCH -a 0-{args.requeue}%1 78 | 79 | # cleans out modules loaded in interactive and inherited by default 80 | module purge 81 | 82 | # loading conda env 83 | source ~/.bashrc 84 | conda activate cassle 85 | 86 | # echo of launched commands 87 | set -x 88 | 89 | cd $WORK/cassle 90 | 91 | # code execution 92 | {command} 93 | """ 94 | ) 95 | 96 | # write command 97 | command_path = os.path.join(full_experiment_dir, "command.sh") 98 | with open(command_path, "w") as f: 99 | f.write(command) 100 | 101 | # run command 102 | p = subprocess.Popen(f"sbatch {command_path}", shell=True, stdout=sys.stdout, stderr=sys.stdout) 103 | p.wait() 104 | -------------------------------------------------------------------------------- /cassle/methods/barlow_twins.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Any, List, Sequence 3 | 4 | import torch 5 | import torch.nn as nn 6 | from cassle.losses.barlow import barlow_loss_func 7 | from cassle.methods.base import BaseModel 8 | 9 | 10 | class BarlowTwins(BaseModel): 11 | def __init__( 12 | self, proj_hidden_dim: int, output_dim: int, lamb: float, scale_loss: float, **kwargs 13 | ): 14 | """Implements Barlow Twins (https://arxiv.org/abs/2103.03230) 15 | 16 | Args: 17 | proj_hidden_dim (int): number of neurons of the hidden layers of the projector. 18 | output_dim (int): number of dimensions of projected features. 19 | lamb (float): off-diagonal scaling factor for the cross-covariance matrix. 20 | scale_loss (float): scaling factor of the loss. 21 | """ 22 | 23 | super().__init__(**kwargs) 24 | 25 | self.lamb = lamb 26 | self.scale_loss = scale_loss 27 | 28 | # projector 29 | self.projector = nn.Sequential( 30 | nn.Linear(self.features_dim, proj_hidden_dim), 31 | nn.BatchNorm1d(proj_hidden_dim), 32 | nn.ReLU(), 33 | nn.Linear(proj_hidden_dim, proj_hidden_dim), 34 | nn.BatchNorm1d(proj_hidden_dim), 35 | nn.ReLU(), 36 | nn.Linear(proj_hidden_dim, output_dim), 37 | ) 38 | 39 | @staticmethod 40 | def add_model_specific_args(parent_parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 41 | parent_parser = super(BarlowTwins, BarlowTwins).add_model_specific_args(parent_parser) 42 | parser = parent_parser.add_argument_group("barlow_twins") 43 | 44 | # projector 45 | parser.add_argument("--output_dim", type=int, default=2048) 46 | parser.add_argument("--proj_hidden_dim", type=int, default=2048) 47 | 48 | # parameters 49 | parser.add_argument("--lamb", type=float, default=5e-3) 50 | parser.add_argument("--scale_loss", type=float, default=0.025) 51 | return parent_parser 52 | 53 | @property 54 | def learnable_params(self) -> List[dict]: 55 | """Adds projector parameters to parent's learnable parameters. 56 | 57 | Returns: 58 | List[dict]: list of learnable parameters. 59 | """ 60 | 61 | extra_learnable_params = [{"params": self.projector.parameters()}] 62 | return super().learnable_params + extra_learnable_params 63 | 64 | def forward(self, X, *args, **kwargs): 65 | out = super().forward(X, *args, **kwargs) 66 | z = self.projector(out["feats"]) 67 | return {**out, "z": z} 68 | 69 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: 70 | """Training step for Barlow Twins reusing BaseModel training step. 71 | 72 | Args: 73 | batch (Sequence[Any]): a batch of data in the format of [img_indexes, [X], Y], where 74 | [X] is a list of size self.num_crops containing batches of images. 75 | batch_idx (int): index of the batch. 76 | 77 | Returns: 78 | torch.Tensor: total loss composed of Barlow loss and classification loss. 79 | """ 80 | 81 | out = super().training_step(batch, batch_idx) 82 | 83 | feats1, feats2 = out["feats"] 84 | 85 | z1 = self.projector(feats1) 86 | z2 = self.projector(feats2) 87 | 88 | # ------- barlow twins loss ------- 89 | barlow_loss = barlow_loss_func(z1, z2, lamb=self.lamb, scale_loss=self.scale_loss) 90 | 91 | self.log("train_barlow_loss", barlow_loss, on_epoch=True, sync_dist=True) 92 | 93 | out.update({"loss": out["loss"] + barlow_loss, "z": [z1, z2]}) 94 | return out 95 | -------------------------------------------------------------------------------- /cassle/utils/lars.py: -------------------------------------------------------------------------------- 1 | """ 2 | References: 3 | - https://github.com/NVIDIA/apex/blob/master/apex/parallel/LARC.py 4 | - https://arxiv.org/pdf/1708.03888.pdf 5 | - https://github.com/noahgolmant/pytorch-lars/blob/master/lars.py 6 | """ 7 | import torch 8 | from torch.optim import Optimizer 9 | 10 | 11 | class LARSWrapper: 12 | def __init__( 13 | self, 14 | optimizer: Optimizer, 15 | eta: float = 1e-3, 16 | clip: bool = False, 17 | eps: float = 1e-8, 18 | exclude_bias_n_norm: bool = False, 19 | ): 20 | """Wrapper that adds LARS scheduling to any optimizer. 21 | This helps stability with huge batch sizes. 22 | 23 | Args: 24 | optimizer (Optimizer): torch optimizer. 25 | eta (float, optional): trust coefficient. Defaults to 1e-3. 26 | clip (bool, optional): clip gradient values. Defaults to False. 27 | eps (float, optional): adaptive_lr stability coefficient. Defaults to 1e-8. 28 | exclude_bias_n_norm (bool, optional): exclude bias and normalization layers from lars. 29 | Defaults to False. 30 | """ 31 | 32 | self.optim = optimizer 33 | self.eta = eta 34 | self.eps = eps 35 | self.clip = clip 36 | self.exclude_bias_n_norm = exclude_bias_n_norm 37 | 38 | # transfer optim methods 39 | self.state_dict = self.optim.state_dict 40 | self.load_state_dict = self.optim.load_state_dict 41 | self.zero_grad = self.optim.zero_grad 42 | self.add_param_group = self.optim.add_param_group 43 | 44 | self.__setstate__ = self.optim.__setstate__ # type: ignore 45 | self.__getstate__ = self.optim.__getstate__ # type: ignore 46 | self.__repr__ = self.optim.__repr__ # type: ignore 47 | 48 | @property 49 | def defaults(self): 50 | return self.optim.defaults 51 | 52 | @defaults.setter 53 | def defaults(self, defaults): 54 | self.optim.defaults = defaults 55 | 56 | @property # type: ignore 57 | def __class__(self): 58 | return Optimizer 59 | 60 | @property 61 | def state(self): 62 | return self.optim.state 63 | 64 | @state.setter 65 | def state(self, state): 66 | self.optim.state = state 67 | 68 | @property 69 | def param_groups(self): 70 | return self.optim.param_groups 71 | 72 | @param_groups.setter 73 | def param_groups(self, value): 74 | self.optim.param_groups = value 75 | 76 | @torch.no_grad() 77 | def step(self, closure=None): 78 | weight_decays = [] 79 | 80 | for group in self.optim.param_groups: 81 | weight_decay = group.get("weight_decay", 0) 82 | weight_decays.append(weight_decay) 83 | 84 | # reset weight decay 85 | group["weight_decay"] = 0 86 | 87 | # update the parameters 88 | for p in group["params"]: 89 | if p.grad is not None and (p.ndim != 1 or not self.exclude_bias_n_norm): 90 | self.update_p(p, group, weight_decay) 91 | 92 | # update the optimizer 93 | self.optim.step(closure=closure) 94 | 95 | # return weight decay control to optimizer 96 | for group_idx, group in enumerate(self.optim.param_groups): 97 | group["weight_decay"] = weight_decays[group_idx] 98 | 99 | def update_p(self, p, group, weight_decay): 100 | # calculate new norms 101 | p_norm = torch.norm(p.data) 102 | g_norm = torch.norm(p.grad.data) 103 | 104 | if p_norm != 0 and g_norm != 0: 105 | # calculate new lr 106 | new_lr = (self.eta * p_norm) / (g_norm + p_norm * weight_decay + self.eps) 107 | 108 | # clip lr 109 | if self.clip: 110 | new_lr = min(new_lr / group["lr"], 1) 111 | 112 | # update params with clipped lr 113 | p.grad.data += weight_decay * p.data 114 | p.grad.data *= new_lr 115 | -------------------------------------------------------------------------------- /main_linear.py: -------------------------------------------------------------------------------- 1 | import os 2 | import types 3 | 4 | import torch 5 | import torch.nn as nn 6 | from pytorch_lightning import Trainer, seed_everything 7 | from pytorch_lightning.callbacks import LearningRateMonitor 8 | from pytorch_lightning.loggers import WandbLogger 9 | from pytorch_lightning.plugins import DDPPlugin 10 | from torchvision.models import resnet18, resnet50 11 | 12 | from cassle.args.setup import parse_args_linear 13 | 14 | try: 15 | from cassle.methods.dali import ClassificationABC 16 | except ImportError: 17 | _dali_avaliable = False 18 | else: 19 | _dali_avaliable = True 20 | from cassle.methods.linear import LinearModel 21 | from cassle.utils.classification_dataloader import prepare_data 22 | from cassle.utils.checkpointer import Checkpointer 23 | 24 | 25 | def main(): 26 | seed_everything(5) 27 | 28 | args = parse_args_linear() 29 | 30 | # split classes into tasks 31 | tasks = None 32 | if args.split_strategy == "class": 33 | assert args.num_classes % args.num_tasks == 0 34 | tasks = torch.randperm(args.num_classes).chunk(args.num_tasks) 35 | 36 | if args.encoder == "resnet18": 37 | backbone = resnet18() 38 | elif args.encoder == "resnet50": 39 | backbone = resnet50() 40 | else: 41 | raise ValueError("Only [resnet18, resnet50] are currently supported.") 42 | 43 | if args.cifar: 44 | backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=2, bias=False) 45 | backbone.maxpool = nn.Identity() 46 | backbone.fc = nn.Identity() 47 | 48 | assert ( 49 | args.pretrained_feature_extractor.endswith(".ckpt") 50 | or args.pretrained_feature_extractor.endswith(".pth") 51 | or args.pretrained_feature_extractor.endswith(".pt") 52 | ) 53 | ckpt_path = args.pretrained_feature_extractor 54 | 55 | state = torch.load(ckpt_path)["state_dict"] 56 | for k in list(state.keys()): 57 | if "encoder" in k: 58 | state[k.replace("encoder.", "")] = state[k] 59 | del state[k] 60 | backbone.load_state_dict(state, strict=False) 61 | 62 | print(f"Loaded {ckpt_path}") 63 | 64 | if args.dali: 65 | assert _dali_avaliable, "Dali is not currently avaiable, please install it first." 66 | MethodClass = types.new_class( 67 | f"Dali{LinearModel.__name__}", (ClassificationABC, LinearModel) 68 | ) 69 | else: 70 | MethodClass = LinearModel 71 | 72 | model = MethodClass(backbone, **args.__dict__, tasks=tasks) 73 | 74 | train_loader, val_loader = prepare_data( 75 | args.dataset, 76 | data_dir=args.data_dir, 77 | train_dir=args.train_dir, 78 | val_dir=args.val_dir, 79 | batch_size=args.batch_size, 80 | num_workers=args.num_workers, 81 | semi_supervised=args.semi_supervised, 82 | ) 83 | 84 | callbacks = [] 85 | 86 | # wandb logging 87 | if args.wandb: 88 | wandb_logger = WandbLogger( 89 | name=args.name, project=args.project, entity=args.entity, offline=args.offline 90 | ) 91 | wandb_logger.watch(model, log="gradients", log_freq=100) 92 | wandb_logger.log_hyperparams(args) 93 | 94 | # lr logging 95 | lr_monitor = LearningRateMonitor(logging_interval="epoch") 96 | callbacks.append(lr_monitor) 97 | 98 | # save checkpoint on last epoch only 99 | ckpt = Checkpointer( 100 | args, 101 | logdir=os.path.join(args.checkpoint_dir, "linear"), 102 | frequency=args.checkpoint_frequency, 103 | ) 104 | callbacks.append(ckpt) 105 | 106 | trainer = Trainer.from_argparse_args( 107 | args, 108 | logger=wandb_logger if args.wandb else None, 109 | callbacks=callbacks, 110 | plugins=DDPPlugin(find_unused_parameters=True), 111 | checkpoint_callback=False, 112 | terminate_on_nan=True, 113 | accelerator="ddp", 114 | ) 115 | if args.dali: 116 | trainer.fit(model, val_dataloaders=val_loader) 117 | else: 118 | trainer.fit(model, train_loader, val_loader) 119 | 120 | 121 | if __name__ == "__main__": 122 | main() 123 | -------------------------------------------------------------------------------- /cassle/losses/dino.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.distributed as dist 5 | import numpy as np 6 | 7 | 8 | class DINOLoss(nn.Module): 9 | def __init__( 10 | self, 11 | num_prototypes: int, 12 | warmup_teacher_temp: float, 13 | teacher_temp: float, 14 | warmup_teacher_temp_epochs: float, 15 | num_epochs: int, 16 | student_temp: float = 0.1, 17 | num_crops: int = 2, 18 | center_momentum: float = 0.9, 19 | ): 20 | """Auxiliary module to compute DINO's loss. 21 | 22 | Args: 23 | num_prototypes (int): number of prototypes. 24 | warmup_teacher_temp (float): base temperature for the temperature schedule 25 | of the teacher. 26 | teacher_temp (float): final temperature for the teacher. 27 | warmup_teacher_temp_epochs (float): number of epochs for the cosine annealing schedule. 28 | num_epochs (int): total number of epochs. 29 | student_temp (float, optional): temperature for the student. Defaults to 0.1. 30 | num_crops (int, optional): number of crops/views. Defaults to 2. 31 | center_momentum (float, optional): momentum for the EMA update of the center of 32 | mass of the teacher. Defaults to 0.9. 33 | """ 34 | 35 | super().__init__() 36 | self.epoch = 0 37 | self.student_temp = student_temp 38 | self.center_momentum = center_momentum 39 | self.num_crops = num_crops 40 | self.register_buffer("center", torch.zeros(1, num_prototypes)) 41 | # we apply a warm up for the teacher temperature because 42 | # a too high temperature makes the training unstable at the beginning 43 | self.teacher_temp_schedule = np.concatenate( 44 | ( 45 | np.linspace(warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs), 46 | np.ones(num_epochs - warmup_teacher_temp_epochs) * teacher_temp, 47 | ) 48 | ) 49 | 50 | def forward(self, student_output: torch.Tensor, teacher_output: torch.Tensor) -> torch.Tensor: 51 | """Computes DINO's loss given a batch of logits of the student and a batch of logits of the 52 | teacher. 53 | 54 | Args: 55 | student_output (torch.Tensor): NxP Tensor containing student logits for all views. 56 | teacher_output (torch.Tensor): NxP Tensor containing teacher logits for all views. 57 | 58 | Returns: 59 | torch.Tensor: DINO loss. 60 | """ 61 | 62 | student_out = student_output / self.student_temp 63 | student_out = student_out.chunk(self.num_crops) 64 | 65 | # teacher centering and sharpening 66 | temp = self.teacher_temp_schedule[self.epoch] 67 | teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1) 68 | teacher_out = teacher_out.detach().chunk(2) 69 | 70 | total_loss = 0 71 | n_loss_terms = 0 72 | for iq, q in enumerate(teacher_out): 73 | for v in range(len(student_out)): 74 | if v == iq: 75 | # we skip cases where student and teacher operate on the same view 76 | continue 77 | loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1) 78 | total_loss += loss.mean() 79 | n_loss_terms += 1 80 | total_loss /= n_loss_terms 81 | self.update_center(teacher_output) 82 | return total_loss 83 | 84 | @torch.no_grad() 85 | def update_center(self, teacher_output: torch.Tensor): 86 | """Updates the center for DINO's loss using exponential moving average. 87 | 88 | Args: 89 | teacher_output (torch.Tensor): NxP Tensor containing teacher logits of all views. 90 | """ 91 | 92 | batch_center = torch.sum(teacher_output, dim=0, keepdim=True) 93 | if dist.is_available() and dist.is_initialized(): 94 | dist.all_reduce(batch_center) 95 | batch_center = batch_center / dist.get_world_size() 96 | batch_center = batch_center / len(teacher_output) 97 | 98 | # ema update 99 | self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum) 100 | -------------------------------------------------------------------------------- /cassle/args/setup.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import pytorch_lightning as pl 4 | from cassle.args.dataset import augmentations_args, dataset_args 5 | from cassle.args.utils import additional_setup_linear, additional_setup_pretrain 6 | from cassle.args.continual import continual_args 7 | from cassle.methods import METHODS 8 | from cassle.utils.checkpointer import Checkpointer 9 | from cassle.distillers import DISTILLERS 10 | 11 | try: 12 | from cassle.utils.auto_umap import AutoUMAP 13 | except ImportError: 14 | _umap_available = False 15 | else: 16 | _umap_available = True 17 | 18 | 19 | def parse_args_pretrain() -> argparse.Namespace: 20 | """Parses dataset, augmentation, pytorch lightning, model specific and additional args. 21 | 22 | First adds shared args such as dataset, augmentation and pytorch lightning args, then pulls the 23 | model name from the command and proceeds to add model specific args from the desired class. If 24 | wandb is enabled, it adds checkpointer args. Finally, adds additional non-user given parameters. 25 | 26 | Returns: 27 | argparse.Namespace: a namespace containing all args needed for pretraining. 28 | """ 29 | 30 | parser = argparse.ArgumentParser() 31 | 32 | # add shared arguments 33 | dataset_args(parser) 34 | augmentations_args(parser) 35 | continual_args(parser) 36 | 37 | # add pytorch lightning trainer args 38 | parser = pl.Trainer.add_argparse_args(parser) 39 | 40 | # add method-specific arguments 41 | parser.add_argument("--method", type=str) 42 | 43 | # THIS LINE IS KEY TO PULL THE MODEL NAME 44 | temp_args, _ = parser.parse_known_args() 45 | 46 | # add model specific args 47 | parser = METHODS[temp_args.method].add_model_specific_args(parser) 48 | 49 | # add distiller args 50 | if temp_args.distiller: 51 | parser = DISTILLERS[temp_args.distiller]().add_model_specific_args(parser) 52 | 53 | # add checkpoint and auto umap args 54 | parser.add_argument("--pretrained_model", type=str, default=None) 55 | parser.add_argument("--save_checkpoint", action="store_true") 56 | parser.add_argument("--auto_umap", action="store_true") 57 | temp_args, _ = parser.parse_known_args() 58 | 59 | # optionally add checkpointer and AutoUMAP args 60 | if temp_args.save_checkpoint: 61 | parser = Checkpointer.add_checkpointer_args(parser) 62 | 63 | if _umap_available and temp_args.auto_umap: 64 | parser = AutoUMAP.add_auto_umap_args(parser) 65 | 66 | # parse args 67 | args = parser.parse_args() 68 | 69 | # prepare arguments with additional setup 70 | additional_setup_pretrain(args) 71 | 72 | return args 73 | 74 | 75 | def parse_args_linear() -> argparse.Namespace: 76 | """Parses feature extractor, dataset, pytorch lightning, linear eval specific and additional args. 77 | 78 | First adds and arg for the pretrained feature extractor, then adds dataset, pytorch lightning 79 | and linear eval specific args. If wandb is enabled, it adds checkpointer args. Finally, adds 80 | additional non-user given parameters. 81 | 82 | Returns: 83 | argparse.Namespace: a namespace containing all args needed for pretraining. 84 | """ 85 | 86 | parser = argparse.ArgumentParser() 87 | 88 | parser.add_argument("--pretrained_feature_extractor", type=str) 89 | 90 | # add shared arguments 91 | dataset_args(parser) 92 | 93 | # add pytorch lightning trainer args 94 | parser = pl.Trainer.add_argparse_args(parser) 95 | 96 | # linear model 97 | parser = METHODS["linear"].add_model_specific_args(parser) 98 | 99 | # THIS LINE IS KEY TO PULL WANDB 100 | temp_args, _ = parser.parse_known_args() 101 | 102 | parser.add_argument("--save_checkpoint", action="store_true") 103 | parser.add_argument("--num_tasks", type=int, default=2) 104 | SPLIT_STRATEGIES = ["class", "data", "domain"] 105 | parser.add_argument("--split_strategy", choices=SPLIT_STRATEGIES, type=str, required=True) 106 | parser.add_argument("--domain", type=str, default=None) 107 | 108 | # add checkpointer args (only if logging is enabled) 109 | if temp_args.wandb: 110 | parser = Checkpointer.add_checkpointer_args(parser) 111 | 112 | # parse args 113 | args = parser.parse_args() 114 | additional_setup_linear(args) 115 | 116 | return args 117 | -------------------------------------------------------------------------------- /cassle/methods/vicreg.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Any, Dict, List, Sequence 3 | 4 | import torch 5 | import torch.nn as nn 6 | from cassle.losses.vicreg import vicreg_loss_func 7 | from cassle.methods.base import BaseModel 8 | 9 | 10 | class VICReg(BaseModel): 11 | def __init__( 12 | self, 13 | output_dim: int, 14 | proj_hidden_dim: int, 15 | sim_loss_weight: float, 16 | var_loss_weight: float, 17 | cov_loss_weight: float, 18 | **kwargs 19 | ): 20 | """Implements VICReg (https://arxiv.org/abs/2105.04906) 21 | 22 | Args: 23 | output_dim (int): number of dimensions of the projected features. 24 | proj_hidden_dim (int): number of neurons in the hidden layers of the projector. 25 | sim_loss_weight (float): weight of the invariance term. 26 | var_loss_weight (float): weight of the variance term. 27 | cov_loss_weight (float): weight of the covariance term. 28 | """ 29 | 30 | super().__init__(**kwargs) 31 | 32 | self.sim_loss_weight = sim_loss_weight 33 | self.var_loss_weight = var_loss_weight 34 | self.cov_loss_weight = cov_loss_weight 35 | 36 | # projector 37 | self.projector = nn.Sequential( 38 | nn.Linear(self.features_dim, proj_hidden_dim), 39 | nn.BatchNorm1d(proj_hidden_dim), 40 | nn.ReLU(), 41 | nn.Linear(proj_hidden_dim, proj_hidden_dim), 42 | nn.BatchNorm1d(proj_hidden_dim), 43 | nn.ReLU(), 44 | nn.Linear(proj_hidden_dim, output_dim), 45 | ) 46 | 47 | @staticmethod 48 | def add_model_specific_args(parent_parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 49 | parent_parser = super(VICReg, VICReg).add_model_specific_args(parent_parser) 50 | parser = parent_parser.add_argument_group("vicreg") 51 | 52 | # projector 53 | parser.add_argument("--output_dim", type=int, default=2048) 54 | parser.add_argument("--proj_hidden_dim", type=int, default=2048) 55 | 56 | # parameters 57 | parser.add_argument("--sim_loss_weight", default=25, type=float) 58 | parser.add_argument("--var_loss_weight", default=25, type=float) 59 | parser.add_argument("--cov_loss_weight", default=1.0, type=float) 60 | return parent_parser 61 | 62 | @property 63 | def learnable_params(self) -> List[dict]: 64 | """Adds projector parameters to the parent's learnable parameters. 65 | 66 | Returns: 67 | List[dict]: list of learnable parameters. 68 | """ 69 | 70 | extra_learnable_params = [{"params": self.projector.parameters()}] 71 | return super().learnable_params + extra_learnable_params 72 | 73 | def forward(self, X: torch.Tensor, *args, **kwargs) -> Dict[str, Any]: 74 | """Performs the forward pass of the encoder and the projector. 75 | 76 | Args: 77 | X (torch.Tensor): a batch of images in the tensor format. 78 | 79 | Returns: 80 | Dict[str, Any]: a dict containing the outputs of the parent and the projected features. 81 | """ 82 | 83 | out = super().forward(X, *args, **kwargs) 84 | z = self.projector(out["feats"]) 85 | return {**out, "z": z} 86 | 87 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: 88 | """Training step for VICReg reusing BaseModel training step. 89 | 90 | Args: 91 | batch (Sequence[Any]): a batch of data in the format of [img_indexes, [X], Y], where 92 | [X] is a list of size self.num_crops containing batches of images. 93 | batch_idx (int): index of the batch. 94 | 95 | Returns: 96 | torch.Tensor: total loss composed of VICReg loss and classification loss. 97 | """ 98 | 99 | out = super().training_step(batch, batch_idx) 100 | feats1, feats2 = out["feats"] 101 | 102 | z1 = self.projector(feats1) 103 | z2 = self.projector(feats2) 104 | 105 | # ------- barlow twins loss ------- 106 | vicreg_loss = vicreg_loss_func( 107 | z1, 108 | z2, 109 | sim_loss_weight=self.sim_loss_weight, 110 | var_loss_weight=self.var_loss_weight, 111 | cov_loss_weight=self.cov_loss_weight, 112 | ) 113 | 114 | self.log("train_vicreg_loss", vicreg_loss, on_epoch=True, sync_dist=True) 115 | 116 | out.update({"loss": out["loss"] + vicreg_loss, "z": [z1, z2]}) 117 | return out 118 | -------------------------------------------------------------------------------- /cassle/distillers/knowledge.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Any, List, Sequence 3 | 4 | import torch 5 | from torch import nn 6 | import torch.nn.functional as F 7 | from cassle.distillers.base import base_distill_wrapper 8 | 9 | 10 | def cross_entropy(preds, targets): 11 | return -torch.mean( 12 | torch.sum(F.softmax(targets, dim=-1) * torch.log_softmax(preds, dim=-1), dim=-1) 13 | ) 14 | 15 | 16 | def knowledge_distill_wrapper(Method=object): 17 | class KnowledgeDistillWrapper(base_distill_wrapper(Method)): 18 | def __init__( 19 | self, 20 | distill_lamb: float, 21 | distill_proj_hidden_dim: int, 22 | distill_temperature: float, 23 | **kwargs 24 | ): 25 | super().__init__(**kwargs) 26 | 27 | self.distill_lamb = distill_lamb 28 | self.distill_temperature = distill_temperature 29 | output_dim = kwargs["output_dim"] 30 | num_prototypes = kwargs["num_prototypes"] 31 | 32 | self.frozen_prototypes = nn.utils.weight_norm( 33 | nn.Linear(output_dim, num_prototypes, bias=False) 34 | ) 35 | for frozen_pg, pg in zip( 36 | self.frozen_prototypes.parameters(), self.prototypes.parameters() 37 | ): 38 | frozen_pg.data.copy_(pg.data) 39 | frozen_pg.requires_grad = False 40 | 41 | self.distill_predictor = nn.Sequential( 42 | nn.Linear(output_dim, distill_proj_hidden_dim), 43 | nn.BatchNorm1d(distill_proj_hidden_dim), 44 | nn.ReLU(), 45 | nn.Linear(distill_proj_hidden_dim, output_dim), 46 | ) 47 | 48 | self.distill_prototypes = nn.utils.weight_norm( 49 | nn.Linear(output_dim, num_prototypes, bias=False) 50 | ) 51 | 52 | @staticmethod 53 | def add_model_specific_args( 54 | parent_parser: argparse.ArgumentParser, 55 | ) -> argparse.ArgumentParser: 56 | parser = parent_parser.add_argument_group("knowledge_distiller") 57 | 58 | parser.add_argument("--distill_lamb", type=float, default=1) 59 | parser.add_argument("--distill_proj_hidden_dim", type=int, default=2048) 60 | parser.add_argument("--distill_temperature", type=float, default=0.1) 61 | 62 | return parent_parser 63 | 64 | @property 65 | def learnable_params(self) -> List[dict]: 66 | """Adds distill predictor parameters to the parent's learnable parameters. 67 | 68 | Returns: 69 | List[dict]: list of learnable parameters. 70 | """ 71 | 72 | extra_learnable_params = [ 73 | {"params": self.distill_predictor.parameters()}, 74 | {"params": self.distill_prototypes.parameters()}, 75 | ] 76 | return super().learnable_params + extra_learnable_params 77 | 78 | def on_train_start(self): 79 | super().on_train_start() 80 | 81 | if self.current_task_idx > 0: 82 | for frozen_pg, pg in zip( 83 | self.frozen_prototypes.parameters(), self.prototypes.parameters() 84 | ): 85 | frozen_pg.data.copy_(pg.data) 86 | frozen_pg.requires_grad = False 87 | 88 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: 89 | out = super().training_step(batch, batch_idx) 90 | z1, z2 = out["z"] 91 | frozen_z1, frozen_z2 = out["frozen_z"] 92 | 93 | with torch.no_grad(): 94 | frozen_z1 = F.normalize(frozen_z1) 95 | frozen_z2 = F.normalize(frozen_z2) 96 | frozen_p1 = self.frozen_prototypes(frozen_z1) / self.distill_temperature 97 | frozen_p2 = self.frozen_prototypes(frozen_z2) / self.distill_temperature 98 | 99 | distill_z1 = F.normalize(self.distill_predictor(z1)) 100 | distill_z2 = F.normalize(self.distill_predictor(z2)) 101 | distill_p1 = self.distill_prototypes(distill_z1) / self.distill_temperature 102 | distill_p2 = self.distill_prototypes(distill_z2) / self.distill_temperature 103 | 104 | distill_loss = ( 105 | cross_entropy(distill_p1, frozen_p1) + cross_entropy(distill_p2, frozen_p2) 106 | ) / 2 107 | 108 | self.log("train_knowledge_distill_loss", distill_loss, on_epoch=True, sync_dist=True) 109 | 110 | return out["loss"] + self.distill_lamb * distill_loss 111 | 112 | return KnowledgeDistillWrapper 113 | -------------------------------------------------------------------------------- /cassle/methods/wmse.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Sequence 2 | 3 | import torch 4 | import torch.nn as nn 5 | from cassle.losses.wmse import wmse_loss_func 6 | from cassle.methods.base import BaseModel 7 | from cassle.utils.whitening import Whitening2d 8 | 9 | 10 | class WMSE(BaseModel): 11 | def __init__( 12 | self, 13 | output_dim: int, 14 | proj_hidden_dim: int, 15 | whitening_iters: int, 16 | whitening_size: int, 17 | whitening_eps: float, 18 | **kwargs 19 | ): 20 | """Implements W-MSE (https://arxiv.org/abs/2007.06346) 21 | 22 | Args: 23 | output_dim (int): number of dimensions of the projected features. 24 | proj_hidden_dim (int): number of neurons in the hidden layers of the projector. 25 | whitening_iters (int): number of times to perform whitening. 26 | whitening_size (int): size of the batch slice for whitening. 27 | whitening_eps (float): epsilon for numerical stability in whitening. 28 | """ 29 | 30 | super().__init__(**kwargs) 31 | 32 | self.whitening_iters = whitening_iters 33 | self.whitening_size = whitening_size 34 | 35 | assert self.whitening_size <= self.batch_size 36 | 37 | # projector 38 | self.projector = nn.Sequential( 39 | nn.Linear(self.features_dim, proj_hidden_dim), 40 | nn.BatchNorm1d(proj_hidden_dim), 41 | nn.ReLU(), 42 | nn.Linear(proj_hidden_dim, output_dim), 43 | ) 44 | 45 | self.whitening = Whitening2d(output_dim, eps=whitening_eps) 46 | 47 | @staticmethod 48 | def add_model_specific_args(parent_parser): 49 | parent_parser = super(WMSE, WMSE).add_model_specific_args(parent_parser) 50 | parser = parent_parser.add_argument_group("simclr") 51 | 52 | # projector 53 | parser.add_argument("--output_dim", type=int, default=128) 54 | parser.add_argument("--proj_hidden_dim", type=int, default=1024) 55 | 56 | # wmse 57 | parser.add_argument("--whitening_iters", type=int, default=1) 58 | parser.add_argument("--whitening_size", type=int, default=256) 59 | parser.add_argument("--whitening_eps", type=float, default=0) 60 | 61 | return parent_parser 62 | 63 | @property 64 | def learnable_params(self) -> List[Dict]: 65 | """Adds projector parameters to the parent's learnable parameters. 66 | 67 | Returns: 68 | List[dict]: list of learnable parameters. 69 | """ 70 | 71 | extra_learnable_params = [{"params": self.projector.parameters()}] 72 | return super().learnable_params + extra_learnable_params 73 | 74 | def forward(self, X: torch.Tensor, *args, **kwargs) -> Dict[str, Any]: 75 | """Performs the forward pass of the encoder and the projector. 76 | 77 | Args: 78 | X (torch.Tensor): a batch of images in the tensor format. 79 | 80 | Returns: 81 | Dict[str, Any]: a dict containing the outputs of the parent and the projected features. 82 | """ 83 | 84 | out = super().forward(X, *args, **kwargs) 85 | v = self.projector(out["feats"]) 86 | return {**out, "v": v} 87 | 88 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: 89 | """Training step for W-MSE reusing BaseModel training step. 90 | 91 | Args: 92 | batch (Sequence[Any]): a batch of data in the format of [img_indexes, [X], Y], where 93 | [X] is a list of size self.num_crops containing batches of images 94 | batch_idx (int): index of the batch 95 | 96 | Returns: 97 | torch.Tensor: total loss composed of W-MSE loss and classification loss 98 | """ 99 | 100 | out = super().training_step(batch, batch_idx) 101 | class_loss = out["loss"] 102 | feats = out["feats"] 103 | 104 | v = torch.cat([self.projector(f) for f in feats]) 105 | 106 | # ------- wmse loss ------- 107 | bs = self.batch_size 108 | num_losses, wmse_loss = 0, 0 109 | for _ in range(self.whitening_iters): 110 | z = torch.empty_like(v) 111 | perm = torch.randperm(bs).view(-1, self.whitening_size) 112 | for idx in perm: 113 | for i in range(self.num_crops): 114 | z[idx + i * bs] = self.whitening(v[idx + i * bs]).type_as(z) 115 | for i in range(self.num_crops - 1): 116 | for j in range(i + 1, self.num_crops): 117 | x0 = z[i * bs : (i + 1) * bs] 118 | x1 = z[j * bs : (j + 1) * bs] 119 | wmse_loss += wmse_loss_func(x0, x1) 120 | num_losses += 1 121 | wmse_loss /= num_losses 122 | 123 | self.log("train_neg_cos_sim", wmse_loss, on_epoch=True, sync_dist=True) 124 | 125 | return wmse_loss + class_loss 126 | -------------------------------------------------------------------------------- /bash_files/linear/domainnet/domain/byol_linear.sh: -------------------------------------------------------------------------------- 1 | # all 2 | python3 main_linear.py \ 3 | --dataset domainnet \ 4 | --encoder resnet18 \ 5 | --data_dir $DATA_DIR/domainnet \ 6 | --split_strategy domain \ 7 | --max_epochs 100 \ 8 | --gpus 0 \ 9 | --precision 16 \ 10 | --optimizer sgd \ 11 | --scheduler step \ 12 | --lr 3.0 \ 13 | --lr_decay_steps 60 80 \ 14 | --weight_decay 0 \ 15 | --batch_size 256 \ 16 | --num_workers 8 \ 17 | --dali \ 18 | --name byol-domainnet_all-linear-eval \ 19 | --pretrained_feature_extractor $PRETRAINED_PATH \ 20 | --project ever-learn \ 21 | --entity unitn-mhug \ 22 | --wandb \ 23 | --save_checkpoint 24 | 25 | # quickdraw 26 | python3 main_linear.py \ 27 | --dataset domainnet \ 28 | --encoder resnet18 \ 29 | --data_dir $DATA_DIR/domainnet \ 30 | --split_strategy domain \ 31 | --domain quickdraw \ 32 | --max_epochs 100 \ 33 | --gpus 0 \ 34 | --precision 16 \ 35 | --optimizer sgd \ 36 | --scheduler step \ 37 | --lr 3.0 \ 38 | --lr_decay_steps 60 80 \ 39 | --weight_decay 0 \ 40 | --batch_size 256 \ 41 | --num_workers 8 \ 42 | --dali \ 43 | --name byol-domainnet_quickdraw-linear-eval \ 44 | --pretrained_feature_extractor $PRETRAINED_PATH \ 45 | --project ever-learn \ 46 | --entity unitn-mhug \ 47 | --wandb \ 48 | --save_checkpoint 49 | 50 | # clipart 51 | python3 main_linear.py \ 52 | --dataset domainnet \ 53 | --encoder resnet18 \ 54 | --data_dir $DATA_DIR/domainnet \ 55 | --split_strategy domain \ 56 | --domain clipart \ 57 | --max_epochs 100 \ 58 | --gpus 0 \ 59 | --precision 16 \ 60 | --optimizer sgd \ 61 | --scheduler step \ 62 | --lr 3.0 \ 63 | --lr_decay_steps 60 80 \ 64 | --weight_decay 0 \ 65 | --batch_size 256 \ 66 | --num_workers 8 \ 67 | --dali \ 68 | --name byol-domainnet_clipart-linear-eval \ 69 | --pretrained_feature_extractor $PRETRAINED_PATH \ 70 | --project ever-learn \ 71 | --entity unitn-mhug \ 72 | --wandb \ 73 | --save_checkpoint 74 | 75 | # infograph 76 | python3 main_linear.py \ 77 | --dataset domainnet \ 78 | --encoder resnet18 \ 79 | --data_dir $DATA_DIR/domainnet \ 80 | --split_strategy domain \ 81 | --domain infograph \ 82 | --max_epochs 100 \ 83 | --gpus 0 \ 84 | --precision 16 \ 85 | --optimizer sgd \ 86 | --scheduler step \ 87 | --lr 3.0 \ 88 | --lr_decay_steps 60 80 \ 89 | --weight_decay 0 \ 90 | --batch_size 256 \ 91 | --num_workers 8 \ 92 | --dali \ 93 | --name byol-domainnet_infograph-linear-eval \ 94 | --pretrained_feature_extractor $PRETRAINED_PATH \ 95 | --project ever-learn \ 96 | --entity unitn-mhug \ 97 | --wandb \ 98 | --save_checkpoint 99 | 100 | # painting 101 | python3 main_linear.py \ 102 | --dataset domainnet \ 103 | --encoder resnet18 \ 104 | --data_dir $DATA_DIR/domainnet \ 105 | --split_strategy domain \ 106 | --domain painting \ 107 | --max_epochs 100 \ 108 | --gpus 0 \ 109 | --precision 16 \ 110 | --optimizer sgd \ 111 | --scheduler step \ 112 | --lr 3.0 \ 113 | --lr_decay_steps 60 80 \ 114 | --weight_decay 0 \ 115 | --batch_size 256 \ 116 | --num_workers 8 \ 117 | --dali \ 118 | --name byol-domainnet_painting-linear-eval \ 119 | --pretrained_feature_extractor $PRETRAINED_PATH \ 120 | --project ever-learn \ 121 | --entity unitn-mhug \ 122 | --wandb \ 123 | --save_checkpoint 124 | 125 | # real 126 | python3 main_linear.py \ 127 | --dataset domainnet \ 128 | --encoder resnet18 \ 129 | --data_dir $DATA_DIR/domainnet \ 130 | --split_strategy domain \ 131 | --domain real \ 132 | --max_epochs 100 \ 133 | --gpus 0 \ 134 | --precision 16 \ 135 | --optimizer sgd \ 136 | --scheduler step \ 137 | --lr 3.0 \ 138 | --lr_decay_steps 60 80 \ 139 | --weight_decay 0 \ 140 | --batch_size 256 \ 141 | --num_workers 8 \ 142 | --dali \ 143 | --name byol-domainnet_real-linear-eval \ 144 | --pretrained_feature_extractor $PRETRAINED_PATH \ 145 | --project ever-learn \ 146 | --entity unitn-mhug \ 147 | --wandb \ 148 | --save_checkpoint 149 | 150 | # sketch 151 | python3 main_linear.py \ 152 | --dataset domainnet \ 153 | --encoder resnet18 \ 154 | --data_dir $DATA_DIR/domainnet \ 155 | --split_strategy domain \ 156 | --domain sketch \ 157 | --max_epochs 100 \ 158 | --gpus 0 \ 159 | --precision 16 \ 160 | --optimizer sgd \ 161 | --scheduler step \ 162 | --lr 3.0 \ 163 | --lr_decay_steps 60 80 \ 164 | --weight_decay 0 \ 165 | --batch_size 256 \ 166 | --num_workers 8 \ 167 | --dali \ 168 | --name byol-domainnet_sketch-linear-eval \ 169 | --pretrained_feature_extractor $PRETRAINED_PATH \ 170 | --project ever-learn \ 171 | --entity unitn-mhug \ 172 | --wandb \ 173 | --save_checkpoint -------------------------------------------------------------------------------- /bash_files/linear/domainnet/domain/barlow_linear.sh: -------------------------------------------------------------------------------- 1 | # all 2 | python3 main_linear.py \ 3 | --dataset domainnet \ 4 | --encoder resnet18 \ 5 | --data_dir $DATA_DIR/domainnet \ 6 | --split_strategy domain \ 7 | --max_epochs 100 \ 8 | --gpus 0 \ 9 | --precision 16 \ 10 | --optimizer sgd \ 11 | --scheduler step \ 12 | --lr 0.1 \ 13 | --lr_decay_steps 60 80 \ 14 | --weight_decay 0 \ 15 | --batch_size 256 \ 16 | --num_workers 8 \ 17 | --dali \ 18 | --name barlow-domain_all-linear-eval \ 19 | --pretrained_feature_extractor $PRETRAINED_PATH \ 20 | --project ever-learn \ 21 | --entity unitn-mhug \ 22 | --wandb \ 23 | --save_checkpoint 24 | 25 | # quickdraw 26 | python3 main_linear.py \ 27 | --dataset domainnet \ 28 | --encoder resnet18 \ 29 | --data_dir $DATA_DIR/domainnet \ 30 | --split_strategy domain \ 31 | --domain quickdraw \ 32 | --max_epochs 100 \ 33 | --gpus 0 \ 34 | --precision 16 \ 35 | --optimizer sgd \ 36 | --scheduler step \ 37 | --lr 0.1 \ 38 | --lr_decay_steps 60 80 \ 39 | --weight_decay 0 \ 40 | --batch_size 256 \ 41 | --num_workers 8 \ 42 | --dali \ 43 | --name barlow-domain_quickdraw-linear-eval \ 44 | --pretrained_feature_extractor $PRETRAINED_PATH \ 45 | --project ever-learn \ 46 | --entity unitn-mhug \ 47 | --wandb \ 48 | --save_checkpoint 49 | 50 | # clipart 51 | python3 main_linear.py \ 52 | --dataset domainnet \ 53 | --encoder resnet18 \ 54 | --data_dir $DATA_DIR/domainnet \ 55 | --split_strategy domain \ 56 | --domain clipart \ 57 | --max_epochs 100 \ 58 | --gpus 0 \ 59 | --precision 16 \ 60 | --optimizer sgd \ 61 | --scheduler step \ 62 | --lr 0.1 \ 63 | --lr_decay_steps 60 80 \ 64 | --weight_decay 0 \ 65 | --batch_size 256 \ 66 | --num_workers 8 \ 67 | --dali \ 68 | --name barlow-domain_clipart-linear-eval \ 69 | --pretrained_feature_extractor $PRETRAINED_PATH \ 70 | --project ever-learn \ 71 | --entity unitn-mhug \ 72 | --wandb \ 73 | --save_checkpoint 74 | 75 | # infograph 76 | python3 main_linear.py \ 77 | --dataset domainnet \ 78 | --encoder resnet18 \ 79 | --data_dir $DATA_DIR/domainnet \ 80 | --split_strategy domain \ 81 | --domain infograph \ 82 | --max_epochs 100 \ 83 | --gpus 0 \ 84 | --precision 16 \ 85 | --optimizer sgd \ 86 | --scheduler step \ 87 | --lr 0.1 \ 88 | --lr_decay_steps 60 80 \ 89 | --weight_decay 0 \ 90 | --batch_size 256 \ 91 | --num_workers 8 \ 92 | --dali \ 93 | --name barlow-domain_infograph-linear-eval \ 94 | --pretrained_feature_extractor $PRETRAINED_PATH \ 95 | --project ever-learn \ 96 | --entity unitn-mhug \ 97 | --wandb \ 98 | --save_checkpoint 99 | 100 | # painting 101 | python3 main_linear.py \ 102 | --dataset domainnet \ 103 | --encoder resnet18 \ 104 | --data_dir $DATA_DIR/domainnet \ 105 | --split_strategy domain \ 106 | --domain painting \ 107 | --max_epochs 100 \ 108 | --gpus 0 \ 109 | --precision 16 \ 110 | --optimizer sgd \ 111 | --scheduler step \ 112 | --lr 0.1 \ 113 | --lr_decay_steps 60 80 \ 114 | --weight_decay 0 \ 115 | --batch_size 256 \ 116 | --num_workers 8 \ 117 | --dali \ 118 | --name barlow-domain_painting-linear-eval \ 119 | --pretrained_feature_extractor $PRETRAINED_PATH \ 120 | --project ever-learn \ 121 | --entity unitn-mhug \ 122 | --wandb \ 123 | --save_checkpoint 124 | 125 | # real 126 | python3 main_linear.py \ 127 | --dataset domainnet \ 128 | --encoder resnet18 \ 129 | --data_dir $DATA_DIR/domainnet \ 130 | --split_strategy domain \ 131 | --domain real \ 132 | --max_epochs 100 \ 133 | --gpus 0 \ 134 | --precision 16 \ 135 | --optimizer sgd \ 136 | --scheduler step \ 137 | --lr 0.1 \ 138 | --lr_decay_steps 60 80 \ 139 | --weight_decay 0 \ 140 | --batch_size 256 \ 141 | --num_workers 8 \ 142 | --dali \ 143 | --name barlow-domain_real-linear-eval \ 144 | --pretrained_feature_extractor $PRETRAINED_PATH \ 145 | --project ever-learn \ 146 | --entity unitn-mhug \ 147 | --wandb \ 148 | --save_checkpoint 149 | 150 | # sketch 151 | python3 main_linear.py \ 152 | --dataset domainnet \ 153 | --encoder resnet18 \ 154 | --data_dir $DATA_DIR/domainnet \ 155 | --split_strategy domain \ 156 | --domain sketch \ 157 | --max_epochs 100 \ 158 | --gpus 0 \ 159 | --precision 16 \ 160 | --optimizer sgd \ 161 | --scheduler step \ 162 | --lr 0.1 \ 163 | --lr_decay_steps 60 80 \ 164 | --weight_decay 0 \ 165 | --batch_size 256 \ 166 | --num_workers 8 \ 167 | --dali \ 168 | --name barlow-domain_sketch-linear-eval \ 169 | --pretrained_feature_extractor $PRETRAINED_PATH \ 170 | --project ever-learn \ 171 | --entity unitn-mhug \ 172 | --wandb \ 173 | --save_checkpoint 174 | 175 | -------------------------------------------------------------------------------- /bash_files/linear/domainnet/domain/swav_linear.sh: -------------------------------------------------------------------------------- 1 | # all 2 | python3 main_linear.py \ 3 | --dataset domainnet \ 4 | --encoder resnet18 \ 5 | --data_dir $DATA_DIR/domainnet \ 6 | --split_strategy domain \ 7 | --max_epochs 100 \ 8 | --gpus 0 \ 9 | --precision 16 \ 10 | --optimizer sgd \ 11 | --scheduler step \ 12 | --lr 0.15 \ 13 | --lr_decay_steps 60 80 \ 14 | --weight_decay 0 \ 15 | --batch_size 256 \ 16 | --num_workers 7 \ 17 | --dali \ 18 | --name swav-domainnet_all-linear-eval \ 19 | --pretrained_feature_extractor $PRETRAINED_PATH \ 20 | --project ever-learn \ 21 | --entity unitn-mhug \ 22 | --wandb \ 23 | --save_checkpoint 24 | 25 | # quickdraw 26 | python3 main_linear.py \ 27 | --dataset domainnet \ 28 | --encoder resnet18 \ 29 | --data_dir $DATA_DIR/domainnet \ 30 | --split_strategy domain \ 31 | --domain quickdraw \ 32 | --max_epochs 100 \ 33 | --gpus 0 \ 34 | --precision 16 \ 35 | --optimizer sgd \ 36 | --scheduler step \ 37 | --lr 0.15 \ 38 | --lr_decay_steps 60 80 \ 39 | --weight_decay 0 \ 40 | --batch_size 256 \ 41 | --num_workers 7 \ 42 | --dali \ 43 | --name swav-domainnet_quickdraw-linear-eval \ 44 | --pretrained_feature_extractor $PRETRAINED_PATH \ 45 | --project ever-learn \ 46 | --entity unitn-mhug \ 47 | --wandb \ 48 | --save_checkpoint 49 | 50 | # clipart 51 | python3 main_linear.py \ 52 | --dataset domainnet \ 53 | --encoder resnet18 \ 54 | --data_dir $DATA_DIR/domainnet \ 55 | --split_strategy domain \ 56 | --domain clipart \ 57 | --max_epochs 100 \ 58 | --gpus 0 \ 59 | --precision 16 \ 60 | --optimizer sgd \ 61 | --scheduler step \ 62 | --lr 0.15 \ 63 | --lr_decay_steps 60 80 \ 64 | --weight_decay 0 \ 65 | --batch_size 256 \ 66 | --num_workers 7 \ 67 | --dali \ 68 | --name swav-domainnet_clipart-linear-eval \ 69 | --pretrained_feature_extractor $PRETRAINED_PATH \ 70 | --project ever-learn \ 71 | --entity unitn-mhug \ 72 | --wandb \ 73 | --save_checkpoint 74 | 75 | # infograph 76 | python3 main_linear.py \ 77 | --dataset domainnet \ 78 | --encoder resnet18 \ 79 | --data_dir $DATA_DIR/domainnet \ 80 | --split_strategy domain \ 81 | --domain infograph \ 82 | --max_epochs 100 \ 83 | --gpus 0 \ 84 | --precision 16 \ 85 | --optimizer sgd \ 86 | --scheduler step \ 87 | --lr 0.15 \ 88 | --lr_decay_steps 60 80 \ 89 | --weight_decay 0 \ 90 | --batch_size 256 \ 91 | --num_workers 7 \ 92 | --dali \ 93 | --name swav-domainnet_infograph-linear-eval \ 94 | --pretrained_feature_extractor $PRETRAINED_PATH \ 95 | --project ever-learn \ 96 | --entity unitn-mhug \ 97 | --wandb \ 98 | --save_checkpoint 99 | 100 | # painting 101 | python3 main_linear.py \ 102 | --dataset domainnet \ 103 | --encoder resnet18 \ 104 | --data_dir $DATA_DIR/domainnet \ 105 | --split_strategy domain \ 106 | --domain painting \ 107 | --max_epochs 100 \ 108 | --gpus 0 \ 109 | --precision 16 \ 110 | --optimizer sgd \ 111 | --scheduler step \ 112 | --lr 0.15 \ 113 | --lr_decay_steps 60 80 \ 114 | --weight_decay 0 \ 115 | --batch_size 256 \ 116 | --num_workers 7 \ 117 | --dali \ 118 | --name swav-domainnet_painting-linear-eval \ 119 | --pretrained_feature_extractor $PRETRAINED_PATH \ 120 | --project ever-learn \ 121 | --entity unitn-mhug \ 122 | --wandb \ 123 | --save_checkpoint 124 | 125 | # real 126 | python3 main_linear.py \ 127 | --dataset domainnet \ 128 | --encoder resnet18 \ 129 | --data_dir $DATA_DIR/domainnet \ 130 | --split_strategy domain \ 131 | --domain real \ 132 | --max_epochs 100 \ 133 | --gpus 0 \ 134 | --precision 16 \ 135 | --optimizer sgd \ 136 | --scheduler step \ 137 | --lr 0.15 \ 138 | --lr_decay_steps 60 80 \ 139 | --weight_decay 0 \ 140 | --batch_size 256 \ 141 | --num_workers 7 \ 142 | --dali \ 143 | --name swav-domainnet_real-linear-eval \ 144 | --pretrained_feature_extractor $PRETRAINED_PATH \ 145 | --project ever-learn \ 146 | --entity unitn-mhug \ 147 | --wandb \ 148 | --save_checkpoint 149 | 150 | # sketch 151 | python3 main_linear.py \ 152 | --dataset domainnet \ 153 | --encoder resnet18 \ 154 | --data_dir $DATA_DIR/domainnet \ 155 | --split_strategy domain \ 156 | --domain sketch \ 157 | --max_epochs 100 \ 158 | --gpus 0 \ 159 | --precision 16 \ 160 | --optimizer sgd \ 161 | --scheduler step \ 162 | --lr 0.15 \ 163 | --lr_decay_steps 60 80 \ 164 | --weight_decay 0 \ 165 | --batch_size 256 \ 166 | --num_workers 7 \ 167 | --dali \ 168 | --name swav-domainnet_sketch-linear-eval \ 169 | --pretrained_feature_extractor $PRETRAINED_PATH \ 170 | --project ever-learn \ 171 | --entity unitn-mhug \ 172 | --wandb \ 173 | --save_checkpoint 174 | -------------------------------------------------------------------------------- /bash_files/linear/domainnet/domain/vicreg_linear.sh: -------------------------------------------------------------------------------- 1 | # all 2 | python3 main_linear.py \ 3 | --dataset domainnet \ 4 | --encoder resnet18 \ 5 | --data_dir $DATA_DIR/domainnet \ 6 | --split_strategy domain \ 7 | --max_epochs 100 \ 8 | --gpus 0 \ 9 | --precision 16 \ 10 | --optimizer sgd \ 11 | --scheduler step \ 12 | --lr 0.3 \ 13 | --lr_decay_steps 60 80 \ 14 | --weight_decay 0 \ 15 | --batch_size 256 \ 16 | --num_workers 7 \ 17 | --dali \ 18 | --name vicreg-domainnet_all-linear-eval \ 19 | --pretrained_feature_extractor $PRETRAINED_PATH \ 20 | --project ever-learn \ 21 | --entity unitn-mhug \ 22 | --wandb \ 23 | --save_checkpoint 24 | 25 | # quickdraw 26 | python3 main_linear.py \ 27 | --dataset domainnet \ 28 | --encoder resnet18 \ 29 | --data_dir $DATA_DIR/domainnet \ 30 | --split_strategy domain \ 31 | --domain quickdraw \ 32 | --max_epochs 100 \ 33 | --gpus 0 \ 34 | --precision 16 \ 35 | --optimizer sgd \ 36 | --scheduler step \ 37 | --lr 0.3 \ 38 | --lr_decay_steps 60 80 \ 39 | --weight_decay 0 \ 40 | --batch_size 256 \ 41 | --num_workers 7 \ 42 | --dali \ 43 | --name vicreg-domainnet_quickdraw-linear-eval \ 44 | --pretrained_feature_extractor $PRETRAINED_PATH \ 45 | --project ever-learn \ 46 | --entity unitn-mhug \ 47 | --wandb \ 48 | --save_checkpoint 49 | 50 | # clipart 51 | python3 main_linear.py \ 52 | --dataset domainnet \ 53 | --encoder resnet18 \ 54 | --data_dir $DATA_DIR/domainnet \ 55 | --split_strategy domain \ 56 | --domain clipart \ 57 | --max_epochs 100 \ 58 | --gpus 0 \ 59 | --precision 16 \ 60 | --optimizer sgd \ 61 | --scheduler step \ 62 | --lr 0.3 \ 63 | --lr_decay_steps 60 80 \ 64 | --weight_decay 0 \ 65 | --batch_size 256 \ 66 | --num_workers 7 \ 67 | --dali \ 68 | --name vicreg-domainnet_clipart-linear-eval \ 69 | --pretrained_feature_extractor $PRETRAINED_PATH \ 70 | --project ever-learn \ 71 | --entity unitn-mhug \ 72 | --wandb \ 73 | --save_checkpoint 74 | 75 | # infograph 76 | python3 main_linear.py \ 77 | --dataset domainnet \ 78 | --encoder resnet18 \ 79 | --data_dir $DATA_DIR/domainnet \ 80 | --split_strategy domain \ 81 | --domain infograph \ 82 | --max_epochs 100 \ 83 | --gpus 0 \ 84 | --precision 16 \ 85 | --optimizer sgd \ 86 | --scheduler step \ 87 | --lr 0.3 \ 88 | --lr_decay_steps 60 80 \ 89 | --weight_decay 0 \ 90 | --batch_size 256 \ 91 | --num_workers 7 \ 92 | --dali \ 93 | --name vicreg-domainnet_infograph-linear-eval \ 94 | --pretrained_feature_extractor $PRETRAINED_PATH \ 95 | --project ever-learn \ 96 | --entity unitn-mhug \ 97 | --wandb \ 98 | --save_checkpoint 99 | 100 | # painting 101 | python3 main_linear.py \ 102 | --dataset domainnet \ 103 | --encoder resnet18 \ 104 | --data_dir $DATA_DIR/domainnet \ 105 | --split_strategy domain \ 106 | --domain painting \ 107 | --max_epochs 100 \ 108 | --gpus 0 \ 109 | --precision 16 \ 110 | --optimizer sgd \ 111 | --scheduler step \ 112 | --lr 0.3 \ 113 | --lr_decay_steps 60 80 \ 114 | --weight_decay 0 \ 115 | --batch_size 256 \ 116 | --num_workers 7 \ 117 | --dali \ 118 | --name vicreg-domainnet_painting-linear-eval \ 119 | --pretrained_feature_extractor $PRETRAINED_PATH \ 120 | --project ever-learn \ 121 | --entity unitn-mhug \ 122 | --wandb \ 123 | --save_checkpoint 124 | 125 | # real 126 | python3 main_linear.py \ 127 | --dataset domainnet \ 128 | --encoder resnet18 \ 129 | --data_dir $DATA_DIR/domainnet \ 130 | --split_strategy domain \ 131 | --domain real \ 132 | --max_epochs 100 \ 133 | --gpus 0 \ 134 | --precision 16 \ 135 | --optimizer sgd \ 136 | --scheduler step \ 137 | --lr 0.3 \ 138 | --lr_decay_steps 60 80 \ 139 | --weight_decay 0 \ 140 | --batch_size 256 \ 141 | --num_workers 7 \ 142 | --dali \ 143 | --name vicreg-domainnet_real-linear-eval \ 144 | --pretrained_feature_extractor $PRETRAINED_PATH \ 145 | --project ever-learn \ 146 | --entity unitn-mhug \ 147 | --wandb \ 148 | --save_checkpoint 149 | 150 | # sketch 151 | python3 main_linear.py \ 152 | --dataset domainnet \ 153 | --encoder resnet18 \ 154 | --data_dir $DATA_DIR/domainnet \ 155 | --split_strategy domain \ 156 | --domain sketch \ 157 | --max_epochs 100 \ 158 | --gpus 0 \ 159 | --precision 16 \ 160 | --optimizer sgd \ 161 | --scheduler step \ 162 | --lr 0.3 \ 163 | --lr_decay_steps 60 80 \ 164 | --weight_decay 0 \ 165 | --batch_size 256 \ 166 | --num_workers 7 \ 167 | --dali \ 168 | --name vicreg-domainnet_sketch-linear-eval \ 169 | --pretrained_feature_extractor $PRETRAINED_PATH \ 170 | --project ever-learn \ 171 | --entity unitn-mhug \ 172 | --wandb \ 173 | --save_checkpoint -------------------------------------------------------------------------------- /bash_files/linear/domainnet/domain/supcon_linear.sh: -------------------------------------------------------------------------------- 1 | # all 2 | python3 main_linear.py \ 3 | --dataset domainnet \ 4 | --encoder resnet18 \ 5 | --data_dir $DATA_DIR/domainnet \ 6 | --split_strategy domain \ 7 | --max_epochs 100 \ 8 | --gpus 0 \ 9 | --precision 16 \ 10 | --optimizer sgd \ 11 | --scheduler step \ 12 | --lr 1.0 \ 13 | --lr_decay_steps 60 80 \ 14 | --weight_decay 0 \ 15 | --batch_size 256 \ 16 | --num_workers 7 \ 17 | --dali \ 18 | --name supcon-domainnet_all-linear-eval \ 19 | --pretrained_feature_extractor $PRETRAINED_PATH \ 20 | --project ever-learn \ 21 | --entity unitn-mhug \ 22 | --wandb \ 23 | --save_checkpoint 24 | 25 | # quickdraw 26 | python3 main_linear.py \ 27 | --dataset domainnet \ 28 | --encoder resnet18 \ 29 | --data_dir $DATA_DIR/domainnet \ 30 | --split_strategy domain \ 31 | --domain quickdraw \ 32 | --max_epochs 100 \ 33 | --gpus 0 \ 34 | --precision 16 \ 35 | --optimizer sgd \ 36 | --scheduler step \ 37 | --lr 1.0 \ 38 | --lr_decay_steps 60 80 \ 39 | --weight_decay 0 \ 40 | --batch_size 256 \ 41 | --num_workers 7 \ 42 | --dali \ 43 | --name supcon-domainnet_quickdraw-linear-eval \ 44 | --pretrained_feature_extractor $PRETRAINED_PATH \ 45 | --project ever-learn \ 46 | --entity unitn-mhug \ 47 | --wandb \ 48 | --save_checkpoint 49 | 50 | # clipart 51 | python3 main_linear.py \ 52 | --dataset domainnet \ 53 | --encoder resnet18 \ 54 | --data_dir $DATA_DIR/domainnet \ 55 | --split_strategy domain \ 56 | --domain clipart \ 57 | --max_epochs 100 \ 58 | --gpus 0 \ 59 | --precision 16 \ 60 | --optimizer sgd \ 61 | --scheduler step \ 62 | --lr 1.0 \ 63 | --lr_decay_steps 60 80 \ 64 | --weight_decay 0 \ 65 | --batch_size 256 \ 66 | --num_workers 7 \ 67 | --dali \ 68 | --name supcon-domainnet_clipart-linear-eval \ 69 | --pretrained_feature_extractor $PRETRAINED_PATH \ 70 | --project ever-learn \ 71 | --entity unitn-mhug \ 72 | --wandb \ 73 | --save_checkpoint 74 | 75 | # infograph 76 | python3 main_linear.py \ 77 | --dataset domainnet \ 78 | --encoder resnet18 \ 79 | --data_dir $DATA_DIR/domainnet \ 80 | --split_strategy domain \ 81 | --domain infograph \ 82 | --max_epochs 100 \ 83 | --gpus 0 \ 84 | --precision 16 \ 85 | --optimizer sgd \ 86 | --scheduler step \ 87 | --lr 1.0 \ 88 | --lr_decay_steps 60 80 \ 89 | --weight_decay 0 \ 90 | --batch_size 256 \ 91 | --num_workers 7 \ 92 | --dali \ 93 | --name supcon-domainnet_infograph-linear-eval \ 94 | --pretrained_feature_extractor $PRETRAINED_PATH \ 95 | --project ever-learn \ 96 | --entity unitn-mhug \ 97 | --wandb \ 98 | --save_checkpoint 99 | 100 | # painting 101 | python3 main_linear.py \ 102 | --dataset domainnet \ 103 | --encoder resnet18 \ 104 | --data_dir $DATA_DIR/domainnet \ 105 | --split_strategy domain \ 106 | --domain painting \ 107 | --max_epochs 100 \ 108 | --gpus 0 \ 109 | --precision 16 \ 110 | --optimizer sgd \ 111 | --scheduler step \ 112 | --lr 1.0 \ 113 | --lr_decay_steps 60 80 \ 114 | --weight_decay 0 \ 115 | --batch_size 256 \ 116 | --num_workers 7 \ 117 | --dali \ 118 | --name supcon-domainnet_painting-linear-eval \ 119 | --pretrained_feature_extractor $PRETRAINED_PATH \ 120 | --project ever-learn \ 121 | --entity unitn-mhug \ 122 | --wandb \ 123 | --save_checkpoint 124 | 125 | # real 126 | python3 main_linear.py \ 127 | --dataset domainnet \ 128 | --encoder resnet18 \ 129 | --data_dir $DATA_DIR/domainnet \ 130 | --split_strategy domain \ 131 | --domain real \ 132 | --max_epochs 100 \ 133 | --gpus 0 \ 134 | --precision 16 \ 135 | --optimizer sgd \ 136 | --scheduler step \ 137 | --lr 1.0 \ 138 | --lr_decay_steps 60 80 \ 139 | --weight_decay 0 \ 140 | --batch_size 256 \ 141 | --num_workers 7 \ 142 | --dali \ 143 | --name supcon-domainnet_real-linear-eval \ 144 | --pretrained_feature_extractor $PRETRAINED_PATH \ 145 | --project ever-learn \ 146 | --entity unitn-mhug \ 147 | --wandb \ 148 | --save_checkpoint 149 | 150 | # sketch 151 | python3 main_linear.py \ 152 | --dataset domainnet \ 153 | --encoder resnet18 \ 154 | --data_dir $DATA_DIR/domainnet \ 155 | --split_strategy domain \ 156 | --domain sketch \ 157 | --max_epochs 100 \ 158 | --gpus 0 \ 159 | --precision 16 \ 160 | --optimizer sgd \ 161 | --scheduler step \ 162 | --lr 1.0 \ 163 | --lr_decay_steps 60 80 \ 164 | --weight_decay 0 \ 165 | --batch_size 256 \ 166 | --num_workers 7 \ 167 | --dali \ 168 | --name supcon-domainnet_sketch-linear-eval \ 169 | --pretrained_feature_extractor $PRETRAINED_PATH \ 170 | --project ever-learn \ 171 | --entity unitn-mhug \ 172 | --wandb \ 173 | --save_checkpoint 174 | -------------------------------------------------------------------------------- /bash_files/linear/domainnet/domain/simclr_linear.sh: -------------------------------------------------------------------------------- 1 | # all 2 | python3 main_linear.py \ 3 | --dataset domainnet \ 4 | --encoder resnet18 \ 5 | --data_dir $DATA_DIR/domainnet \ 6 | --split_strategy domain \ 7 | --max_epochs 100 \ 8 | --gpus 0 \ 9 | --precision 16 \ 10 | --optimizer sgd \ 11 | --scheduler step \ 12 | --lr 1.0 \ 13 | --lr_decay_steps 60 80 \ 14 | --weight_decay 0 \ 15 | --batch_size 256 \ 16 | --num_workers 7 \ 17 | --dali \ 18 | --name simclr-domainnet_all-linear-eval \ 19 | --pretrained_feature_extractor $PRETRAINED_PATH \ 20 | --project ever-learn \ 21 | --entity unitn-mhug \ 22 | --wandb \ 23 | --save_checkpoint 24 | 25 | # quickdraw 26 | python3 main_linear.py \ 27 | --dataset domainnet \ 28 | --encoder resnet18 \ 29 | --data_dir $DATA_DIR/domainnet \ 30 | --split_strategy domain \ 31 | --domain quickdraw \ 32 | --max_epochs 100 \ 33 | --gpus 0 \ 34 | --precision 16 \ 35 | --optimizer sgd \ 36 | --scheduler step \ 37 | --lr 1.0 \ 38 | --lr_decay_steps 60 80 \ 39 | --weight_decay 0 \ 40 | --batch_size 256 \ 41 | --num_workers 7 \ 42 | --dali \ 43 | --name simclr-domainnet_quickdraw-linear-eval \ 44 | --pretrained_feature_extractor $PRETRAINED_PATH \ 45 | --project ever-learn \ 46 | --entity unitn-mhug \ 47 | --wandb \ 48 | --save_checkpoint 49 | 50 | # clipart 51 | python3 main_linear.py \ 52 | --dataset domainnet \ 53 | --encoder resnet18 \ 54 | --data_dir $DATA_DIR/domainnet \ 55 | --split_strategy domain \ 56 | --domain clipart \ 57 | --max_epochs 100 \ 58 | --gpus 0 \ 59 | --precision 16 \ 60 | --optimizer sgd \ 61 | --scheduler step \ 62 | --lr 1.0 \ 63 | --lr_decay_steps 60 80 \ 64 | --weight_decay 0 \ 65 | --batch_size 256 \ 66 | --num_workers 7 \ 67 | --dali \ 68 | --name simclr-domainnet_clipart-linear-eval \ 69 | --pretrained_feature_extractor $PRETRAINED_PATH \ 70 | --project ever-learn \ 71 | --entity unitn-mhug \ 72 | --wandb \ 73 | --save_checkpoint 74 | 75 | # infograph 76 | python3 main_linear.py \ 77 | --dataset domainnet \ 78 | --encoder resnet18 \ 79 | --data_dir $DATA_DIR/domainnet \ 80 | --split_strategy domain \ 81 | --domain infograph \ 82 | --max_epochs 100 \ 83 | --gpus 0 \ 84 | --precision 16 \ 85 | --optimizer sgd \ 86 | --scheduler step \ 87 | --lr 1.0 \ 88 | --lr_decay_steps 60 80 \ 89 | --weight_decay 0 \ 90 | --batch_size 256 \ 91 | --num_workers 7 \ 92 | --dali \ 93 | --name simclr-domainnet_infograph-linear-eval \ 94 | --pretrained_feature_extractor $PRETRAINED_PATH \ 95 | --project ever-learn \ 96 | --entity unitn-mhug \ 97 | --wandb \ 98 | --save_checkpoint 99 | 100 | # painting 101 | python3 main_linear.py \ 102 | --dataset domainnet \ 103 | --encoder resnet18 \ 104 | --data_dir $DATA_DIR/domainnet \ 105 | --split_strategy domain \ 106 | --domain painting \ 107 | --max_epochs 100 \ 108 | --gpus 0 \ 109 | --precision 16 \ 110 | --optimizer sgd \ 111 | --scheduler step \ 112 | --lr 1.0 \ 113 | --lr_decay_steps 60 80 \ 114 | --weight_decay 0 \ 115 | --batch_size 256 \ 116 | --num_workers 7 \ 117 | --dali \ 118 | --name simclr-domainnet_painting-linear-eval \ 119 | --pretrained_feature_extractor $PRETRAINED_PATH \ 120 | --project ever-learn \ 121 | --entity unitn-mhug \ 122 | --wandb \ 123 | --save_checkpoint 124 | 125 | # real 126 | python3 main_linear.py \ 127 | --dataset domainnet \ 128 | --encoder resnet18 \ 129 | --data_dir $DATA_DIR/domainnet \ 130 | --split_strategy domain \ 131 | --domain real \ 132 | --max_epochs 100 \ 133 | --gpus 0 \ 134 | --precision 16 \ 135 | --optimizer sgd \ 136 | --scheduler step \ 137 | --lr 1.0 \ 138 | --lr_decay_steps 60 80 \ 139 | --weight_decay 0 \ 140 | --batch_size 256 \ 141 | --num_workers 7 \ 142 | --dali \ 143 | --name simclr-domainnet_real-linear-eval \ 144 | --pretrained_feature_extractor $PRETRAINED_PATH \ 145 | --project ever-learn \ 146 | --entity unitn-mhug \ 147 | --wandb \ 148 | --save_checkpoint 149 | 150 | # sketch 151 | python3 main_linear.py \ 152 | --dataset domainnet \ 153 | --encoder resnet18 \ 154 | --data_dir $DATA_DIR/domainnet \ 155 | --split_strategy domain \ 156 | --domain sketch \ 157 | --max_epochs 100 \ 158 | --gpus 0 \ 159 | --precision 16 \ 160 | --optimizer sgd \ 161 | --scheduler step \ 162 | --lr 1.0 \ 163 | --lr_decay_steps 60 80 \ 164 | --weight_decay 0 \ 165 | --batch_size 256 \ 166 | --num_workers 7 \ 167 | --dali \ 168 | --name simclr-domainnet_sketch-linear-eval \ 169 | --pretrained_feature_extractor $PRETRAINED_PATH \ 170 | --project ever-learn \ 171 | --entity unitn-mhug \ 172 | --wandb \ 173 | --save_checkpoint 174 | 175 | -------------------------------------------------------------------------------- /bash_files/linear/domainnet/domain/mocov2plus_linear.sh: -------------------------------------------------------------------------------- 1 | # all 2 | python3 main_linear.py \ 3 | --dataset domainnet \ 4 | --encoder resnet18 \ 5 | --data_dir $DATA_DIR/domainnet \ 6 | --split_strategy domain \ 7 | --max_epochs 100 \ 8 | --gpus 0 \ 9 | --precision 16 \ 10 | --optimizer sgd \ 11 | --scheduler step \ 12 | --lr 3.0 \ 13 | --lr_decay_steps 60 80 \ 14 | --weight_decay 0 \ 15 | --batch_size 256 \ 16 | --num_workers 10 \ 17 | --dali \ 18 | --name mocov2plus-domainnet_all-linear-eval \ 19 | --pretrained_feature_extractor $PRETRAINED_PATH \ 20 | --project ever-learn \ 21 | --entity unitn-mhug \ 22 | --wandb \ 23 | --save_checkpoint 24 | 25 | # quickdraw 26 | python3 main_linear.py \ 27 | --dataset domainnet \ 28 | --encoder resnet18 \ 29 | --data_dir $DATA_DIR/domainnet \ 30 | --split_strategy domain \ 31 | --domain quickdraw \ 32 | --max_epochs 100 \ 33 | --gpus 0 \ 34 | --precision 16 \ 35 | --optimizer sgd \ 36 | --scheduler step \ 37 | --lr 3.0 \ 38 | --lr_decay_steps 60 80 \ 39 | --weight_decay 0 \ 40 | --batch_size 256 \ 41 | --num_workers 10 \ 42 | --dali \ 43 | --name mocov2plus-domainnet_quickdraw-linear-eval \ 44 | --pretrained_feature_extractor $PRETRAINED_PATH \ 45 | --project ever-learn \ 46 | --entity unitn-mhug \ 47 | --wandb \ 48 | --save_checkpoint 49 | 50 | # clipart 51 | python3 main_linear.py \ 52 | --dataset domainnet \ 53 | --encoder resnet18 \ 54 | --data_dir $DATA_DIR/domainnet \ 55 | --split_strategy domain \ 56 | --domain clipart \ 57 | --max_epochs 100 \ 58 | --gpus 0 \ 59 | --precision 16 \ 60 | --optimizer sgd \ 61 | --scheduler step \ 62 | --lr 3.0 \ 63 | --lr_decay_steps 60 80 \ 64 | --weight_decay 0 \ 65 | --batch_size 256 \ 66 | --num_workers 10 \ 67 | --dali \ 68 | --name mocov2plus-domainnet_clipart-linear-eval \ 69 | --pretrained_feature_extractor $PRETRAINED_PATH \ 70 | --project ever-learn \ 71 | --entity unitn-mhug \ 72 | --wandb \ 73 | --save_checkpoint 74 | 75 | # infograph 76 | python3 main_linear.py \ 77 | --dataset domainnet \ 78 | --encoder resnet18 \ 79 | --data_dir $DATA_DIR/domainnet \ 80 | --split_strategy domain \ 81 | --domain infograph \ 82 | --max_epochs 100 \ 83 | --gpus 0 \ 84 | --precision 16 \ 85 | --optimizer sgd \ 86 | --scheduler step \ 87 | --lr 3.0 \ 88 | --lr_decay_steps 60 80 \ 89 | --weight_decay 0 \ 90 | --batch_size 256 \ 91 | --num_workers 10 \ 92 | --dali \ 93 | --name mocov2plus-domainnet_infograph-linear-eval \ 94 | --pretrained_feature_extractor $PRETRAINED_PATH \ 95 | --project ever-learn \ 96 | --entity unitn-mhug \ 97 | --wandb \ 98 | --save_checkpoint 99 | 100 | # painting 101 | python3 main_linear.py \ 102 | --dataset domainnet \ 103 | --encoder resnet18 \ 104 | --data_dir $DATA_DIR/domainnet \ 105 | --split_strategy domain \ 106 | --domain painting \ 107 | --max_epochs 100 \ 108 | --gpus 0 \ 109 | --precision 16 \ 110 | --optimizer sgd \ 111 | --scheduler step \ 112 | --lr 3.0 \ 113 | --lr_decay_steps 60 80 \ 114 | --weight_decay 0 \ 115 | --batch_size 256 \ 116 | --num_workers 10 \ 117 | --dali \ 118 | --name mocov2plus-domainnet_painting-linear-eval \ 119 | --pretrained_feature_extractor $PRETRAINED_PATH \ 120 | --project ever-learn \ 121 | --entity unitn-mhug \ 122 | --wandb \ 123 | --save_checkpoint 124 | 125 | # real 126 | python3 main_linear.py \ 127 | --dataset domainnet \ 128 | --encoder resnet18 \ 129 | --data_dir $DATA_DIR/domainnet \ 130 | --split_strategy domain \ 131 | --domain real \ 132 | --max_epochs 100 \ 133 | --gpus 0 \ 134 | --precision 16 \ 135 | --optimizer sgd \ 136 | --scheduler step \ 137 | --lr 3.0 \ 138 | --lr_decay_steps 60 80 \ 139 | --weight_decay 0 \ 140 | --batch_size 256 \ 141 | --num_workers 10 \ 142 | --dali \ 143 | --name mocov2plus-domainnet_real-linear-eval \ 144 | --pretrained_feature_extractor $PRETRAINED_PATH \ 145 | --project ever-learn \ 146 | --entity unitn-mhug \ 147 | --wandb \ 148 | --save_checkpoint 149 | 150 | # sketch 151 | python3 main_linear.py \ 152 | --dataset domainnet \ 153 | --encoder resnet18 \ 154 | --data_dir $DATA_DIR/domainnet \ 155 | --split_strategy domain \ 156 | --domain sketch \ 157 | --max_epochs 100 \ 158 | --gpus 0 \ 159 | --precision 16 \ 160 | --optimizer sgd \ 161 | --scheduler step \ 162 | --lr 3.0 \ 163 | --lr_decay_steps 60 80 \ 164 | --weight_decay 0 \ 165 | --batch_size 256 \ 166 | --num_workers 10 \ 167 | --dali \ 168 | --name mocov2plus-domainnet_sketch-linear-eval \ 169 | --pretrained_feature_extractor $PRETRAINED_PATH \ 170 | --project ever-learn \ 171 | --entity unitn-mhug \ 172 | --wandb \ 173 | --save_checkpoint -------------------------------------------------------------------------------- /cassle/methods/simsiam.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Any, Dict, List, Sequence 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from cassle.losses.simsiam import simsiam_loss_func 8 | from cassle.methods.base import BaseModel 9 | 10 | 11 | class SimSiam(BaseModel): 12 | def __init__( 13 | self, 14 | output_dim: int, 15 | proj_hidden_dim: int, 16 | pred_hidden_dim: int, 17 | **kwargs, 18 | ): 19 | """Implements SimSiam (https://arxiv.org/abs/2011.10566). 20 | 21 | Args: 22 | output_dim (int): number of dimensions of projected features. 23 | proj_hidden_dim (int): number of neurons of the hidden layers of the projector. 24 | pred_hidden_dim (int): number of neurons of the hidden layers of the predictor. 25 | """ 26 | 27 | super().__init__(**kwargs) 28 | 29 | # projector 30 | self.projector = nn.Sequential( 31 | nn.Linear(self.features_dim, proj_hidden_dim, bias=False), 32 | nn.BatchNorm1d(proj_hidden_dim), 33 | nn.ReLU(), 34 | nn.Linear(proj_hidden_dim, proj_hidden_dim, bias=False), 35 | nn.BatchNorm1d(proj_hidden_dim), 36 | nn.ReLU(), 37 | nn.Linear(proj_hidden_dim, output_dim), 38 | nn.BatchNorm1d(output_dim, affine=False), 39 | ) 40 | self.projector[6].bias.requires_grad = False # hack: not use bias as it is followed by BN 41 | 42 | # predictor 43 | self.predictor = nn.Sequential( 44 | nn.Linear(output_dim, pred_hidden_dim, bias=False), 45 | nn.BatchNorm1d(pred_hidden_dim), 46 | nn.ReLU(), 47 | nn.Linear(pred_hidden_dim, output_dim), 48 | ) 49 | 50 | @staticmethod 51 | def add_model_specific_args(parent_parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 52 | parent_parser = super(SimSiam, SimSiam).add_model_specific_args(parent_parser) 53 | parser = parent_parser.add_argument_group("simsiam") 54 | 55 | # projector 56 | parser.add_argument("--output_dim", type=int, default=128) 57 | parser.add_argument("--proj_hidden_dim", type=int, default=2048) 58 | 59 | # predictor 60 | parser.add_argument("--pred_hidden_dim", type=int, default=512) 61 | return parent_parser 62 | 63 | @property 64 | def learnable_params(self) -> List[dict]: 65 | """Adds projector and predictor parameters to the parent's learnable parameters. 66 | 67 | Returns: 68 | List[dict]: list of learnable parameters. 69 | """ 70 | 71 | extra_learnable_params: List[dict] = [ 72 | {"params": self.projector.parameters()}, 73 | {"params": self.predictor.parameters(), "static_lr": True}, 74 | ] 75 | return super().learnable_params + extra_learnable_params 76 | 77 | def forward(self, X: torch.Tensor, *args, **kwargs) -> Dict[str, Any]: 78 | """Performs the forward pass of the encoder, the projector and the predictor. 79 | 80 | Args: 81 | X (torch.Tensor): a batch of images in the tensor format. 82 | 83 | Returns: 84 | Dict[str, Any]: 85 | a dict containing the outputs of the parent 86 | and the projected and predicted features. 87 | """ 88 | 89 | out = super().forward(X, *args, **kwargs) 90 | z = self.projector(out["feats"]) 91 | p = self.predictor(z) 92 | return {**out, "z": z, "p": p} 93 | 94 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: 95 | """Training step for SimSiam reusing BaseModel training step. 96 | 97 | Args: 98 | batch (Sequence[Any]): a batch of data in the format of [img_indexes, [X], Y], where 99 | [X] is a list of size self.num_crops containing batches of images 100 | batch_idx (int): index of the batch 101 | 102 | Returns: 103 | torch.Tensor: total loss composed of SimSiam loss and classification loss 104 | """ 105 | 106 | out = super().training_step(batch, batch_idx) 107 | feats1, feats2 = out["feats"] 108 | 109 | z1 = self.projector(feats1) 110 | z2 = self.projector(feats2) 111 | 112 | p1 = self.predictor(z1) 113 | p2 = self.predictor(z2) 114 | 115 | # ------- contrastive loss ------- 116 | neg_cos_sim = simsiam_loss_func(p1, z2) / 2 + simsiam_loss_func(p2, z1) / 2 117 | 118 | # calculate std of features 119 | z1_std = F.normalize(z1, dim=-1).std(dim=0).mean() 120 | z2_std = F.normalize(z2, dim=-1).std(dim=0).mean() 121 | z_std = (z1_std + z2_std) / 2 122 | 123 | metrics = { 124 | "train_neg_cos_sim": neg_cos_sim, 125 | "train_z_std": z_std, 126 | } 127 | self.log_dict(metrics, on_epoch=True, sync_dist=True) 128 | 129 | out.update({"loss": out["loss"] + neg_cos_sim, "z": [z1, z2]}) 130 | return out 131 | -------------------------------------------------------------------------------- /cassle/losses/simclr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from typing import Optional 4 | 5 | 6 | def simclr_distill_loss_func( 7 | p1: torch.Tensor, 8 | p2: torch.Tensor, 9 | z1: torch.Tensor, 10 | z2: torch.Tensor, 11 | temperature: float = 0.1, 12 | ) -> torch.Tensor: 13 | 14 | device = z1.device 15 | 16 | b = z1.size(0) 17 | 18 | p = F.normalize(torch.cat([p1, p2]), dim=-1) 19 | z = F.normalize(torch.cat([z1, z2]), dim=-1) 20 | 21 | logits = torch.einsum("if, jf -> ij", p, z) / temperature 22 | logits_max, _ = torch.max(logits, dim=1, keepdim=True) 23 | logits = logits - logits_max.detach() 24 | 25 | # positive mask are matches i, j (i from aug1, j from aug2), where i == j and matches j, i 26 | pos_mask = torch.zeros((2 * b, 2 * b), dtype=torch.bool, device=device) 27 | pos_mask.fill_diagonal_(True) 28 | 29 | # all matches excluding the main diagonal 30 | logit_mask = torch.ones_like(pos_mask, device=device) 31 | logit_mask.fill_diagonal_(True) 32 | logit_mask[:, b:].fill_diagonal_(True) 33 | logit_mask[b:, :].fill_diagonal_(True) 34 | 35 | exp_logits = torch.exp(logits) * logit_mask 36 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 37 | 38 | # compute mean of log-likelihood over positives 39 | mean_log_prob_pos = (pos_mask * log_prob).sum(1) / pos_mask.sum(1) 40 | # loss 41 | loss = -mean_log_prob_pos.mean() 42 | return loss 43 | 44 | 45 | def simclr_loss_func( 46 | z1: torch.Tensor, 47 | z2: torch.Tensor, 48 | temperature: float = 0.1, 49 | extra_pos_mask: Optional[torch.Tensor] = None, 50 | ) -> torch.Tensor: 51 | """Computes SimCLR's loss given batch of projected features z1 from view 1 and 52 | projected features z2 from view 2. 53 | 54 | Args: 55 | z1 (torch.Tensor): NxD Tensor containing projected features from view 1. 56 | z2 (torch.Tensor): NxD Tensor containing projected features from view 2. 57 | temperature (float): temperature factor for the loss. Defaults to 0.1. 58 | extra_pos_mask (Optional[torch.Tensor]): boolean mask containing extra positives other 59 | than normal across-view positives. Defaults to None. 60 | 61 | Returns: 62 | torch.Tensor: SimCLR loss. 63 | """ 64 | 65 | device = z1.device 66 | 67 | b = z1.size(0) 68 | z = torch.cat((z1, z2), dim=0) 69 | z = F.normalize(z, dim=-1) 70 | 71 | logits = torch.einsum("if, jf -> ij", z, z) / temperature 72 | logits_max, _ = torch.max(logits, dim=1, keepdim=True) 73 | logits = logits - logits_max.detach() 74 | 75 | # positive mask are matches i, j (i from aug1, j from aug2), where i == j and matches j, i 76 | pos_mask = torch.zeros((2 * b, 2 * b), dtype=torch.bool, device=device) 77 | pos_mask[:, b:].fill_diagonal_(True) 78 | pos_mask[b:, :].fill_diagonal_(True) 79 | 80 | # if we have extra "positives" 81 | if extra_pos_mask is not None: 82 | pos_mask = torch.bitwise_or(pos_mask, extra_pos_mask) 83 | 84 | # all matches excluding the main diagonal 85 | logit_mask = torch.ones_like(pos_mask, device=device).fill_diagonal_(0) 86 | 87 | exp_logits = torch.exp(logits) * logit_mask 88 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 89 | 90 | # compute mean of log-likelihood over positives 91 | mean_log_prob_pos = (pos_mask * log_prob).sum(1) / pos_mask.sum(1) 92 | # loss 93 | loss = -mean_log_prob_pos.mean() 94 | return loss 95 | 96 | 97 | def manual_simclr_loss_func( 98 | z: torch.Tensor, pos_mask: torch.Tensor, neg_mask: torch.Tensor, temperature: float = 0.1 99 | ) -> torch.Tensor: 100 | """Manually computes SimCLR's loss given batch of projected features z 101 | from different views, a positive boolean mask of all positives and 102 | a negative boolean mask of all negatives. 103 | 104 | Args: 105 | z (torch.Tensor): NxViewsxD Tensor containing projected features from the views. 106 | pos_mask (torch.Tensor): boolean mask containing all positives for z * z.T. 107 | neg_mask (torch.Tensor): boolean mask containing all negatives for z * z.T. 108 | temperature (float): temperature factor for the loss. 109 | 110 | Return: 111 | torch.Tensor: manual SimCLR loss. 112 | """ 113 | 114 | z = F.normalize(z, dim=-1) 115 | 116 | logits = torch.einsum("if, jf -> ij", z, z) / temperature 117 | logits_max, _ = torch.max(logits, dim=1, keepdim=True) 118 | logits = logits - logits_max.detach() 119 | 120 | negatives = torch.sum(torch.exp(logits) * neg_mask, dim=1, keepdim=True) 121 | exp_logits = torch.exp(logits) 122 | log_prob = torch.log(exp_logits / (exp_logits + negatives)) 123 | 124 | # compute mean of log-likelihood over positive 125 | mean_log_prob_pos = (pos_mask * log_prob).sum(1) 126 | 127 | indexes = pos_mask.sum(1) > 0 128 | pos_mask = pos_mask[indexes] 129 | mean_log_prob_pos = mean_log_prob_pos[indexes] / pos_mask.sum(1) 130 | 131 | # loss 132 | loss = -mean_log_prob_pos.mean() 133 | return loss 134 | --------------------------------------------------------------------------------