├── solo ├── models │ ├── __init__.py │ ├── model_with_linear.py │ └── wide_resnet.py ├── args │ ├── __init__.py │ ├── dataset.py │ ├── setup.py │ └── utils.py ├── __init__.py ├── methods │ ├── __init__.py │ ├── dali.py │ └── mocov2_distillation_AT.py ├── losses │ ├── wmse.py │ ├── byol.py │ ├── simsiam.py │ ├── nnclr.py │ ├── __init__.py │ ├── moco.py │ ├── swav.py │ ├── deepclusterv2.py │ ├── ressl.py │ ├── barlow.py │ ├── simclr.py │ ├── vibcreg.py │ ├── vicreg.py │ └── dino.py └── utils │ ├── __init__.py │ ├── metrics.py │ ├── auto_resumer.py │ ├── momentum.py │ ├── sinkhorn_knopp.py │ ├── lars.py │ ├── checkpointer.py │ ├── misc.py │ ├── kmeans.py │ ├── knn.py │ ├── whitening.py │ ├── classification_dataloader.py │ ├── classification_dataloader_AdvTraining.py │ └── auto_umap.py ├── TeacherCKPT └── README.md ├── figure └── DeACL.png ├── .gitignore ├── requirements.txt ├── bash_files ├── eval_cifar10_resnet18.sh ├── DeACL_cifar100_resnet18.sh ├── DeACL_cifar10_resnet18.sh └── DeACL_cifar10_resnet50.sh ├── trades └── trades.py ├── README.md └── main_pretrain_AdvTraining.py /solo/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /TeacherCKPT/README.md: -------------------------------------------------------------------------------- 1 | Put the teacher model ckpt here. -------------------------------------------------------------------------------- /figure/DeACL.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pantheon5100/DeACL/HEAD/figure/DeACL.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | 3 | 4 | wandb/ 5 | code/ 6 | trained_models/ 7 | ckeckpoint/ 8 | logger/ 9 | data/ 10 | 11 | checkpoint/ 12 | 13 | /TeacherCKPT/*.ckpt 14 | 15 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops 2 | pytorch-lightning==1.5.3 3 | torchmetrics==0.6.0 4 | lightning-bolts>=0.4.0 5 | tqdm 6 | wandb 7 | scipy 8 | timm==0.5.4 9 | setuptools==59.5.0 10 | -------------------------------------------------------------------------------- /bash_files/eval_cifar10_resnet18.sh: -------------------------------------------------------------------------------- 1 | # For SLF 2 | # python adv_finetune.py \ 3 | # --ckpt DEACL_WEIGHT_res18_simclr-cifar10-offline-x4h7cp45-ep=99.ckpt \ 4 | # --mode slf 5 | 6 | # For AFF 7 | # python adv_finetune.py \ 8 | # --ckpt DEACL_WEIGHT_res18_simclr-cifar10-offline-x4h7cp45-ep=99.ckpt \ 9 | # --mode aff 10 | 11 | # For ALF 12 | # python adv_finetune.py \ 13 | # --ckpt DEACL_WEIGHT_res18_simclr-cifar10-offline-x4h7cp45-ep=99.ckpt \ 14 | # --mode alf 15 | 16 | 17 | CKPT="trained_models/simclr/DEACL_WEIGHT_res18_simclr-cifar10-offline-x4h7cp45-ep=99.ckpt" 18 | # run the SLF 5 times 19 | for i in {1..5} 20 | do 21 | CUDA_VISIBLE_DEVICES=1 python adv_finetune.py \ 22 | --ckpt $CKPT \ 23 | --mode slf \ 24 | --learning_rate 0.1 25 | done 26 | -------------------------------------------------------------------------------- /bash_files/DeACL_cifar100_resnet18.sh: -------------------------------------------------------------------------------- 1 | python3 main_pretrain_AdvTraining.py \ 2 | --dataset cifar100 \ 3 | --backbone resnet18 \ 4 | --data_dir ./data \ 5 | --max_epochs 100 \ 6 | --gpus 1 \ 7 | --accelerator gpu \ 8 | --precision 16 \ 9 | --optimizer sgd \ 10 | --scheduler warmup_cosine \ 11 | --lr 0.5 \ 12 | --classifier_lr 0.5 \ 13 | --weight_decay 5e-4 \ 14 | --batch_size 256 \ 15 | --num_workers 4 \ 16 | --brightness 0.4 \ 17 | --contrast 0.4 \ 18 | --saturation 0.4 \ 19 | --hue 0.1 \ 20 | --gaussian_prob 0.0 0.0 \ 21 | --crop_size 32 \ 22 | --num_crops_per_aug 1 1 \ 23 | --name "res18_simclr-cifar10" \ 24 | --save_checkpoint \ 25 | --method mocov2_kd_at \ 26 | --limit_val_batches 0.2 \ 27 | --distillation_teacher "simclr_cifar10" \ 28 | --trades_k 2 29 | -------------------------------------------------------------------------------- /solo/models/model_with_linear.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class ModelwithLinear(nn.Module): 5 | def __init__(self, model, inplanes, num_classes=10): 6 | super(ModelwithLinear, self).__init__() 7 | self.model = model 8 | 9 | self.classifier = nn.Linear(inplanes, num_classes) 10 | 11 | def forward(self, img): 12 | x = self.model(img) 13 | out = self.classifier(x) 14 | return out 15 | 16 | class LinearClassifier(nn.Module): 17 | """Linear classifier""" 18 | 19 | def __init__(self, name='resnet50', feat_dim=512, num_classes=10): 20 | super(LinearClassifier, self).__init__() 21 | # _, feat_dim = model_dict[name] 22 | self.classifier = nn.Linear(feat_dim, num_classes) 23 | 24 | def forward(self, features): 25 | return self.classifier(features) 26 | -------------------------------------------------------------------------------- /bash_files/DeACL_cifar10_resnet18.sh: -------------------------------------------------------------------------------- 1 | export WANDB_API_KEY="" 2 | 3 | python3 main_pretrain_AdvTraining.py \ 4 | --dataset cifar10 \ 5 | --backbone resnet18 \ 6 | --data_dir ./data \ 7 | --max_epochs 100 \ 8 | --gpus 1 \ 9 | --accelerator gpu \ 10 | --precision 32 \ 11 | --optimizer sgd \ 12 | --scheduler warmup_cosine \ 13 | --lr 0.5 \ 14 | --classifier_lr 0.5 \ 15 | --weight_decay 5e-4 \ 16 | --batch_size 256 \ 17 | --num_workers 4 \ 18 | --brightness 0.4 \ 19 | --contrast 0.4 \ 20 | --saturation 0.4 \ 21 | --hue 0.1 \ 22 | --gaussian_prob 0.0 0.0 \ 23 | --crop_size 32 \ 24 | --num_crops_per_aug 1 1 \ 25 | --name "res18_simclr-cifar10-fp32" \ 26 | --save_checkpoint \ 27 | --method mocov2_kd_at \ 28 | --limit_val_batches 0.2 \ 29 | --distillation_teacher "simclr_cifar10" \ 30 | --trades_k 2 31 | -------------------------------------------------------------------------------- /bash_files/DeACL_cifar10_resnet50.sh: -------------------------------------------------------------------------------- 1 | python3 main_pretrain_AdvTraining.py \ 2 | --dataset cifar10 \ 3 | --backbone resnet50 \ 4 | --data_dir ./data \ 5 | --max_epochs 100 \ 6 | --gpus 4 \ 7 | --accelerator gpu \ 8 | --precision 16 \ 9 | --optimizer sgd \ 10 | --scheduler warmup_cosine \ 11 | --lr 0.5 \ 12 | --classifier_lr 0.5 \ 13 | --weight_decay 5e-4 \ 14 | --batch_size 256 \ 15 | --num_workers 4 \ 16 | --brightness 0.4 \ 17 | --contrast 0.4 \ 18 | --saturation 0.4 \ 19 | --hue 0.1 \ 20 | --gaussian_prob 0.0 0.0 \ 21 | --crop_size 32 \ 22 | --num_crops_per_aug 1 1 \ 23 | --name "res50_simclr-cifar10" \ 24 | --save_checkpoint \ 25 | --method mocov2_kd_at \ 26 | --queue_size 32768 \ 27 | --temperature 0.2 \ 28 | --limit_val_batches 0.2 \ 29 | --distillation_teacher "simclr_cifar10_resnet50" \ 30 | --trades_k 2 31 | -------------------------------------------------------------------------------- /solo/args/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | from solo.args import dataset, setup, utils 21 | 22 | __all__ = ["dataset", "setup", "utils"] 23 | -------------------------------------------------------------------------------- /solo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | 21 | from solo import args, losses, methods, utils 22 | 23 | __all__ = ["args", "losses", "methods", "utils"] 24 | -------------------------------------------------------------------------------- /solo/methods/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | from solo.methods.base_for_adversarial_training import BaseMethod 21 | 22 | from solo.methods.mocov2_distillation_AT import MoCoV2KDAT 23 | 24 | 25 | METHODS = { 26 | # base classes 27 | "base": BaseMethod, 28 | # methods 29 | "mocov2_kd_at": MoCoV2KDAT, 30 | } 31 | __all__ = [ 32 | "BaseMethod", 33 | "MoCoV2KDAT", 34 | ] 35 | 36 | try: 37 | from solo.methods import dali # noqa: F401 38 | except ImportError: 39 | pass 40 | else: 41 | __all__.append("dali") 42 | -------------------------------------------------------------------------------- /solo/losses/wmse.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | import torch.nn.functional as F 22 | 23 | 24 | def wmse_loss_func(z1: torch.Tensor, z2: torch.Tensor, simplified: bool = True) -> torch.Tensor: 25 | """Computes W-MSE's loss given two batches of whitened features z1 and z2. 26 | 27 | Args: 28 | z1 (torch.Tensor): NxD Tensor containing whitened features from view 1. 29 | z2 (torch.Tensor): NxD Tensor containing whitened features from view 2. 30 | simplified (bool): faster computation, but with same result. 31 | 32 | Returns: 33 | torch.Tensor: W-MSE loss. 34 | """ 35 | 36 | if simplified: 37 | return 2 - 2 * F.cosine_similarity(z1, z2.detach(), dim=-1).mean() 38 | 39 | z1 = F.normalize(z1, dim=-1) 40 | z2 = F.normalize(z2, dim=-1) 41 | 42 | return 2 - 2 * (z1 * z2).sum(dim=-1).mean() 43 | -------------------------------------------------------------------------------- /solo/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | from solo.utils import ( 21 | backbones, 22 | checkpointer, 23 | classification_dataloader, 24 | knn, 25 | lars, 26 | metrics, 27 | misc, 28 | momentum, 29 | pretrain_dataloader, 30 | sinkhorn_knopp, 31 | ) 32 | 33 | __all__ = [ 34 | "backbones", 35 | "classification_dataloader", 36 | "pretrain_dataloader", 37 | "checkpointer", 38 | "knn", 39 | "misc", 40 | "lars", 41 | "metrics", 42 | "momentum", 43 | "sinkhorn_knopp", 44 | ] 45 | 46 | try: 47 | from solo.utils import dali_dataloader # noqa: F401 48 | except ImportError: 49 | pass 50 | else: 51 | __all__.append("dali_dataloader") 52 | 53 | try: 54 | from solo.utils import auto_umap # noqa: F401 55 | except ImportError: 56 | pass 57 | else: 58 | __all__.append("auto_umap") 59 | -------------------------------------------------------------------------------- /solo/losses/byol.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | import torch.nn.functional as F 22 | 23 | 24 | def byol_loss_func(p: torch.Tensor, z: torch.Tensor, simplified: bool = True) -> torch.Tensor: 25 | """Computes BYOL's loss given batch of predicted features p and projected momentum features z. 26 | 27 | Args: 28 | p (torch.Tensor): NxD Tensor containing predicted features from view 1 29 | z (torch.Tensor): NxD Tensor containing projected momentum features from view 2 30 | simplified (bool): faster computation, but with same result. Defaults to True. 31 | 32 | Returns: 33 | torch.Tensor: BYOL's loss. 34 | """ 35 | 36 | if simplified: 37 | return 2 - 2 * F.cosine_similarity(p, z.detach(), dim=-1).mean() 38 | 39 | p = F.normalize(p, dim=-1) 40 | z = F.normalize(z, dim=-1) 41 | 42 | return 2 - 2 * (p * z.detach()).sum(dim=1).mean() 43 | -------------------------------------------------------------------------------- /solo/losses/simsiam.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | import torch.nn.functional as F 22 | 23 | 24 | def simsiam_loss_func(p: torch.Tensor, z: torch.Tensor, simplified: bool = True) -> torch.Tensor: 25 | """Computes SimSiam's loss given batch of predicted features p from view 1 and 26 | a batch of projected features z from view 2. 27 | 28 | Args: 29 | p (torch.Tensor): Tensor containing predicted features from view 1. 30 | z (torch.Tensor): Tensor containing projected features from view 2. 31 | simplified (bool): faster computation, but with same result. 32 | 33 | Returns: 34 | torch.Tensor: SimSiam loss. 35 | """ 36 | 37 | if simplified: 38 | return -F.cosine_similarity(p, z.detach(), dim=-1).mean() 39 | 40 | p = F.normalize(p, dim=-1) 41 | z = F.normalize(z, dim=-1) 42 | 43 | return -(p * z.detach()).sum(dim=1).mean() 44 | -------------------------------------------------------------------------------- /solo/losses/nnclr.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | import torch.nn.functional as F 22 | 23 | 24 | def nnclr_loss_func(nn: torch.Tensor, p: torch.Tensor, temperature: float = 0.1) -> torch.Tensor: 25 | """Computes NNCLR's loss given batch of nearest-neighbors nn from view 1 and 26 | predicted features p from view 2. 27 | 28 | Args: 29 | nn (torch.Tensor): NxD Tensor containing nearest neighbors' features from view 1. 30 | p (torch.Tensor): NxD Tensor containing predicted features from view 2 31 | temperature (float, optional): temperature of the softmax in the contrastive loss. Defaults 32 | to 0.1. 33 | 34 | Returns: 35 | torch.Tensor: NNCLR loss. 36 | """ 37 | 38 | nn = F.normalize(nn, dim=-1) 39 | p = F.normalize(p, dim=-1) 40 | 41 | logits = nn @ p.T / temperature 42 | 43 | n = p.size(0) 44 | labels = torch.arange(n, device=p.device) 45 | 46 | loss = F.cross_entropy(logits, labels) 47 | return loss 48 | -------------------------------------------------------------------------------- /solo/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | from solo.losses.barlow import barlow_loss_func 21 | from solo.losses.byol import byol_loss_func 22 | from solo.losses.deepclusterv2 import deepclusterv2_loss_func 23 | from solo.losses.dino import DINOLoss 24 | from solo.losses.moco import moco_loss_func 25 | from solo.losses.nnclr import nnclr_loss_func 26 | from solo.losses.ressl import ressl_loss_func 27 | from solo.losses.simclr import simclr_loss_func 28 | from solo.losses.simsiam import simsiam_loss_func 29 | from solo.losses.swav import swav_loss_func 30 | from solo.losses.vibcreg import vibcreg_loss_func 31 | from solo.losses.vicreg import vicreg_loss_func 32 | from solo.losses.wmse import wmse_loss_func 33 | 34 | __all__ = [ 35 | "barlow_loss_func", 36 | "byol_loss_func", 37 | "deepclusterv2_loss_func", 38 | "DINOLoss", 39 | "moco_loss_func", 40 | "nnclr_loss_func", 41 | "ressl_loss_func", 42 | "simclr_loss_func", 43 | "simsiam_loss_func", 44 | "swav_loss_func", 45 | "vibcreg_loss_func", 46 | "vicreg_loss_func", 47 | "wmse_loss_func", 48 | ] 49 | -------------------------------------------------------------------------------- /solo/losses/moco.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | import torch.nn.functional as F 22 | 23 | 24 | def moco_loss_func( 25 | query: torch.Tensor, key: torch.Tensor, queue: torch.Tensor, temperature=0.1 26 | ) -> torch.Tensor: 27 | """Computes MoCo's loss given a batch of queries from view 1, a batch of keys from view 2 and a 28 | queue of past elements. 29 | 30 | Args: 31 | query (torch.Tensor): NxD Tensor containing the queries from view 1. 32 | key (torch.Tensor): NxD Tensor containing the queries from view 2. 33 | queue (torch.Tensor): a queue of negative samples for the contrastive loss. 34 | temperature (float, optional): [description]. temperature of the softmax in the contrastive 35 | loss. Defaults to 0.1. 36 | 37 | Returns: 38 | torch.Tensor: MoCo loss. 39 | """ 40 | 41 | pos = torch.einsum("nc,nc->n", [query, key]).unsqueeze(-1) 42 | neg = torch.einsum("nc,ck->nk", [query, queue]) 43 | logits = torch.cat([pos, neg], dim=1) 44 | logits /= temperature 45 | targets = torch.zeros(query.size(0), device=query.device, dtype=torch.long) 46 | return F.cross_entropy(logits, targets) 47 | -------------------------------------------------------------------------------- /solo/losses/swav.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | from typing import List 21 | 22 | import numpy as np 23 | import torch 24 | 25 | 26 | def swav_loss_func( 27 | preds: List[torch.Tensor], assignments: List[torch.Tensor], temperature: float = 0.1 28 | ) -> torch.Tensor: 29 | """Computes SwAV's loss given list of batch predictions from multiple views 30 | and a list of cluster assignments from the same multiple views. 31 | 32 | Args: 33 | preds (torch.Tensor): list of NxC Tensors containing nearest neighbors' features from 34 | view 1. 35 | assignments (torch.Tensor): list of NxC Tensor containing predicted features from view 2. 36 | temperature (torch.Tensor): softmax temperature for the loss. Defaults to 0.1. 37 | 38 | Returns: 39 | torch.Tensor: SwAV loss. 40 | """ 41 | 42 | losses = [] 43 | for v1 in range(len(preds)): 44 | for v2 in np.delete(np.arange(len(preds)), v1): 45 | a = assignments[v1] 46 | p = preds[v2] / temperature 47 | loss = -torch.mean(torch.sum(a * torch.log_softmax(p, dim=1), dim=1)) 48 | losses.append(loss) 49 | return sum(losses) / len(losses) 50 | -------------------------------------------------------------------------------- /solo/losses/deepclusterv2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | import torch.nn.functional as F 22 | 23 | 24 | def deepclusterv2_loss_func( 25 | outputs: torch.Tensor, assignments: torch.Tensor, temperature: float = 0.1 26 | ) -> torch.Tensor: 27 | """Computes DeepClusterV2's loss given a tensor containing logits from multiple views 28 | and a tensor containing cluster assignments from the same multiple views. 29 | 30 | Args: 31 | outputs (torch.Tensor): tensor of size PxVxNxC where P is the number of prototype 32 | layers and V is the number of views. 33 | assignments (torch.Tensor): tensor of size PxVxNxC containing the assignments 34 | generated using k-means. 35 | temperature (float, optional): softmax temperature for the loss. Defaults to 0.1. 36 | 37 | Returns: 38 | torch.Tensor: DeepClusterV2 loss. 39 | """ 40 | loss = 0 41 | for h in range(outputs.size(0)): 42 | scores = outputs[h].view(-1, outputs.size(-1)) / temperature 43 | targets = assignments[h].repeat(outputs.size(1)).to(outputs.device, non_blocking=True) 44 | loss += F.cross_entropy(scores, targets, ignore_index=-1) 45 | return loss / outputs.size(0) 46 | -------------------------------------------------------------------------------- /solo/losses/ressl.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | import torch.nn.functional as F 22 | 23 | 24 | def ressl_loss_func( 25 | q: torch.Tensor, 26 | k: torch.Tensor, 27 | queue: torch.Tensor, 28 | temperature_q: float = 0.1, 29 | temperature_k: float = 0.04, 30 | ) -> torch.Tensor: 31 | """Computes ReSSL's loss given a batch of queries from view 1, a batch of keys from view 2 and a 32 | queue of past elements. 33 | 34 | Args: 35 | query (torch.Tensor): NxD Tensor containing the queries from view 1. 36 | key (torch.Tensor): NxD Tensor containing the queries from view 2. 37 | queue (torch.Tensor): a queue of negative samples for the contrastive loss. 38 | temperature_q (float, optional): [description]. temperature of the softmax for the query. 39 | Defaults to 0.1. 40 | temperature_k (float, optional): [description]. temperature of the softmax for the key. 41 | Defaults to 0.04. 42 | 43 | Returns: 44 | torch.Tensor: ReSSL loss. 45 | """ 46 | 47 | logits_q = torch.einsum("nc,kc->nk", [q, queue]) 48 | logits_k = torch.einsum("nc,kc->nk", [k, queue]) 49 | 50 | loss = -torch.sum( 51 | F.softmax(logits_k.detach() / temperature_k, dim=1) 52 | * F.log_softmax(logits_q / temperature_q, dim=1), 53 | dim=1, 54 | ).mean() 55 | 56 | return loss 57 | -------------------------------------------------------------------------------- /solo/losses/barlow.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | 22 | import torch.distributed as dist 23 | 24 | 25 | def barlow_loss_func( 26 | z1: torch.Tensor, z2: torch.Tensor, lamb: float = 5e-3, scale_loss: float = 0.025 27 | ) -> torch.Tensor: 28 | """Computes Barlow Twins' loss given batch of projected features z1 from view 1 and 29 | projected features z2 from view 2. 30 | 31 | Args: 32 | z1 (torch.Tensor): NxD Tensor containing projected features from view 1. 33 | z2 (torch.Tensor): NxD Tensor containing projected features from view 2. 34 | lamb (float, optional): off-diagonal scaling factor for the cross-covariance matrix. 35 | Defaults to 5e-3. 36 | scale_loss (float, optional): final scaling factor of the loss. Defaults to 0.025. 37 | 38 | Returns: 39 | torch.Tensor: Barlow Twins' loss. 40 | """ 41 | 42 | N, D = z1.size() 43 | 44 | # to match the original code 45 | bn = torch.nn.BatchNorm1d(D, affine=False).to(z1.device) 46 | z1 = bn(z1) 47 | z2 = bn(z2) 48 | 49 | corr = torch.einsum("bi, bj -> ij", z1, z2) / N 50 | 51 | if dist.is_available() and dist.is_initialized(): 52 | dist.all_reduce(corr) 53 | world_size = dist.get_world_size() 54 | corr /= world_size 55 | 56 | diag = torch.eye(D, device=corr.device) 57 | cdif = (corr - diag).pow(2) 58 | cdif[~diag.bool()] *= lamb 59 | loss = scale_loss * cdif.sum() 60 | return loss 61 | -------------------------------------------------------------------------------- /solo/losses/simclr.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | import torch.nn.functional as F 22 | from solo.utils.misc import gather, get_rank 23 | 24 | 25 | def simclr_loss_func( 26 | z: torch.Tensor, indexes: torch.Tensor, temperature: float = 0.1 27 | ) -> torch.Tensor: 28 | """Computes SimCLR's loss given batch of projected features z 29 | from different views, a positive boolean mask of all positives and 30 | a negative boolean mask of all negatives. 31 | 32 | Args: 33 | z (torch.Tensor): (N*views) x D Tensor containing projected features from the views. 34 | indexes (torch.Tensor): unique identifiers for each crop (unsupervised) 35 | or targets of each crop (supervised). 36 | 37 | Return: 38 | torch.Tensor: SimCLR loss. 39 | """ 40 | 41 | z = F.normalize(z, dim=-1) 42 | gathered_z = gather(z) 43 | 44 | sim = torch.exp(torch.einsum("if, jf -> ij", z, gathered_z) / temperature) 45 | 46 | gathered_indexes = gather(indexes) 47 | 48 | indexes = indexes.unsqueeze(0) 49 | gathered_indexes = gathered_indexes.unsqueeze(0) 50 | # positives 51 | pos_mask = indexes.t() == gathered_indexes 52 | pos_mask[:, z.size(0) * get_rank() :].fill_diagonal_(0) 53 | # negatives 54 | neg_mask = indexes.t() != gathered_indexes 55 | 56 | pos = torch.sum(sim * pos_mask, 1) 57 | neg = torch.sum(sim * neg_mask, 1) 58 | loss = -(torch.mean(torch.log(pos / (pos + neg)))) 59 | return loss 60 | -------------------------------------------------------------------------------- /solo/utils/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | from typing import Dict, List, Sequence 21 | 22 | import torch 23 | 24 | 25 | def accuracy_at_k( 26 | outputs: torch.Tensor, targets: torch.Tensor, top_k: Sequence[int] = (1, 5) 27 | ) -> Sequence[int]: 28 | """Computes the accuracy over the k top predictions for the specified values of k. 29 | 30 | Args: 31 | outputs (torch.Tensor): output of a classifier (logits or probabilities). 32 | targets (torch.Tensor): ground truth labels. 33 | top_k (Sequence[int], optional): sequence of top k values to compute the accuracy over. 34 | Defaults to (1, 5). 35 | 36 | Returns: 37 | Sequence[int]: accuracies at the desired k. 38 | """ 39 | 40 | with torch.no_grad(): 41 | maxk = max(top_k) 42 | batch_size = targets.size(0) 43 | 44 | _, pred = outputs.topk(maxk, 1, True, True) 45 | pred = pred.t() 46 | correct = pred.eq(targets.view(1, -1).expand_as(pred)) 47 | 48 | res = [] 49 | for k in top_k: 50 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) 51 | res.append(correct_k.mul_(100.0 / batch_size)) 52 | return res 53 | 54 | 55 | def weighted_mean(outputs: List[Dict], key: str, batch_size_key: str) -> float: 56 | """Computes the mean of the values of a key weighted by the batch size. 57 | 58 | Args: 59 | outputs (List[Dict]): list of dicts containing the outputs of a validation step. 60 | key (str): key of the metric of interest. 61 | batch_size_key (str): key of batch size values. 62 | 63 | Returns: 64 | float: weighted mean of the values of a key 65 | """ 66 | 67 | value = 0 68 | n = 0 69 | for out in outputs: 70 | value += out[batch_size_key] * out[key] 71 | n += out[batch_size_key] 72 | value = value / n 73 | return value.squeeze(0) 74 | -------------------------------------------------------------------------------- /solo/losses/vibcreg.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | from solo.losses.vicreg import invariance_loss, variance_loss 22 | from torch import Tensor 23 | from torch.nn import functional as F 24 | 25 | 26 | def covariance_loss(z1: Tensor, z2: Tensor) -> Tensor: 27 | """Computes normalized covariance loss given batch of projected features z1 from view 1 and 28 | projected features z2 from view 2. 29 | 30 | Args: 31 | z1 (torch.Tensor): NxD Tensor containing projected features from view 1. 32 | z2 (torch.Tensor): NxD Tensor containing projected features from view 2. 33 | 34 | Returns: 35 | torch.Tensor: covariance regularization loss. 36 | """ 37 | 38 | norm_z1 = z1 - z1.mean(dim=0) 39 | norm_z2 = z2 - z2.mean(dim=0) 40 | norm_z1 = F.normalize(norm_z1, p=2, dim=0) # (batch * feature); l2-norm 41 | norm_z2 = F.normalize(norm_z2, p=2, dim=0) 42 | fxf_cov_z1 = torch.mm(norm_z1.T, norm_z1) # (feature * feature) 43 | fxf_cov_z2 = torch.mm(norm_z2.T, norm_z2) 44 | fxf_cov_z1.fill_diagonal_(0.0) 45 | fxf_cov_z2.fill_diagonal_(0.0) 46 | cov_loss = (fxf_cov_z1 ** 2).mean() + (fxf_cov_z2 ** 2).mean() 47 | return cov_loss 48 | 49 | 50 | def vibcreg_loss_func( 51 | z1: torch.Tensor, 52 | z2: torch.Tensor, 53 | sim_loss_weight: float = 25.0, 54 | var_loss_weight: float = 25.0, 55 | cov_loss_weight: float = 200.0, 56 | ) -> torch.Tensor: 57 | """Computes VIbCReg's loss given batch of projected features z1 from view 1 and 58 | projected features z2 from view 2. 59 | 60 | Args: 61 | z1 (torch.Tensor): NxD Tensor containing projected features from view 1. 62 | z2 (torch.Tensor): NxD Tensor containing projected features from view 2. 63 | sim_loss_weight (float): invariance loss weight. 64 | var_loss_weight (float): variance loss weight. 65 | cov_loss_weight (float): covariance loss weight. 66 | 67 | Returns: 68 | torch.Tensor: VIbCReg loss. 69 | """ 70 | 71 | sim_loss = invariance_loss(z1, z2) 72 | var_loss = variance_loss(z1, z2) 73 | cov_loss = covariance_loss(z1, z2) 74 | 75 | loss = sim_loss_weight * sim_loss + var_loss_weight * var_loss + cov_loss_weight * cov_loss 76 | return loss 77 | -------------------------------------------------------------------------------- /solo/utils/auto_resumer.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from argparse import ArgumentParser, Namespace 4 | from collections import namedtuple 5 | from datetime import datetime, timedelta 6 | from pathlib import Path 7 | from typing import Union 8 | 9 | Checkpoint = namedtuple("Checkpoint", ["creation_time", "args", "checkpoint"]) 10 | 11 | 12 | class AutoResumer: 13 | SHOULD_MATCH = [ 14 | "batch_size", 15 | "weight_decay", 16 | "lr", 17 | "dataset", 18 | "backbone", 19 | "max_epochs", 20 | "method", 21 | "name", 22 | "project", 23 | "entity", 24 | ] 25 | 26 | def __init__( 27 | self, 28 | checkpoint_dir: Union[str, Path] = Path("trained_models"), 29 | max_hours: int = 36, 30 | ): 31 | """Autoresumer object that automatically tries to find a checkpoint 32 | that is as old as max_time. 33 | 34 | Args: 35 | checkpoint_dir (Union[str, Path], optional): base directory to store checkpoints. 36 | Defaults to "trained_models". 37 | max_hours (int): maximum elapsed hours to consider checkpoint as valid. 38 | """ 39 | 40 | self.checkpoint_dir = checkpoint_dir 41 | self.max_hours = timedelta(hours=max_hours) 42 | 43 | @staticmethod 44 | def add_autoresumer_args(parent_parser: ArgumentParser): 45 | """Adds user-required arguments to a parser. 46 | 47 | Args: 48 | parent_parser (ArgumentParser): parser to add new args to. 49 | """ 50 | 51 | parser = parent_parser.add_argument_group("autoresumer") 52 | parser.add_argument("--auto_resumer_max_hours", default=36, type=int) 53 | return parent_parser 54 | 55 | def find_checkpoint(self, args: Namespace): 56 | """Finds a valid checkpoint that matches the arguments 57 | 58 | Args: 59 | args (Namespace): namespace object containing all settings of the model. 60 | """ 61 | 62 | current_time = datetime.now() 63 | 64 | possible_checkpoints = [] 65 | for rootdir, _, files in os.walk(self.checkpoint_dir): 66 | rootdir = Path(rootdir) 67 | if files: 68 | # skip checkpoints that are empty 69 | try: 70 | checkpoint_file = [rootdir / f for f in files if f.endswith(".ckpt")][0] 71 | except: 72 | continue 73 | 74 | creation_time = datetime.fromtimestamp(os.path.getctime(checkpoint_file)) 75 | if current_time - creation_time < self.max_hours: 76 | ck = Checkpoint( 77 | creation_time=creation_time, 78 | args=rootdir / "args.json", 79 | checkpoint=checkpoint_file, 80 | ) 81 | possible_checkpoints.append(ck) 82 | 83 | if possible_checkpoints: 84 | # sort by most recent 85 | possible_checkpoints = sorted( 86 | possible_checkpoints, key=lambda ck: ck.creation_time, reverse=True 87 | ) 88 | 89 | for checkpoint in possible_checkpoints: 90 | checkpoint_args = Namespace(**json.load(open(checkpoint.args))) 91 | if all( 92 | getattr(checkpoint_args, param) == getattr(args, param) 93 | for param in AutoResumer.SHOULD_MATCH 94 | ): 95 | return checkpoint.checkpoint 96 | 97 | return None 98 | -------------------------------------------------------------------------------- /solo/utils/momentum.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import math 21 | 22 | import torch 23 | from torch import nn 24 | 25 | 26 | @torch.no_grad() 27 | def initialize_momentum_params(online_net: nn.Module, momentum_net: nn.Module): 28 | """Copies the parameters of the online network to the momentum network. 29 | 30 | Args: 31 | online_net (nn.Module): online network (e.g. online backbone, online projection, etc...). 32 | momentum_net (nn.Module): momentum network (e.g. momentum backbone, 33 | momentum projection, etc...). 34 | """ 35 | 36 | params_online = online_net.parameters() 37 | params_momentum = momentum_net.parameters() 38 | for po, pm in zip(params_online, params_momentum): 39 | pm.data.copy_(po.data) 40 | pm.requires_grad = False 41 | 42 | 43 | class MomentumUpdater: 44 | def __init__(self, base_tau: float = 0.996, final_tau: float = 1.0): 45 | """Updates momentum parameters using exponential moving average. 46 | 47 | Args: 48 | base_tau (float, optional): base value of the weight decrease coefficient 49 | (should be in [0,1]). Defaults to 0.996. 50 | final_tau (float, optional): final value of the weight decrease coefficient 51 | (should be in [0,1]). Defaults to 1.0. 52 | """ 53 | 54 | super().__init__() 55 | 56 | assert 0 <= base_tau <= 1 57 | assert 0 <= final_tau <= 1 and base_tau <= final_tau 58 | 59 | self.base_tau = base_tau 60 | self.cur_tau = base_tau 61 | self.final_tau = final_tau 62 | 63 | @torch.no_grad() 64 | def update(self, online_net: nn.Module, momentum_net: nn.Module): 65 | """Performs the momentum update for each param group. 66 | 67 | Args: 68 | online_net (nn.Module): online network (e.g. online backbone, online projection, etc...). 69 | momentum_net (nn.Module): momentum network (e.g. momentum backbone, 70 | momentum projection, etc...). 71 | """ 72 | 73 | for op, mp in zip(online_net.parameters(), momentum_net.parameters()): 74 | mp.data = self.cur_tau * mp.data + (1 - self.cur_tau) * op.data 75 | 76 | def update_tau(self, cur_step: int, max_steps: int): 77 | """Computes the next value for the weighting decrease coefficient tau using cosine annealing. 78 | 79 | Args: 80 | cur_step (int): number of gradient steps so far. 81 | max_steps (int): overall number of gradient steps in the whole training. 82 | """ 83 | 84 | self.cur_tau = ( 85 | self.final_tau 86 | - (self.final_tau - self.base_tau) * (math.cos(math.pi * cur_step / max_steps) + 1) / 2 87 | ) 88 | -------------------------------------------------------------------------------- /solo/utils/sinkhorn_knopp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | # Adapted from https://github.com/facebookresearch/swav. 21 | 22 | import torch 23 | import torch.distributed as dist 24 | 25 | 26 | class SinkhornKnopp(torch.nn.Module): 27 | def __init__(self, num_iters: int = 3, epsilon: float = 0.05, world_size: int = 1): 28 | """Approximates optimal transport using the Sinkhorn-Knopp algorithm. 29 | 30 | A simple iterative method to approach the double stochastic matrix is to alternately rescale 31 | rows and columns of the matrix to sum to 1. 32 | 33 | Args: 34 | num_iters (int, optional): number of times to perform row and column normalization. 35 | Defaults to 3. 36 | epsilon (float, optional): weight for the entropy regularization term. Defaults to 0.05. 37 | world_size (int, optional): number of nodes for distributed training. Defaults to 1. 38 | """ 39 | 40 | super().__init__() 41 | self.num_iters = num_iters 42 | self.epsilon = epsilon 43 | self.world_size = world_size 44 | 45 | @torch.no_grad() 46 | def forward(self, Q: torch.Tensor) -> torch.Tensor: 47 | """Produces assignments using Sinkhorn-Knopp algorithm. 48 | 49 | Applies the entropy regularization, normalizes the Q matrix and then normalizes rows and 50 | columns in an alternating fashion for num_iter times. Before returning it normalizes again 51 | the columns in order for the output to be an assignment of samples to prototypes. 52 | 53 | Args: 54 | Q (torch.Tensor): cosine similarities between the features of the 55 | samples and the prototypes. 56 | 57 | Returns: 58 | torch.Tensor: assignment of samples to prototypes according to optimal transport. 59 | """ 60 | 61 | Q = torch.exp(Q / self.epsilon).t() 62 | B = Q.shape[1] * self.world_size 63 | K = Q.shape[0] # num prototypes 64 | 65 | # make the matrix sums to 1 66 | sum_Q = torch.sum(Q) 67 | if dist.is_available() and dist.is_initialized(): 68 | dist.all_reduce(sum_Q) 69 | Q /= sum_Q 70 | 71 | for _ in range(self.num_iters): 72 | # normalize each row: total weight per prototype must be 1/K 73 | sum_of_rows = torch.sum(Q, dim=1, keepdim=True) 74 | if dist.is_available() and dist.is_initialized(): 75 | dist.all_reduce(sum_of_rows) 76 | Q /= sum_of_rows 77 | Q /= K 78 | 79 | # normalize each column: total weight per sample must be 1/B 80 | Q /= torch.sum(Q, dim=0, keepdim=True) 81 | Q /= B 82 | 83 | Q *= B # the colomns must sum to 1 so that Q is an assignment 84 | return Q.t() 85 | -------------------------------------------------------------------------------- /solo/args/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | from argparse import ArgumentParser 21 | from pathlib import Path 22 | 23 | 24 | def dataset_args(parser: ArgumentParser): 25 | """Adds dataset-related arguments to a parser. 26 | 27 | Args: 28 | parser (ArgumentParser): parser to add dataset args to. 29 | """ 30 | 31 | SUPPORTED_DATASETS = [ 32 | "cifar10", 33 | "cifar100", 34 | "stl10", 35 | "imagenet", 36 | "imagenet100", 37 | "custom", 38 | ] 39 | 40 | parser.add_argument("--dataset", choices=SUPPORTED_DATASETS, type=str, required=True) 41 | 42 | # dataset path 43 | parser.add_argument("--data_dir", type=Path, required=True) 44 | parser.add_argument("--train_dir", type=Path, default=None) 45 | parser.add_argument("--val_dir", type=Path, default=None) 46 | 47 | # dali (imagenet-100/imagenet/custom only) 48 | parser.add_argument("--dali", action="store_true") 49 | parser.add_argument("--dali_device", type=str, default="gpu") 50 | 51 | 52 | def augmentations_args(parser: ArgumentParser): 53 | """Adds augmentation-related arguments to a parser. 54 | 55 | Args: 56 | parser (ArgumentParser): parser to add augmentation args to. 57 | """ 58 | 59 | # cropping 60 | parser.add_argument("--num_crops_per_aug", type=int, default=[2], nargs="+") 61 | 62 | # color jitter 63 | parser.add_argument("--brightness", type=float, required=True, nargs="+") 64 | parser.add_argument("--contrast", type=float, required=True, nargs="+") 65 | parser.add_argument("--saturation", type=float, required=True, nargs="+") 66 | parser.add_argument("--hue", type=float, required=True, nargs="+") 67 | parser.add_argument("--color_jitter_prob", type=float, default=[0.8], nargs="+") 68 | 69 | # other augmentation probabilities 70 | parser.add_argument("--gray_scale_prob", type=float, default=[0.2], nargs="+") 71 | parser.add_argument("--horizontal_flip_prob", type=float, default=[0.5], nargs="+") 72 | parser.add_argument("--gaussian_prob", type=float, default=[0.5], nargs="+") 73 | parser.add_argument("--solarization_prob", type=float, default=[0.0], nargs="+") 74 | 75 | # cropping 76 | parser.add_argument("--crop_size", type=int, default=[224], nargs="+") 77 | parser.add_argument("--min_scale", type=float, default=[0.08], nargs="+") 78 | parser.add_argument("--max_scale", type=float, default=[1.0], nargs="+") 79 | 80 | # debug 81 | parser.add_argument("--debug_augmentations", action="store_true") 82 | 83 | 84 | def linear_augmentations_args(parser: ArgumentParser): 85 | parser.add_argument("--crop_size", type=int, default=[224], nargs="+") 86 | 87 | 88 | def custom_dataset_args(parser: ArgumentParser): 89 | """Adds custom data-related arguments to a parser. 90 | 91 | Args: 92 | parser (ArgumentParser): parser to add augmentation args to. 93 | """ 94 | 95 | # custom dataset only 96 | parser.add_argument("--no_labels", action="store_true") 97 | 98 | # for custom dataset 99 | parser.add_argument("--mean", type=float, default=[0.485, 0.456, 0.406], nargs="+") 100 | parser.add_argument("--std", type=float, default=[0.228, 0.224, 0.225], nargs="+") 101 | -------------------------------------------------------------------------------- /solo/losses/vicreg.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | import torch.nn.functional as F 22 | 23 | 24 | def invariance_loss(z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor: 25 | """Computes mse loss given batch of projected features z1 from view 1 and 26 | projected features z2 from view 2. 27 | 28 | Args: 29 | z1 (torch.Tensor): NxD Tensor containing projected features from view 1. 30 | z2 (torch.Tensor): NxD Tensor containing projected features from view 2. 31 | 32 | Returns: 33 | torch.Tensor: invariance loss (mean squared error). 34 | """ 35 | 36 | return F.mse_loss(z1, z2) 37 | 38 | 39 | def variance_loss(z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor: 40 | """Computes variance loss given batch of projected features z1 from view 1 and 41 | projected features z2 from view 2. 42 | 43 | Args: 44 | z1 (torch.Tensor): NxD Tensor containing projected features from view 1. 45 | z2 (torch.Tensor): NxD Tensor containing projected features from view 2. 46 | 47 | Returns: 48 | torch.Tensor: variance regularization loss. 49 | """ 50 | 51 | eps = 1e-4 52 | std_z1 = torch.sqrt(z1.var(dim=0) + eps) 53 | std_z2 = torch.sqrt(z2.var(dim=0) + eps) 54 | std_loss = torch.mean(F.relu(1 - std_z1)) + torch.mean(F.relu(1 - std_z2)) 55 | return std_loss 56 | 57 | 58 | def covariance_loss(z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor: 59 | """Computes covariance loss given batch of projected features z1 from view 1 and 60 | projected features z2 from view 2. 61 | 62 | Args: 63 | z1 (torch.Tensor): NxD Tensor containing projected features from view 1. 64 | z2 (torch.Tensor): NxD Tensor containing projected features from view 2. 65 | 66 | Returns: 67 | torch.Tensor: covariance regularization loss. 68 | """ 69 | 70 | N, D = z1.size() 71 | 72 | z1 = z1 - z1.mean(dim=0) 73 | z2 = z2 - z2.mean(dim=0) 74 | cov_z1 = (z1.T @ z1) / (N - 1) 75 | cov_z2 = (z2.T @ z2) / (N - 1) 76 | 77 | diag = torch.eye(D, device=z1.device) 78 | cov_loss = cov_z1[~diag.bool()].pow_(2).sum() / D + cov_z2[~diag.bool()].pow_(2).sum() / D 79 | return cov_loss 80 | 81 | 82 | def vicreg_loss_func( 83 | z1: torch.Tensor, 84 | z2: torch.Tensor, 85 | sim_loss_weight: float = 25.0, 86 | var_loss_weight: float = 25.0, 87 | cov_loss_weight: float = 1.0, 88 | ) -> torch.Tensor: 89 | """Computes VICReg's loss given batch of projected features z1 from view 1 and 90 | projected features z2 from view 2. 91 | 92 | Args: 93 | z1 (torch.Tensor): NxD Tensor containing projected features from view 1. 94 | z2 (torch.Tensor): NxD Tensor containing projected features from view 2. 95 | sim_loss_weight (float): invariance loss weight. 96 | var_loss_weight (float): variance loss weight. 97 | cov_loss_weight (float): covariance loss weight. 98 | 99 | Returns: 100 | torch.Tensor: VICReg loss. 101 | """ 102 | 103 | sim_loss = invariance_loss(z1, z2) 104 | var_loss = variance_loss(z1, z2) 105 | cov_loss = covariance_loss(z1, z2) 106 | 107 | loss = sim_loss_weight * sim_loss + var_loss_weight * var_loss + cov_loss_weight * cov_loss 108 | return loss 109 | -------------------------------------------------------------------------------- /trades/trades.py: -------------------------------------------------------------------------------- 1 | # This code is adapted from https://github.com/yaodongyu/TRADES/blob/master/trades.py 2 | 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | import torch.optim as optim 9 | 10 | 11 | def squared_l2_norm(x): 12 | flattened = x.view(x.unsqueeze(0).shape[0], -1) 13 | return (flattened ** 2).sum(1) 14 | 15 | 16 | def l2_norm(x): 17 | return squared_l2_norm(x).sqrt() 18 | 19 | 20 | def trades_loss(model, 21 | model_backbone, 22 | model_linear, 23 | x_natural, 24 | y, 25 | optimizer, 26 | step_size=0.003, 27 | epsilon=0.031, 28 | perturb_steps=10, 29 | beta=1.0, 30 | distance='l_inf'): 31 | # define KL-loss 32 | criterion_kl = nn.KLDivLoss(size_average=False) 33 | model_backbone.eval() 34 | model_linear.eval() 35 | batch_size = len(x_natural) 36 | # generate adversarial example 37 | x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).cuda().detach() 38 | if distance == 'l_inf': 39 | for _ in range(perturb_steps): 40 | x_adv.requires_grad_() 41 | with torch.enable_grad(): 42 | loss_kl = criterion_kl(F.log_softmax(model(x_adv), dim=1), 43 | F.softmax(model(x_natural), dim=1)) 44 | grad = torch.autograd.grad(loss_kl, [x_adv])[0] 45 | x_adv = x_adv.detach() + step_size * torch.sign(grad.detach()) 46 | x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon) 47 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 48 | elif distance == 'l_2': 49 | delta = 0.001 * torch.randn(x_natural.shape).cuda().detach() 50 | delta = Variable(delta.data, requires_grad=True) 51 | 52 | # Setup optimizers 53 | optimizer_delta = optim.SGD([delta], lr=epsilon / perturb_steps * 2) 54 | 55 | for _ in range(perturb_steps): 56 | adv = x_natural + delta 57 | 58 | # optimize 59 | optimizer_delta.zero_grad() 60 | with torch.enable_grad(): 61 | loss = (-1) * criterion_kl(F.log_softmax(model(adv), dim=1), 62 | F.softmax(model(x_natural), dim=1)) 63 | loss.backward() 64 | # renorming gradient 65 | grad_norms = delta.grad.view(batch_size, -1).norm(p=2, dim=1) 66 | delta.grad.div_(grad_norms.view(-1, 1, 1, 1)) 67 | # avoid nan or inf if gradient is 0 68 | if (grad_norms == 0).any(): 69 | delta.grad[grad_norms == 0] = torch.randn_like(delta.grad[grad_norms == 0]) 70 | optimizer_delta.step() 71 | 72 | # projection 73 | delta.data.add_(x_natural) 74 | delta.data.clamp_(0, 1).sub_(x_natural) 75 | delta.data.renorm_(p=2, dim=0, maxnorm=epsilon) 76 | x_adv = Variable(x_natural + delta, requires_grad=False) 77 | else: 78 | x_adv = torch.clamp(x_adv, 0.0, 1.0) 79 | # model.train() 80 | model_backbone.train() 81 | model_linear.train() 82 | 83 | x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False) 84 | # zero gradient 85 | optimizer.zero_grad() 86 | # calculate robust loss 87 | logits = model(x_natural) 88 | loss_natural = F.cross_entropy(logits, y) 89 | loss_robust = (1.0 / batch_size) * criterion_kl(F.log_softmax(model(x_adv), dim=1), 90 | F.softmax(model(x_natural), dim=1)) 91 | loss = loss_natural + beta * loss_robust 92 | return loss, logits 93 | 94 | 95 | # This function is adapted from https://github.com/yaodongyu/TRADES/blob/master/pgd_attack_cifar10.py#L54 96 | def _pgd_whitebox(model, 97 | X, 98 | y, 99 | epsilon, 100 | num_steps, 101 | step_size, 102 | random, 103 | device, 104 | ): 105 | out_clean = model(X) 106 | X_pgd = Variable(X.data, requires_grad=True) 107 | if random: 108 | random_noise = torch.FloatTensor(*X_pgd.shape).uniform_(-epsilon, epsilon).to(device) 109 | X_pgd = Variable(X_pgd.data + random_noise, requires_grad=True) 110 | 111 | for _ in range(num_steps): 112 | opt = optim.SGD([X_pgd], lr=1e-3) 113 | opt.zero_grad() 114 | 115 | with torch.enable_grad(): 116 | loss = nn.CrossEntropyLoss()(model(X_pgd), y) 117 | loss.backward() 118 | eta = step_size * X_pgd.grad.data.sign() 119 | X_pgd = Variable(X_pgd.data + eta, requires_grad=True) 120 | eta = torch.clamp(X_pgd.data - X.data, -epsilon, epsilon) 121 | X_pgd = Variable(X.data + eta, requires_grad=True) 122 | X_pgd = Variable(torch.clamp(X_pgd, 0, 1.0), requires_grad=True) 123 | 124 | out_pgd = model(X_pgd) 125 | return out_clean, out_pgd 126 | 127 | -------------------------------------------------------------------------------- /solo/utils/lars.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | # Copied from Pytorch Lightning (https://github.com/PyTorchLightning/pytorch-lightning/) 21 | # with extra documentations. 22 | 23 | 24 | import torch 25 | from torch.optim import Optimizer 26 | 27 | 28 | class LARSWrapper: 29 | def __init__( 30 | self, 31 | optimizer: Optimizer, 32 | eta: float = 1e-3, 33 | clip: bool = False, 34 | eps: float = 1e-8, 35 | exclude_bias_n_norm: bool = False, 36 | ): 37 | """Wrapper that adds LARS scheduling to any optimizer. 38 | This helps stability with huge batch sizes. 39 | 40 | Args: 41 | optimizer (Optimizer): torch optimizer. 42 | eta (float, optional): trust coefficient. Defaults to 1e-3. 43 | clip (bool, optional): clip gradient values. Defaults to False. 44 | eps (float, optional): adaptive_lr stability coefficient. Defaults to 1e-8. 45 | exclude_bias_n_norm (bool, optional): exclude bias and normalization layers from lars. 46 | Defaults to False. 47 | """ 48 | 49 | self.optim = optimizer 50 | self.eta = eta 51 | self.eps = eps 52 | self.clip = clip 53 | self.exclude_bias_n_norm = exclude_bias_n_norm 54 | 55 | # transfer optim methods 56 | self.state_dict = self.optim.state_dict 57 | self.load_state_dict = self.optim.load_state_dict 58 | self.zero_grad = self.optim.zero_grad 59 | self.add_param_group = self.optim.add_param_group 60 | 61 | self.__setstate__ = self.optim.__setstate__ # type: ignore 62 | self.__getstate__ = self.optim.__getstate__ # type: ignore 63 | self.__repr__ = self.optim.__repr__ # type: ignore 64 | 65 | @property 66 | def defaults(self): 67 | return self.optim.defaults 68 | 69 | @defaults.setter 70 | def defaults(self, defaults): 71 | self.optim.defaults = defaults 72 | 73 | @property # type: ignore 74 | def __class__(self): 75 | return Optimizer 76 | 77 | @property 78 | def state(self): 79 | return self.optim.state 80 | 81 | @state.setter 82 | def state(self, state): 83 | self.optim.state = state 84 | 85 | @property 86 | def param_groups(self): 87 | return self.optim.param_groups 88 | 89 | @param_groups.setter 90 | def param_groups(self, value): 91 | self.optim.param_groups = value 92 | 93 | @torch.no_grad() 94 | def step(self, closure=None): 95 | weight_decays = [] 96 | 97 | for group in self.optim.param_groups: 98 | weight_decay = group.get("weight_decay", 0) 99 | weight_decays.append(weight_decay) 100 | 101 | # reset weight decay 102 | group["weight_decay"] = 0 103 | 104 | # update the parameters 105 | for p in group["params"]: 106 | if p.grad is not None and (p.ndim != 1 or not self.exclude_bias_n_norm): 107 | self.update_p(p, group, weight_decay) 108 | 109 | # update the optimizer 110 | self.optim.step(closure=closure) 111 | 112 | # return weight decay control to optimizer 113 | for group_idx, group in enumerate(self.optim.param_groups): 114 | group["weight_decay"] = weight_decays[group_idx] 115 | 116 | def update_p(self, p, group, weight_decay): 117 | # calculate new norms 118 | p_norm = torch.norm(p.data) 119 | g_norm = torch.norm(p.grad.data) 120 | 121 | if p_norm != 0 and g_norm != 0: 122 | # calculate new lr 123 | new_lr = (self.eta * p_norm) / (g_norm + p_norm * weight_decay + self.eps) 124 | 125 | # clip lr 126 | if self.clip: 127 | new_lr = min(new_lr / group["lr"], 1) 128 | 129 | # update params with clipped lr 130 | p.grad.data += weight_decay * p.data 131 | p.grad.data *= new_lr 132 | -------------------------------------------------------------------------------- /solo/losses/dino.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import numpy as np 21 | import torch 22 | import torch.distributed as dist 23 | import torch.nn as nn 24 | import torch.nn.functional as F 25 | 26 | 27 | class DINOLoss(nn.Module): 28 | def __init__( 29 | self, 30 | num_prototypes: int, 31 | warmup_teacher_temp: float, 32 | teacher_temp: float, 33 | warmup_teacher_temp_epochs: float, 34 | num_epochs: int, 35 | student_temp: float = 0.1, 36 | num_large_crops: int = 2, 37 | center_momentum: float = 0.9, 38 | ): 39 | """Auxiliary module to compute DINO's loss. 40 | 41 | Args: 42 | num_prototypes (int): number of prototypes. 43 | warmup_teacher_temp (float): base temperature for the temperature schedule 44 | of the teacher. 45 | teacher_temp (float): final temperature for the teacher. 46 | warmup_teacher_temp_epochs (float): number of epochs for the cosine annealing schedule. 47 | num_epochs (int): total number of epochs. 48 | student_temp (float, optional): temperature for the student. Defaults to 0.1. 49 | num_large_crops (int, optional): number of crops/views. Defaults to 2. 50 | center_momentum (float, optional): momentum for the EMA update of the center of 51 | mass of the teacher. Defaults to 0.9. 52 | """ 53 | 54 | super().__init__() 55 | self.epoch = 0 56 | self.student_temp = student_temp 57 | self.center_momentum = center_momentum 58 | self.num_large_crops = num_large_crops 59 | self.register_buffer("center", torch.zeros(1, num_prototypes)) 60 | # we apply a warm up for the teacher temperature because 61 | # a too high temperature makes the training instable at the beginning 62 | self.teacher_temp_schedule = np.concatenate( 63 | ( 64 | np.linspace(warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs), 65 | np.ones(num_epochs - warmup_teacher_temp_epochs) * teacher_temp, 66 | ) 67 | ) 68 | 69 | def forward(self, student_output: torch.Tensor, teacher_output: torch.Tensor) -> torch.Tensor: 70 | """Computes DINO's loss given a batch of logits of the student and a batch of logits of the 71 | teacher. 72 | 73 | Args: 74 | student_output (torch.Tensor): NxP Tensor containing student logits for all views. 75 | teacher_output (torch.Tensor): NxP Tensor containing teacher logits for all views. 76 | 77 | Returns: 78 | torch.Tensor: DINO loss. 79 | """ 80 | 81 | student_out = student_output / self.student_temp 82 | student_out = student_out.chunk(self.num_large_crops) 83 | 84 | # teacher centering and sharpening 85 | temp = self.teacher_temp_schedule[self.epoch] 86 | teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1) 87 | teacher_out = teacher_out.detach().chunk(2) 88 | 89 | total_loss = 0 90 | n_loss_terms = 0 91 | for iq, q in enumerate(teacher_out): 92 | for iv, v in enumerate(student_out): 93 | if iv == iq: 94 | # we skip cases where student and teacher operate on the same view 95 | continue 96 | loss = torch.sum(-q * F.log_softmax(v, dim=-1), dim=-1) 97 | total_loss += loss.mean() 98 | n_loss_terms += 1 99 | total_loss /= n_loss_terms 100 | self.update_center(teacher_output) 101 | return total_loss 102 | 103 | @torch.no_grad() 104 | def update_center(self, teacher_output: torch.Tensor): 105 | """Updates the center for DINO's loss using exponential moving average. 106 | 107 | Args: 108 | teacher_output (torch.Tensor): NxP Tensor containing teacher logits of all views. 109 | """ 110 | 111 | batch_center = torch.sum(teacher_output, dim=0, keepdim=True) 112 | if dist.is_available() and dist.is_initialized(): 113 | dist.all_reduce(batch_center) 114 | batch_center = batch_center / dist.get_world_size() 115 | batch_center = batch_center / len(teacher_output) 116 | 117 | # ema update 118 | self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum) 119 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Decoupled Adversarial Contrastive Learning for Self-supervised Adversarial Robustness 2 | 3 | Chaoning Zhang*, Kang Zhang*, Chenshuang Zhang, Axi Niu, Jiu Feng, Chang D. Yoo, In So Kweon 4 | (*Equal contribution) 5 | 6 | This is the official implementation of the paper "Decoupled Adversarial Contrastive Learning for Self-supervised Adversarial Robustness," which was accepted for an oral presentation at ECCV 2022. 7 | 8 | The DeACL framework consists of two stages. In the first stage, DeACL performs standard self-supervised learning (SSL) to obtain a non-robust encoder. In the second stage, the pretrained encoder acts as a teacher model, generating pseudo-targets to guide supervised adversarial training (AT) on a student model. The student model, trained through these two stages, is the final model of interest. DeACL is a general framework that can be applied to any SSL method and AT method. 9 | 10 | 11 | Paper link: [arXiv](https://arxiv.org/abs/2207.10899), [ECCV2022](https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136900716.pdf), [supplementary material](https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136900716-supp.pdf) 12 | 13 | 14 | # Change log 15 | *2024.7.14* We have rewritten the SLF and AFF code to enhance its readability and usability. The new code is more modular, making it easier to extend to other datasets and models. Using this updated code, we conducted experiments with SimCLR and ResNet18 on the CIFAR10 dataset (ckpt [here](https://drive.google.com/file/d/1yc38miWGY57sHS6W6aY_k5t69Gt5v5fm/view?usp=sharing)). The code was executed five times, and the average results are reported below. The updated code can be found in the `adv_finetune.py` file. 16 | 17 | *2023.3.2* The different definitions of the Resnet model between pre-train and SLF make the forward and backward different. Our previous code can get a different result given in the paper. We fixed the bug by changing the Resnet used during the SLF setting and released the pre-trained model with new code, which performs slightly differently from the one reported in the paper (SLF with CIFAR10 (AA,RA,SA) reported in paper: `45.31, 53.95, 80.17` -> with current code: `45.57, 55.43, 79.53`). (We apologize for not providing the model used in the paper since we accidentally deleted the original file.) We also update the environment configuration to help you reproduce our result. 18 | 19 | # 🔧 Enviroment 20 | We use [conda](https://docs.conda.io/en/latest/miniconda.html) for python enviroment management. After installing conda, 21 | 22 | 1. conda create -n deacl python=3.8 23 | 24 | 2. conda activate deacl 25 | 26 | 3. pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117 27 | 28 | 2. pip install -r requirements.txt 29 | 30 | 31 | # ⚡ Training and Evaluation 32 | 33 | ## 1. prepare the pretrained teacher self-supervised model 34 | You can download pretrained checkpoint from [solo-learn](https://github.com/vturrisi/solo-learn#cifar-10) or train by yourself. 35 | 36 | SimCLR model pretrained by solo-learn [link](https://drive.google.com/drive/folders/1mcvWr8P2WNJZ7TVpdLHA_Q91q4VK3y8O?usp=sharing). 37 | 38 | put the downloaded model into folder 'TeacherCKPT' 39 | 40 | ## 2. Train DeACL with ResNet18 on CIFAR10 dataset 41 | 42 | Using the file `bash_files\DeACL_cifar10_resnet18.sh`. 43 | 44 | You need to specific the `--project xxx`, put your wandb api key, and add `--wandb` to enable wandb logging. 45 | 46 | ## 3. Test the robustness of PGD and AutoAttack under standard linear fine-tuning (SLF) and adversarial full fine-tuning (AFF) 47 | First install [autoattack](https://github.com/fra31/auto-attack) package `pip install git+https://github.com/fra31/auto-attack` 48 | 49 | ### a. Eval the the trained model in step 2 50 | Use the following commandline, replace the `CKPT_PATH` with the path of the trained model. 51 | 52 | ```bash 53 | # SLF 54 | python adv_finetune.py --ckpt CKPT_PATH --mode SLF --learning_rate 0.1 55 | # AFF 56 | python adv_finetune.py --ckpt CKPT_PATH --mode AFF --learning_rate 0.01 57 | ``` 58 | 59 | ### b. Eval the pretrained model provided by us 60 | We privide our pretrained model on CIFAR10 with teacher model SimCLR at [here](https://drive.google.com/file/d/1yc38miWGY57sHS6W6aY_k5t69Gt5v5fm/view?usp=sharing), you can download it and use the commandline in a. to evaluate the model. 61 | 62 | The results get by the pretrained model with the above code are as follows (average of 5 runs). For SLF, the initial learning rate is 0.1, and for AFF and ALF, the initial learning rate is 0.01 with beta 6 in trades loss. All three modes training epochs are 25, and the learning rate decay at 15 and 20 epochs by 10 times. 63 | | Mode | AA | RA | SA | 64 | | --- | --- | --- | --- | 65 | | SLF | 46.14 ± 0.054 | 53.45 ± 0.095 | 80.82 ± 0.090 | 66 | | AFF | 50.75 ± 0.150 | 54.23 ± 0.238 | 83.64 ± 0.139 | 67 | | ALF | 45.55 ± 0.134 | 55.30 ± 0.142 | 79.39 ± 0.140 | 68 | 69 | 70 | # Acknowledgement 71 | This code is developed based on [solo-learn](https://github.com/vturrisi/solo-learn) for training and [AdvCL](https://github.com/LijieFan/AdvCL.git), [AutoAttack](https://github.com/fra31/auto-attack) and [TRADES](https://github.com/yaodongyu/TRADES) for testing. 72 | 73 | 82 | 83 | # See also our other works 84 | 85 | Dual Temperature Helps Contrastive Learning Without Many Negative Samples: Towards Understanding and Simplifying MoCo (Accepted by CVPR2022) [GitHub](https://github.com/ChaoningZhang/Dual-temperature.git), [arXiv](https://arxiv.org/abs/2203.17248) 86 | 87 | 88 | 89 | # Citation 90 | ``` 91 | @inproceedings{zhang2022decoupled, 92 | title={Decoupled Adversarial Contrastive Learning for Self-supervised Adversarial Robustness}, 93 | author={Zhang, Chaoning and Zhang, Kang and Zhang, Chenshuang and Niu, Axi and Feng, Jiu and Yoo, Chang D and Kweon, In So}, 94 | booktitle={ECCV 2022}, 95 | pages={725--742}, 96 | year={2022}, 97 | organization={Springer} 98 | } 99 | ``` 100 | -------------------------------------------------------------------------------- /solo/utils/checkpointer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import json 21 | import os 22 | import random 23 | import string 24 | import time 25 | from argparse import ArgumentParser, Namespace 26 | from pathlib import Path 27 | from typing import Optional, Union 28 | 29 | import pytorch_lightning as pl 30 | from pytorch_lightning.callbacks import Callback 31 | 32 | 33 | def random_string(letter_count=4, digit_count=4): 34 | tmp_random = random.Random(time.time()) 35 | rand_str = "".join((tmp_random.choice(string.ascii_lowercase) for x in range(letter_count))) 36 | rand_str += "".join((tmp_random.choice(string.digits) for x in range(digit_count))) 37 | rand_str = list(rand_str) 38 | tmp_random.shuffle(rand_str) 39 | return "".join(rand_str) 40 | 41 | 42 | class Checkpointer(Callback): 43 | def __init__( 44 | self, 45 | args: Namespace, 46 | logdir: Union[str, Path] = Path("trained_models"), 47 | frequency: int = 1, 48 | keep_previous_checkpoints: bool = False, 49 | ): 50 | """Custom checkpointer callback that stores checkpoints in an easier to access way. 51 | 52 | Args: 53 | args (Namespace): namespace object containing at least an attribute name. 54 | logdir (Union[str, Path], optional): base directory to store checkpoints. 55 | Defaults to "trained_models". 56 | frequency (int, optional): number of epochs between each checkpoint. Defaults to 1. 57 | keep_previous_checkpoints (bool, optional): whether to keep previous checkpoints or not. 58 | Defaults to False. 59 | """ 60 | 61 | super().__init__() 62 | 63 | self.args = args 64 | self.logdir = Path(logdir) 65 | self.frequency = frequency 66 | self.keep_previous_checkpoints = keep_previous_checkpoints 67 | 68 | @staticmethod 69 | def add_checkpointer_args(parent_parser: ArgumentParser): 70 | """Adds user-required arguments to a parser. 71 | 72 | Args: 73 | parent_parser (ArgumentParser): parser to add new args to. 74 | """ 75 | 76 | parser = parent_parser.add_argument_group("checkpointer") 77 | parser.add_argument("--checkpoint_dir", default=Path("trained_models"), type=Path) 78 | parser.add_argument("--checkpoint_frequency", default=1, type=int) 79 | return parent_parser 80 | 81 | def initial_setup(self, trainer: pl.Trainer): 82 | """Creates the directories and does the initial setup needed. 83 | 84 | Args: 85 | trainer (pl.Trainer): pytorch lightning trainer object. 86 | """ 87 | 88 | if trainer.logger is None: 89 | if self.logdir.exists(): 90 | existing_versions = set(os.listdir(self.logdir)) 91 | else: 92 | existing_versions = [] 93 | version = "offline-" + random_string() 94 | while version in existing_versions: 95 | version = "offline-" + random_string() 96 | else: 97 | version = str(trainer.logger.version) 98 | if version is not None: 99 | self.path = self.logdir / version 100 | self.ckpt_placeholder = f"{self.args.name}-{version}" + "-ep={}.ckpt" 101 | else: 102 | self.path = self.logdir 103 | self.ckpt_placeholder = f"{self.args.name}" + "-ep={}.ckpt" 104 | self.last_ckpt: Optional[str] = None 105 | 106 | # create logging dirs 107 | if trainer.is_global_zero: 108 | os.makedirs(self.path, exist_ok=True) 109 | 110 | def save_args(self, trainer: pl.Trainer): 111 | """Stores arguments into a json file. 112 | 113 | Args: 114 | trainer (pl.Trainer): pytorch lightning trainer object. 115 | """ 116 | 117 | if trainer.is_global_zero: 118 | args = vars(self.args) 119 | json_path = self.path / "args.json" 120 | json.dump(args, open(json_path, "w"), default=lambda o: "") 121 | 122 | def save(self, trainer: pl.Trainer): 123 | """Saves current checkpoint. 124 | 125 | Args: 126 | trainer (pl.Trainer): pytorch lightning trainer object. 127 | """ 128 | 129 | if trainer.is_global_zero and not trainer.sanity_checking: 130 | epoch = trainer.current_epoch # type: ignore 131 | ckpt = self.path / self.ckpt_placeholder.format(epoch) 132 | trainer.save_checkpoint(ckpt) 133 | 134 | if self.last_ckpt and self.last_ckpt != ckpt and not self.keep_previous_checkpoints: 135 | os.remove(self.last_ckpt) 136 | self.last_ckpt = ckpt 137 | 138 | def on_train_start(self, trainer: pl.Trainer, _): 139 | """Executes initial setup and saves arguments. 140 | 141 | Args: 142 | trainer (pl.Trainer): pytorch lightning trainer object. 143 | """ 144 | 145 | self.initial_setup(trainer) 146 | self.save_args(trainer) 147 | 148 | def on_train_epoch_end(self, trainer: pl.Trainer, _): 149 | """Tries to save current checkpoint at the end of each train epoch. 150 | 151 | Args: 152 | trainer (pl.Trainer): pytorch lightning trainer object. 153 | """ 154 | 155 | epoch = trainer.current_epoch # type: ignore 156 | if epoch % self.frequency == 0: 157 | self.save(trainer) 158 | -------------------------------------------------------------------------------- /solo/args/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import argparse 21 | 22 | import pytorch_lightning as pl 23 | from solo.args.dataset import ( 24 | augmentations_args, 25 | custom_dataset_args, 26 | dataset_args, 27 | linear_augmentations_args, 28 | ) 29 | from solo.args.utils import additional_setup_linear, additional_setup_pretrain 30 | from solo.methods import METHODS 31 | from solo.utils.auto_resumer import AutoResumer 32 | from solo.utils.checkpointer import Checkpointer 33 | 34 | try: 35 | from solo.utils.auto_umap import AutoUMAP 36 | except ImportError: 37 | _umap_available = False 38 | else: 39 | _umap_available = True 40 | 41 | 42 | def parse_args_pretrain() -> argparse.Namespace: 43 | """Parses dataset, augmentation, pytorch lightning, model specific and additional args. 44 | 45 | First adds shared args such as dataset, augmentation and pytorch lightning args, then pulls the 46 | model name from the command and proceeds to add model specific args from the desired class. If 47 | wandb is enabled, it adds checkpointer args. Finally, adds additional non-user given parameters. 48 | 49 | Returns: 50 | argparse.Namespace: a namespace containing all args needed for pretraining. 51 | """ 52 | 53 | parser = argparse.ArgumentParser() 54 | 55 | parser.add_argument("--aux_data", action='store_true') 56 | 57 | 58 | # add shared arguments 59 | dataset_args(parser) 60 | augmentations_args(parser) 61 | custom_dataset_args(parser) 62 | 63 | # add pytorch lightning trainer args 64 | parser = pl.Trainer.add_argparse_args(parser) 65 | 66 | # add method-specific arguments 67 | parser.add_argument("--method", type=str) 68 | 69 | # THIS LINE IS KEY TO PULL THE MODEL NAME 70 | temp_args, _ = parser.parse_known_args() 71 | 72 | # add model specific args 73 | parser = METHODS[temp_args.method].add_model_specific_args(parser) 74 | 75 | # add auto checkpoint/umap args 76 | parser.add_argument("--save_checkpoint", action="store_true") 77 | parser.add_argument("--auto_umap", action="store_true") 78 | parser.add_argument("--auto_resume", action="store_true") 79 | temp_args, _ = parser.parse_known_args() 80 | 81 | # optionally add checkpointer and AutoUMAP args 82 | if temp_args.save_checkpoint: 83 | parser = Checkpointer.add_checkpointer_args(parser) 84 | 85 | if _umap_available and temp_args.auto_umap: 86 | parser = AutoUMAP.add_auto_umap_args(parser) 87 | 88 | if temp_args.auto_resume: 89 | parser = AutoResumer.add_autoresumer_args(parser) 90 | 91 | # parse args 92 | args = parser.parse_args() 93 | 94 | # prepare arguments with additional setup 95 | additional_setup_pretrain(args) 96 | 97 | return args 98 | 99 | 100 | def parse_args_linear() -> argparse.Namespace: 101 | """Parses feature extractor, dataset, pytorch lightning, linear eval specific and additional args. 102 | 103 | First adds an arg for the pretrained feature extractor, then adds dataset, pytorch lightning 104 | and linear eval specific args. If wandb is enabled, it adds checkpointer args. Finally, adds 105 | additional non-user given parameters. 106 | 107 | Returns: 108 | argparse.Namespace: a namespace containing all args needed for pretraining. 109 | """ 110 | 111 | parser = argparse.ArgumentParser() 112 | 113 | parser.add_argument("--pretrained_feature_extractor", type=str) 114 | 115 | # add shared arguments 116 | dataset_args(parser) 117 | linear_augmentations_args(parser) 118 | custom_dataset_args(parser) 119 | 120 | # add pytorch lightning trainer args 121 | parser = pl.Trainer.add_argparse_args(parser) 122 | 123 | # linear model 124 | parser = METHODS["linear"].add_model_specific_args(parser) 125 | 126 | # THIS LINE IS KEY TO PULL WANDB AND SAVE_CHECKPOINT 127 | parser.add_argument("--save_checkpoint", action="store_true") 128 | temp_args, _ = parser.parse_known_args() 129 | 130 | # optionally add checkpointer 131 | if temp_args.save_checkpoint: 132 | parser = Checkpointer.add_checkpointer_args(parser) 133 | 134 | # parse args 135 | args = parser.parse_args() 136 | additional_setup_linear(args) 137 | 138 | return args 139 | 140 | 141 | def parse_args_knn() -> argparse.Namespace: 142 | """Parses arguments for offline K-NN. 143 | 144 | Returns: 145 | argparse.Namespace: a namespace containing all args needed for pretraining. 146 | """ 147 | 148 | parser = argparse.ArgumentParser() 149 | 150 | # add knn args 151 | parser.add_argument("--pretrained_checkpoint_dir", type=str) 152 | parser.add_argument("--batch_size", type=int, default=16) 153 | parser.add_argument("--num_workers", type=int, default=10) 154 | parser.add_argument("--k", type=int, nargs="+") 155 | parser.add_argument("--temperature", type=float, nargs="+") 156 | parser.add_argument("--distance_function", type=str, nargs="+") 157 | parser.add_argument("--feature_type", type=str, nargs="+") 158 | 159 | # add shared arguments 160 | dataset_args(parser) 161 | custom_dataset_args(parser) 162 | 163 | # parse args 164 | args = parser.parse_args() 165 | 166 | return args 167 | 168 | 169 | def parse_args_umap() -> argparse.Namespace: 170 | """Parses arguments for offline UMAP. 171 | 172 | Returns: 173 | argparse.Namespace: a namespace containing all args needed for pretraining. 174 | """ 175 | 176 | parser = argparse.ArgumentParser() 177 | 178 | # add knn args 179 | parser.add_argument("--pretrained_checkpoint_dir", type=str) 180 | parser.add_argument("--batch_size", type=int, default=16) 181 | parser.add_argument("--num_workers", type=int, default=10) 182 | 183 | # add shared arguments 184 | dataset_args(parser) 185 | custom_dataset_args(parser) 186 | 187 | # parse args 188 | args = parser.parse_args() 189 | 190 | return args 191 | -------------------------------------------------------------------------------- /solo/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import math 21 | import warnings 22 | from typing import List, Tuple 23 | 24 | import torch 25 | import torch.distributed as dist 26 | import torch.nn as nn 27 | 28 | 29 | def _1d_filter(tensor: torch.Tensor) -> torch.Tensor: 30 | return tensor.isfinite() 31 | 32 | 33 | def _2d_filter(tensor: torch.Tensor) -> torch.Tensor: 34 | return tensor.isfinite().all(dim=1) 35 | 36 | 37 | def _single_input_filter(tensor: torch.Tensor) -> Tuple[torch.Tensor]: 38 | if len(tensor.size()) == 1: 39 | filter_func = _1d_filter 40 | elif len(tensor.size()) == 2: 41 | filter_func = _2d_filter 42 | else: 43 | raise RuntimeError("Only 1d and 2d tensors are supported.") 44 | 45 | selected = filter_func(tensor) 46 | tensor = tensor[selected] 47 | 48 | return tensor, selected 49 | 50 | 51 | def _multi_input_filter(tensors: List[torch.Tensor]) -> Tuple[torch.Tensor]: 52 | if len(tensors[0].size()) == 1: 53 | filter_func = _1d_filter 54 | elif len(tensors[0].size()) == 2: 55 | filter_func = _2d_filter 56 | else: 57 | raise RuntimeError("Only 1d and 2d tensors are supported.") 58 | 59 | selected = filter_func(tensors[0]) 60 | for tensor in tensors[1:]: 61 | selected = torch.logical_and(selected, filter_func(tensor)) 62 | tensors = [tensor[selected] for tensor in tensors] 63 | 64 | return tensors, selected 65 | 66 | 67 | def filter_inf_n_nan(tensors: List[torch.Tensor], return_indexes: bool = False): 68 | """Filters out inf and nans from any tensor. 69 | This is usefull when there are instability issues, 70 | which cause a small number of values to go bad. 71 | 72 | Args: 73 | tensor (List): tensor to remove nans and infs from. 74 | 75 | Returns: 76 | torch.Tensor: filtered view of the tensor without nans or infs. 77 | """ 78 | 79 | if isinstance(tensors, torch.Tensor): 80 | tensors, selected = _single_input_filter(tensors) 81 | else: 82 | tensors, selected = _multi_input_filter(tensors) 83 | 84 | if return_indexes: 85 | return tensors, selected 86 | return tensors 87 | 88 | 89 | class FilterInfNNan(nn.Module): 90 | def __init__(self, module): 91 | """Layer that filters out inf and nans from any tensor. 92 | This is usefull when there are instability issues, 93 | which cause a small number of values to go bad. 94 | 95 | Args: 96 | tensor (List): tensor to remove nans and infs from. 97 | 98 | Returns: 99 | torch.Tensor: filtered view of the tensor without nans or infs. 100 | """ 101 | super().__init__() 102 | 103 | self.module = module 104 | 105 | def forward(self, x: torch.Tensor) -> torch.Tensor: 106 | out = self.module(x) 107 | out = filter_inf_n_nan(out) 108 | return out 109 | 110 | def __getattr__(self, name): 111 | try: 112 | return super().__getattr__(name) 113 | except AttributeError: 114 | if name == "module": 115 | raise AttributeError() 116 | return getattr(self.module, name) 117 | 118 | 119 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 120 | """Copy & paste from PyTorch official master until it's in a few official releases - RW 121 | Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 122 | """ 123 | 124 | def norm_cdf(x): 125 | """Computes standard normal cumulative distribution function""" 126 | 127 | return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 128 | 129 | if (mean < a - 2 * std) or (mean > b + 2 * std): 130 | warnings.warn( 131 | "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 132 | "The distribution of values may be incorrect.", 133 | stacklevel=2, 134 | ) 135 | 136 | with torch.no_grad(): 137 | # Values are generated by using a truncated uniform distribution and 138 | # then using the inverse CDF for the normal distribution. 139 | # Get upper and lower cdf values 140 | l = norm_cdf((a - mean) / std) 141 | u = norm_cdf((b - mean) / std) 142 | 143 | # Uniformly fill tensor with values from [l, u], then translate to 144 | # [2l-1, 2u-1]. 145 | tensor.uniform_(2 * l - 1, 2 * u - 1) 146 | 147 | # Use inverse cdf transform for normal distribution to get truncated 148 | # standard normal 149 | tensor.erfinv_() 150 | 151 | # Transform to proper mean, std 152 | tensor.mul_(std * math.sqrt(2.0)) 153 | tensor.add_(mean) 154 | 155 | # Clamp to ensure it's in the proper range 156 | tensor.clamp_(min=a, max=b) 157 | return tensor 158 | 159 | 160 | def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): 161 | """Copy & paste from PyTorch official master until it's in a few official releases - RW 162 | Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 163 | """ 164 | 165 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 166 | 167 | 168 | class GatherLayer(torch.autograd.Function): 169 | """Gathers tensors from all processes, supporting backward propagation.""" 170 | 171 | @staticmethod 172 | def forward(ctx, inp): 173 | ctx.save_for_backward(inp) 174 | if dist.is_available() and dist.is_initialized(): 175 | output = [torch.zeros_like(inp) for _ in range(dist.get_world_size())] 176 | dist.all_gather(output, inp) 177 | else: 178 | output = [inp] 179 | return tuple(output) 180 | 181 | @staticmethod 182 | def backward(ctx, *grads): 183 | (inp,) = ctx.saved_tensors 184 | if dist.is_available() and dist.is_initialized(): 185 | grad_out = torch.zeros_like(inp) 186 | grad_out[:] = grads[dist.get_rank()] 187 | else: 188 | grad_out = grads[0] 189 | return grad_out 190 | 191 | 192 | def gather(X, dim=0): 193 | """Gathers tensors from all processes, supporting backward propagation.""" 194 | return torch.cat(GatherLayer.apply(X), dim=dim) 195 | 196 | 197 | def get_rank(): 198 | if dist.is_available() and dist.is_initialized(): 199 | return dist.get_rank() 200 | return 0 201 | -------------------------------------------------------------------------------- /solo/utils/kmeans.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | from typing import Any, Sequence 21 | 22 | import numpy as np 23 | import torch 24 | import torch.distributed as dist 25 | import torch.nn.functional as F 26 | from scipy.sparse import csr_matrix 27 | 28 | 29 | class KMeans: 30 | def __init__( 31 | self, 32 | world_size: int, 33 | rank: int, 34 | num_large_crops: int, 35 | dataset_size: int, 36 | proj_features_dim: int, 37 | num_prototypes: int, 38 | kmeans_iters: int = 10, 39 | ): 40 | """Class that performs K-Means on the hypersphere. 41 | 42 | Args: 43 | world_size (int): world size. 44 | rank (int): rank of the current process. 45 | num_large_crops (int): number of crops. 46 | dataset_size (int): total size of the dataset (number of samples). 47 | proj_features_dim (int): number of dimensions of the projected features. 48 | num_prototypes (int): number of prototypes. 49 | kmeans_iters (int, optional): number of iterations for the k-means clustering. 50 | Defaults to 10. 51 | """ 52 | self.world_size = world_size 53 | self.rank = rank 54 | self.num_large_crops = num_large_crops 55 | self.dataset_size = dataset_size 56 | self.proj_features_dim = proj_features_dim 57 | self.num_prototypes = num_prototypes 58 | self.kmeans_iters = kmeans_iters 59 | 60 | @staticmethod 61 | def get_indices_sparse(data: np.ndarray): 62 | cols = np.arange(data.size) 63 | M = csr_matrix((cols, (data.ravel(), cols)), shape=(int(data.max()) + 1, data.size)) 64 | return [np.unravel_index(row.data, data.shape) for row in M] 65 | 66 | def cluster_memory( 67 | self, 68 | local_memory_index: torch.Tensor, 69 | local_memory_embeddings: torch.Tensor, 70 | ) -> Sequence[Any]: 71 | """Performs K-Means clustering on the hypersphere and returns centroids and 72 | assignments for each sample. 73 | 74 | Args: 75 | local_memory_index (torch.Tensor): memory bank cointaining indices of the 76 | samples. 77 | local_memory_embeddings (torch.Tensor): memory bank cointaining embeddings 78 | of the samples. 79 | 80 | Returns: 81 | Sequence[Any]: assignments and centroids. 82 | """ 83 | j = 0 84 | device = local_memory_embeddings.device 85 | assignments = -torch.ones(len(self.num_prototypes), self.dataset_size).long() 86 | centroids_list = [] 87 | with torch.no_grad(): 88 | for i_K, K in enumerate(self.num_prototypes): 89 | # run distributed k-means 90 | 91 | # init centroids with elements from memory bank of rank 0 92 | centroids = torch.empty(K, self.proj_features_dim).to(device, non_blocking=True) 93 | if self.rank == 0: 94 | random_idx = torch.randperm(len(local_memory_embeddings[j]))[:K] 95 | assert len(random_idx) >= K, "please reduce the number of centroids" 96 | centroids = local_memory_embeddings[j][random_idx] 97 | if dist.is_available() and dist.is_initialized(): 98 | dist.broadcast(centroids, 0) 99 | 100 | for n_iter in range(self.kmeans_iters + 1): 101 | 102 | # E step 103 | dot_products = torch.mm(local_memory_embeddings[j], centroids.t()) 104 | _, local_assignments = dot_products.max(dim=1) 105 | 106 | # finish 107 | if n_iter == self.kmeans_iters: 108 | break 109 | 110 | # M step 111 | where_helper = self.get_indices_sparse(local_assignments.cpu().numpy()) 112 | counts = torch.zeros(K).to(device, non_blocking=True).int() 113 | emb_sums = torch.zeros(K, self.proj_features_dim).to(device, non_blocking=True) 114 | for k in range(len(where_helper)): 115 | if len(where_helper[k][0]) > 0: 116 | emb_sums[k] = torch.sum( 117 | local_memory_embeddings[j][where_helper[k][0]], 118 | dim=0, 119 | ) 120 | counts[k] = len(where_helper[k][0]) 121 | if dist.is_available() and dist.is_initialized(): 122 | dist.all_reduce(counts) 123 | dist.all_reduce(emb_sums) 124 | mask = counts > 0 125 | centroids[mask] = emb_sums[mask] / counts[mask].unsqueeze(1) 126 | 127 | # normalize centroids 128 | centroids = F.normalize(centroids, dim=1, p=2) 129 | 130 | centroids_list.append(centroids) 131 | 132 | if dist.is_available() and dist.is_initialized(): 133 | # gather the assignments 134 | assignments_all = torch.empty( 135 | self.world_size, 136 | local_assignments.size(0), 137 | dtype=local_assignments.dtype, 138 | device=local_assignments.device, 139 | ) 140 | assignments_all = list(assignments_all.unbind(0)) 141 | 142 | dist_process = dist.all_gather( 143 | assignments_all, local_assignments, async_op=True 144 | ) 145 | dist_process.wait() 146 | assignments_all = torch.cat(assignments_all).cpu() 147 | 148 | # gather the indexes 149 | indexes_all = torch.empty( 150 | self.world_size, 151 | local_memory_index.size(0), 152 | dtype=local_memory_index.dtype, 153 | device=local_memory_index.device, 154 | ) 155 | indexes_all = list(indexes_all.unbind(0)) 156 | dist_process = dist.all_gather(indexes_all, local_memory_index, async_op=True) 157 | dist_process.wait() 158 | indexes_all = torch.cat(indexes_all).cpu() 159 | 160 | else: 161 | assignments_all = local_assignments 162 | indexes_all = local_memory_index 163 | 164 | # log assignments 165 | assignments[i_K][indexes_all] = assignments_all 166 | 167 | # next memory bank to use 168 | j = (j + 1) % self.num_large_crops 169 | 170 | return assignments, centroids_list 171 | -------------------------------------------------------------------------------- /main_pretrain_AdvTraining.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import os 21 | from pprint import pprint 22 | 23 | from pytorch_lightning import Trainer, seed_everything 24 | from pytorch_lightning.callbacks import LearningRateMonitor 25 | from pytorch_lightning.loggers import WandbLogger 26 | 27 | from solo.args.setup import parse_args_pretrain 28 | from solo.methods import METHODS 29 | from solo.utils.auto_resumer import AutoResumer 30 | import wandb 31 | try: 32 | from solo.methods.dali import PretrainABC 33 | except ImportError as e: 34 | print(e) 35 | _dali_avaliable = False 36 | else: 37 | _dali_avaliable = True 38 | 39 | try: 40 | from solo.utils.auto_umap import AutoUMAP 41 | except ImportError: 42 | _umap_available = False 43 | else: 44 | _umap_available = True 45 | 46 | import shutil 47 | import types 48 | 49 | from torchvision import transforms 50 | 51 | from solo.utils.checkpointer import Checkpointer 52 | # from solo.utils.classification_dataloader import prepare_data as prepare_data_classification 53 | from solo.utils.classification_dataloader_AdvTraining import \ 54 | prepare_data as prepare_data_classification 55 | from solo.utils.pretrain_dataloader_AdvTraining import ( 56 | prepare_dataloader, prepare_datasets, prepare_n_crop_transform, 57 | prepare_transform) 58 | 59 | 60 | def main(): 61 | seed_everything(5) 62 | 63 | args = parse_args_pretrain() 64 | 65 | print(args) 66 | 67 | assert args.method in METHODS, f"Choose from {METHODS.keys()}" 68 | 69 | if args.num_large_crops != 2: 70 | assert args.method == "wmse" 71 | 72 | MethodClass = METHODS[args.method] 73 | if args.dali: 74 | assert ( 75 | _dali_avaliable 76 | ), "Dali is not currently avaiable, please install it first with [dali]." 77 | MethodClass = types.new_class(f"Dali{MethodClass.__name__}", (PretrainABC, MethodClass)) 78 | 79 | model = MethodClass(**args.__dict__) 80 | 81 | # pretrain dataloader 82 | if not args.dali: 83 | # asymmetric augmentations 84 | if args.unique_augs > 1: 85 | 86 | if args.dataset in ["cifar10", "cifar100", "svhn"]: 87 | crop_size = 32 88 | elif args.dataset in ["tinyimagenet"]: 89 | crop_size = 64 90 | else: 91 | raise 92 | 93 | weak_transforms = transforms.Compose([ 94 | transforms.RandomCrop(crop_size, padding=4), 95 | transforms.RandomHorizontalFlip(), 96 | transforms.ToTensor(), 97 | ]) 98 | transform = [ 99 | weak_transforms 100 | ] 101 | 102 | else: 103 | transform = [prepare_transform(args.dataset, **args.transform_kwargs)] 104 | 105 | transform = prepare_n_crop_transform(transform, num_crops_per_aug=args.num_crops_per_aug) 106 | if args.debug_augmentations: 107 | print("Transforms:") 108 | pprint(transform) 109 | 110 | train_dataset = prepare_datasets( 111 | args.dataset, 112 | transform, 113 | data_dir=args.data_dir, 114 | train_dir=args.train_dir, 115 | no_labels=args.no_labels, 116 | ) 117 | train_loader = prepare_dataloader( 118 | train_dataset, batch_size=args.batch_size, num_workers=args.num_workers 119 | ) 120 | 121 | 122 | 123 | # normal dataloader for when it is available 124 | if args.dataset == "custom" and (args.no_labels or args.val_dir is None): 125 | val_loader = None 126 | elif args.dataset in ["imagenet100", "imagenet"] and args.val_dir is None: 127 | val_loader = None 128 | else: 129 | _, val_loader = prepare_data_classification( 130 | args.dataset, 131 | data_dir=args.data_dir, 132 | train_dir=args.train_dir, 133 | val_dir=args.val_dir, 134 | batch_size=args.batch_size, 135 | num_workers=args.num_workers, 136 | ) 137 | 138 | callbacks = [] 139 | 140 | # wandb logging 141 | if args.wandb: 142 | wandb_logger = WandbLogger( 143 | name=args.name, 144 | project=args.project, 145 | offline=args.offline, 146 | settings=wandb.Settings(start_method="fork") 147 | ) 148 | wandb_logger.watch(model, log="gradients", log_freq=100) 149 | wandb_logger.log_hyperparams(args) 150 | 151 | # lr logging 152 | lr_monitor = LearningRateMonitor(logging_interval="epoch") 153 | callbacks.append(lr_monitor) 154 | 155 | if args.save_checkpoint: 156 | # save checkpoint on last epoch only 157 | ckpt = Checkpointer( 158 | args, 159 | logdir=os.path.join(args.checkpoint_dir, args.method), 160 | frequency=args.checkpoint_frequency, 161 | ) 162 | callbacks.append(ckpt) 163 | 164 | if args.auto_umap: 165 | assert ( 166 | _umap_available 167 | ), "UMAP is not currently avaiable, please install it first with [umap]." 168 | auto_umap = AutoUMAP( 169 | args, 170 | logdir=os.path.join(args.auto_umap_dir, args.method), 171 | frequency=args.auto_umap_frequency, 172 | ) 173 | callbacks.append(auto_umap) 174 | 175 | # 1.7 will deprecate resume_from_checkpoint, but for the moment 176 | # the argument is the same, but we need to pass it as ckpt_path to trainer.fit 177 | ckpt_path = None 178 | 179 | 180 | trainer = Trainer.from_argparse_args( 181 | args, 182 | logger=wandb_logger if args.wandb else None, 183 | callbacks=callbacks, 184 | enable_checkpointing=False, 185 | ) 186 | 187 | # File backup 188 | if args.wandb: 189 | experimentdir = f"code/{args.method}_{args.project}_{args.name}_{trainer.logger.version}" 190 | args.codepath = experimentdir 191 | else: 192 | experimentdir = f"code/{args.method}_{args.project}_{args.name}_test" 193 | 194 | if not os.path.exists("code"): 195 | os.mkdir("code") 196 | 197 | if os.path.exists(experimentdir): 198 | print(experimentdir + ' : exists. overwrite it.') 199 | shutil.rmtree(experimentdir) 200 | os.mkdir(experimentdir) 201 | else: 202 | os.mkdir(experimentdir) 203 | 204 | shutil.copytree(f"solo", os.path.join(experimentdir, 'solo')) 205 | shutil.copytree(f"bash_files", os.path.join(experimentdir, 'bash_files')) 206 | shutil.copyfile(f"main_pretrain_AdvTraining.py", os.path.join(experimentdir, 'main_pretrain_AdvTraining.py')) 207 | shutil.copyfile(f"adv_slf.py", os.path.join(experimentdir, 'adv_slf.py')) 208 | 209 | 210 | if args.dali: 211 | trainer.fit(model, val_dataloaders=val_loader, ckpt_path=ckpt_path) 212 | else: 213 | trainer.fit(model, train_loader, val_loader, ckpt_path=ckpt_path) 214 | 215 | 216 | 217 | if __name__ == "__main__": 218 | main() 219 | -------------------------------------------------------------------------------- /solo/models/wide_resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | # Adapted from timm https://github.com/rwightman/pytorch-image-models/blob/master/timm/ 21 | 22 | 23 | import torch 24 | import torch.nn as nn 25 | import torch.nn.functional as F 26 | from timm.models.registry import register_model 27 | 28 | 29 | def normalize_fn(tensor, mean, std): 30 | """Differentiable version of torchvision.functional.normalize""" 31 | # here we assume the color channel is in at dim=1 32 | mean = mean[None, :, None, None] 33 | std = std[None, :, None, None] 34 | # import ipdb; ipdb.set_trace() 35 | return tensor.sub(mean).div(std) 36 | 37 | 38 | class NormalizeByChannelMeanStd(nn.Module): 39 | def __init__(self, mean, std): 40 | super(NormalizeByChannelMeanStd, self).__init__() 41 | if not isinstance(mean, torch.Tensor): 42 | mean = torch.tensor(mean) 43 | if not isinstance(std, torch.Tensor): 44 | std = torch.tensor(std) 45 | self.register_buffer("mean", mean) 46 | self.register_buffer("std", std) 47 | 48 | def forward(self, tensor): 49 | # self.mean = self.mean.to("cuda") 50 | # self.std = self.std.to("cuda") 51 | 52 | return normalize_fn(tensor, self.mean, self.std) 53 | 54 | def extra_repr(self): 55 | return 'mean={}, std={}'.format(self.mean, self.std) 56 | 57 | class WideResnetBasicBlock(nn.Module): 58 | def __init__( 59 | self, in_planes, out_planes, stride, drop_rate=0.0, activate_before_residual=False 60 | ): 61 | super(WideResnetBasicBlock, self).__init__() 62 | self.bn1 = nn.BatchNorm2d(in_planes, momentum=0.001, eps=0.001) 63 | self.relu1 = nn.LeakyReLU(negative_slope=0.1, inplace=False) 64 | self.conv1 = nn.Conv2d( 65 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True 66 | ) 67 | self.bn2 = nn.BatchNorm2d(out_planes, momentum=0.001, eps=0.001) 68 | self.relu2 = nn.LeakyReLU(negative_slope=0.1, inplace=False) 69 | self.conv2 = nn.Conv2d( 70 | out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=True 71 | ) 72 | self.drop_rate = drop_rate 73 | self.equalInOut = in_planes == out_planes 74 | self.convShortcut = ( 75 | (not self.equalInOut) 76 | and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=True) 77 | or None 78 | ) 79 | self.activate_before_residual = activate_before_residual 80 | 81 | def forward(self, x): 82 | if not self.equalInOut and self.activate_before_residual: 83 | x = self.relu1(self.bn1(x)) 84 | else: 85 | out = self.relu1(self.bn1(x)) 86 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x))) 87 | if self.drop_rate > 0: 88 | out = F.dropout(out, p=self.drop_rate, training=self.training) 89 | out = self.conv2(out) 90 | return torch.add(x if self.equalInOut else self.convShortcut(x), out) 91 | 92 | 93 | class WideResnetNetworkBlock(nn.Module): 94 | def __init__( 95 | self, 96 | nb_layers, 97 | in_planes, 98 | out_planes, 99 | block, 100 | stride, 101 | drop_rate=0.0, 102 | activate_before_residual=False, 103 | ): 104 | super(WideResnetNetworkBlock, self).__init__() 105 | self.layer = self._make_layer( 106 | block, in_planes, out_planes, nb_layers, stride, drop_rate, activate_before_residual 107 | ) 108 | 109 | def _make_layer( 110 | self, block, in_planes, out_planes, nb_layers, stride, drop_rate, activate_before_residual 111 | ): 112 | layers = [] 113 | for i in range(int(nb_layers)): 114 | layers.append( 115 | block( 116 | i == 0 and in_planes or out_planes, 117 | out_planes, 118 | i == 0 and stride or 1, 119 | drop_rate, 120 | activate_before_residual, 121 | ) 122 | ) 123 | return nn.Sequential(*layers) 124 | 125 | def forward(self, x): 126 | return self.layer(x) 127 | 128 | 129 | class WideResNet(nn.Module): 130 | def __init__(self, first_stride=1, depth=28, widen_factor=2, drop_rate=0.0, **kwargs): 131 | super(WideResNet, self).__init__() 132 | channels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor] 133 | self.num_features = channels[-1] 134 | self.inplanes = self.num_features 135 | assert (depth - 4) % 6 == 0 136 | n = (depth - 4) / 6 137 | 138 | self.normalize = NormalizeByChannelMeanStd( 139 | mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261]) 140 | 141 | block = WideResnetBasicBlock 142 | # 1st conv before any network block 143 | self.conv1 = nn.Conv2d(3, channels[0], kernel_size=3, stride=1, padding=1, bias=True) 144 | # import ipdb; ipdb.set_trace() 145 | # 1st block 146 | self.block1 = WideResnetNetworkBlock( 147 | n, 148 | channels[0], 149 | channels[1], 150 | block, 151 | first_stride, 152 | drop_rate, 153 | activate_before_residual=True, 154 | ) 155 | # 2nd block 156 | self.block2 = WideResnetNetworkBlock(n, channels[1], channels[2], block, 2, drop_rate) 157 | # 3rd block 158 | self.block3 = WideResnetNetworkBlock(n, channels[2], channels[3], block, 2, drop_rate) 159 | # global average pooling 160 | self.bn1 = nn.BatchNorm2d(channels[3], momentum=0.001, eps=0.001) 161 | self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=False) 162 | 163 | for m in self.modules(): 164 | if isinstance(m, nn.Conv2d): 165 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="leaky_relu") 166 | elif isinstance(m, nn.BatchNorm2d): 167 | m.weight.data.fill_(1) 168 | m.bias.data.zero_() 169 | elif isinstance(m, nn.Linear): 170 | nn.init.xavier_normal_(m.weight.data) 171 | if m.bias is not None: 172 | m.bias.data.zero_() 173 | 174 | def forward(self, x): 175 | x = self.normalize(x) 176 | out = self.conv1(x) 177 | out = self.block1(out) 178 | out = self.block2(out) 179 | out = self.block3(out) 180 | out = self.relu(self.bn1(out)) 181 | out = F.adaptive_avg_pool2d(out, 1) 182 | x = out.view(-1, self.num_features) 183 | return x 184 | 185 | 186 | @register_model 187 | def wide_resnet28w2(**kwargs): 188 | encoder = WideResNet(depth=28, widen_factor=2, **kwargs) 189 | return encoder 190 | 191 | 192 | @register_model 193 | def wide_resnet28w8(**kwargs): 194 | encoder = WideResNet(depth=28, widen_factor=8, **kwargs) 195 | return encoder 196 | 197 | 198 | @register_model 199 | def wide_resnet28w10(**kwargs): 200 | encoder = WideResNet(depth=28, widen_factor=10, **kwargs) 201 | return encoder 202 | -------------------------------------------------------------------------------- /solo/utils/knn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | from typing import Tuple 21 | 22 | import torch 23 | import torch.nn.functional as F 24 | from torchmetrics.metric import Metric 25 | 26 | 27 | class WeightedKNNClassifier(Metric): 28 | def __init__( 29 | self, 30 | k: int = 20, 31 | T: float = 0.07, 32 | max_distance_matrix_size: int = int(5e6), 33 | distance_fx: str = "cosine", 34 | epsilon: float = 0.00001, 35 | dist_sync_on_step: bool = False, 36 | ): 37 | """Implements the weighted k-NN classifier used for evaluation. 38 | 39 | Args: 40 | k (int, optional): number of neighbors. Defaults to 20. 41 | T (float, optional): temperature for the exponential. Only used with cosine 42 | distance. Defaults to 0.07. 43 | max_distance_matrix_size (int, optional): maximum number of elements in the 44 | distance matrix. Defaults to 5e6. 45 | distance_fx (str, optional): Distance function. Accepted arguments: "cosine" or 46 | "euclidean". Defaults to "cosine". 47 | epsilon (float, optional): Small value for numerical stability. Only used with 48 | euclidean distance. Defaults to 0.00001. 49 | dist_sync_on_step (bool, optional): whether to sync distributed values at every 50 | step. Defaults to False. 51 | """ 52 | 53 | super().__init__(dist_sync_on_step=dist_sync_on_step, compute_on_step=False) 54 | 55 | self.k = k 56 | self.T = T 57 | self.max_distance_matrix_size = max_distance_matrix_size 58 | self.distance_fx = distance_fx 59 | self.epsilon = epsilon 60 | 61 | self.add_state("train_features", default=[], persistent=False) 62 | self.add_state("train_targets", default=[], persistent=False) 63 | self.add_state("test_features", default=[], persistent=False) 64 | self.add_state("test_targets", default=[], persistent=False) 65 | 66 | def update( 67 | self, 68 | train_features: torch.Tensor = None, 69 | train_targets: torch.Tensor = None, 70 | test_features: torch.Tensor = None, 71 | test_targets: torch.Tensor = None, 72 | ): 73 | """Updates the memory banks. If train (test) features are passed as input, the 74 | corresponding train (test) targets must be passed as well. 75 | 76 | Args: 77 | train_features (torch.Tensor, optional): a batch of train features. Defaults to None. 78 | train_targets (torch.Tensor, optional): a batch of train targets. Defaults to None. 79 | test_features (torch.Tensor, optional): a batch of test features. Defaults to None. 80 | test_targets (torch.Tensor, optional): a batch of test targets. Defaults to None. 81 | """ 82 | assert (train_features is None) == (train_targets is None) 83 | assert (test_features is None) == (test_targets is None) 84 | 85 | if train_features is not None: 86 | assert train_features.size(0) == train_targets.size(0) 87 | self.train_features.append(train_features.detach()) 88 | self.train_targets.append(train_targets.detach()) 89 | 90 | if test_features is not None: 91 | assert test_features.size(0) == test_targets.size(0) 92 | self.test_features.append(test_features.detach()) 93 | self.test_targets.append(test_targets.detach()) 94 | 95 | @torch.no_grad() 96 | def compute(self) -> Tuple[float]: 97 | """Computes weighted k-NN accuracy @1 and @5. If cosine distance is selected, 98 | the weight is computed using the exponential of the temperature scaled cosine 99 | distance of the samples. If euclidean distance is selected, the weight corresponds 100 | to the inverse of the euclidean distance. 101 | 102 | Returns: 103 | Tuple[float]: k-NN accuracy @1 and @5. 104 | """ 105 | 106 | train_features = torch.cat(self.train_features) 107 | train_targets = torch.cat(self.train_targets) 108 | test_features = torch.cat(self.test_features) 109 | test_targets = torch.cat(self.test_targets) 110 | 111 | if self.distance_fx == "cosine": 112 | train_features = F.normalize(train_features) 113 | test_features = F.normalize(test_features) 114 | 115 | num_classes = torch.unique(test_targets).numel() 116 | num_train_images = train_targets.size(0) 117 | num_test_images = test_targets.size(0) 118 | num_train_images = train_targets.size(0) 119 | chunk_size = min( 120 | max(1, self.max_distance_matrix_size // num_train_images), 121 | num_test_images, 122 | ) 123 | k = min(self.k, num_train_images) 124 | 125 | top1, top5, total = 0.0, 0.0, 0 126 | retrieval_one_hot = torch.zeros(k, num_classes).to(train_features.device) 127 | for idx in range(0, num_test_images, chunk_size): 128 | # get the features for test images 129 | features = test_features[idx : min((idx + chunk_size), num_test_images), :] 130 | targets = test_targets[idx : min((idx + chunk_size), num_test_images)] 131 | batch_size = targets.size(0) 132 | 133 | # calculate the dot product and compute top-k neighbors 134 | if self.distance_fx == "cosine": 135 | similarities = torch.mm(features, train_features.t()) 136 | elif self.distance_fx == "euclidean": 137 | similarities = 1 / (torch.cdist(features, train_features) + self.epsilon) 138 | else: 139 | raise NotImplementedError 140 | 141 | similarities, indices = similarities.topk(k, largest=True, sorted=True) 142 | candidates = train_targets.view(1, -1).expand(batch_size, -1) 143 | retrieved_neighbors = torch.gather(candidates, 1, indices) 144 | 145 | retrieval_one_hot.resize_(batch_size * k, num_classes).zero_() 146 | retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1) 147 | 148 | if self.distance_fx == "cosine": 149 | similarities = similarities.clone().div_(self.T).exp_() 150 | 151 | probs = torch.sum( 152 | torch.mul( 153 | retrieval_one_hot.view(batch_size, -1, num_classes), 154 | similarities.view(batch_size, -1, 1), 155 | ), 156 | 1, 157 | ) 158 | _, predictions = probs.sort(1, True) 159 | 160 | # find the predictions that match the target 161 | correct = predictions.eq(targets.data.view(-1, 1)) 162 | top1 = top1 + correct.narrow(1, 0, 1).sum().item() 163 | top5 = ( 164 | top5 + correct.narrow(1, 0, min(5, k, correct.size(-1))).sum().item() 165 | ) # top5 does not make sense if k < 5 166 | total += targets.size(0) 167 | 168 | top1 = top1 * 100.0 / total 169 | top5 = top5 * 100.0 / total 170 | 171 | self.reset() 172 | 173 | return top1, top5 174 | -------------------------------------------------------------------------------- /solo/utils/whitening.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | 21 | from typing import Optional 22 | 23 | import torch 24 | import torch.nn as nn 25 | from torch.cuda.amp import custom_fwd 26 | from torch.nn.functional import conv2d 27 | 28 | 29 | class Whitening2d(nn.Module): 30 | def __init__(self, output_dim: int, eps: float = 0.0): 31 | """Layer that computes hard whitening for W-MSE using the Cholesky decomposition. 32 | 33 | Args: 34 | output_dim (int): number of dimension of projected features. 35 | eps (float, optional): eps for numerical stability in Cholesky decomposition. Defaults 36 | to 0.0. 37 | """ 38 | 39 | super(Whitening2d, self).__init__() 40 | self.output_dim = output_dim 41 | self.eps = eps 42 | 43 | @custom_fwd(cast_inputs=torch.float32) 44 | def forward(self, x: torch.Tensor) -> torch.Tensor: 45 | """Performs whitening using the Cholesky decomposition. 46 | 47 | Args: 48 | x (torch.Tensor): a batch or slice of projected features. 49 | 50 | Returns: 51 | torch.Tensor: a batch or slice of whitened features. 52 | """ 53 | 54 | x = x.unsqueeze(2).unsqueeze(3) 55 | m = x.mean(0).view(self.output_dim, -1).mean(-1).view(1, -1, 1, 1) 56 | xn = x - m 57 | 58 | T = xn.permute(1, 0, 2, 3).contiguous().view(self.output_dim, -1) 59 | f_cov = torch.mm(T, T.permute(1, 0)) / (T.shape[-1] - 1) 60 | 61 | eye = torch.eye(self.output_dim).type(f_cov.type()) 62 | 63 | f_cov_shrinked = (1 - self.eps) * f_cov + self.eps * eye 64 | 65 | inv_sqrt = torch.triangular_solve(eye, torch.cholesky(f_cov_shrinked), upper=False)[0] 66 | inv_sqrt = inv_sqrt.contiguous().view(self.output_dim, self.output_dim, 1, 1) 67 | 68 | decorrelated = conv2d(xn, inv_sqrt) 69 | 70 | return decorrelated.squeeze(2).squeeze(2) 71 | 72 | 73 | class iterative_normalization_py(torch.autograd.Function): 74 | @staticmethod 75 | def forward(ctx, *args) -> torch.Tensor: 76 | X, running_mean, running_wmat, nc, ctx.T, eps, momentum, training = args 77 | 78 | # change NxCxHxW to (G x D) x(NxHxW), i.e., g*d*m 79 | ctx.g = X.size(1) // nc 80 | x = X.transpose(0, 1).contiguous().view(ctx.g, nc, -1) 81 | _, d, m = x.size() 82 | saved = [] 83 | if training: 84 | # calculate centered activation by subtracted mini-batch mean 85 | mean = x.mean(-1, keepdim=True) 86 | xc = x - mean 87 | saved.append(xc) 88 | # calculate covariance matrix 89 | P = [None] * (ctx.T + 1) 90 | P[0] = torch.eye(d).to(X).expand(ctx.g, d, d) 91 | Sigma = torch.baddbmm( 92 | beta=eps, 93 | input=P[0], 94 | alpha=1.0 / m, 95 | batch1=xc, 96 | batch2=xc.transpose(1, 2), 97 | ) 98 | # reciprocal of trace of Sigma: shape [g, 1, 1] 99 | rTr = (Sigma * P[0]).sum((1, 2), keepdim=True).reciprocal_() 100 | saved.append(rTr) 101 | Sigma_N = Sigma * rTr 102 | saved.append(Sigma_N) 103 | for k in range(ctx.T): 104 | P[k + 1] = torch.baddbmm( 105 | beta=1.5, 106 | input=P[k], 107 | alpha=-0.5, 108 | batch1=torch.matrix_power(P[k], 3), 109 | batch2=Sigma_N, 110 | ) 111 | saved.extend(P) 112 | wm = P[ctx.T].mul_( 113 | rTr.sqrt() 114 | ) # whiten matrix: the matrix inverse of Sigma, i.e., Sigma^{-1/2} 115 | 116 | running_mean.copy_(momentum * mean + (1.0 - momentum) * running_mean) 117 | running_wmat.copy_(momentum * wm + (1.0 - momentum) * running_wmat) 118 | else: 119 | xc = x - running_mean 120 | wm = running_wmat 121 | xn = wm.matmul(xc) 122 | Xn = xn.view(X.size(1), X.size(0), *X.size()[2:]).transpose(0, 1).contiguous() 123 | ctx.save_for_backward(*saved) 124 | return Xn 125 | 126 | @staticmethod 127 | def backward(ctx, *grad_outputs): 128 | (grad,) = grad_outputs 129 | saved = ctx.saved_tensors 130 | if len(saved) == 0: 131 | return None, None, None, None, None, None, None, None 132 | 133 | xc = saved[0] # centered input 134 | rTr = saved[1] # trace of Sigma 135 | sn = saved[2].transpose(-2, -1) # normalized Sigma 136 | P = saved[3:] # middle result matrix, 137 | g, d, m = xc.size() 138 | 139 | g_ = grad.transpose(0, 1).contiguous().view_as(xc) 140 | g_wm = g_.matmul(xc.transpose(-2, -1)) 141 | g_P = g_wm * rTr.sqrt() 142 | wm = P[ctx.T] 143 | g_sn = 0 144 | for k in range(ctx.T, 1, -1): 145 | P[k - 1].transpose_(-2, -1) 146 | P2 = P[k - 1].matmul(P[k - 1]) 147 | g_sn += P2.matmul(P[k - 1]).matmul(g_P) 148 | g_tmp = g_P.matmul(sn) 149 | g_P.baddbmm_(beta=1.5, alpha=-0.5, batch1=g_tmp, batch2=P2) 150 | g_P.baddbmm_(beta=1, alpha=-0.5, batch1=P2, batch2=g_tmp) 151 | g_P.baddbmm_(beta=1, alpha=-0.5, batch1=P[k - 1].matmul(g_tmp), batch2=P[k - 1]) 152 | g_sn += g_P 153 | g_tr = ((-sn.matmul(g_sn) + g_wm.transpose(-2, -1).matmul(wm)) * P[0]).sum( 154 | (1, 2), keepdim=True 155 | ) * P[0] 156 | g_sigma = (g_sn + g_sn.transpose(-2, -1) + 2.0 * g_tr) * (-0.5 / m * rTr) 157 | g_x = torch.baddbmm(wm.matmul(g_ - g_.mean(-1, keepdim=True)), g_sigma, xc) 158 | grad_input = ( 159 | g_x.view(grad.size(1), grad.size(0), *grad.size()[2:]).transpose(0, 1).contiguous() 160 | ) 161 | return grad_input, None, None, None, None, None, None, None 162 | 163 | 164 | class IterNorm(torch.nn.Module): 165 | def __init__( 166 | self, 167 | num_features: int, 168 | num_groups: int = 64, 169 | num_channels: Optional[int] = None, 170 | T: int = 5, 171 | dim: int = 2, 172 | eps: float = 1.0e-5, 173 | momentum: float = 0.1, 174 | affine: bool = True, 175 | ): 176 | super(IterNorm, self).__init__() 177 | # assert dim == 4, 'IterNorm does not support 2D' 178 | self.T = T 179 | self.eps = eps 180 | self.momentum = momentum 181 | self.num_features = num_features 182 | self.affine = affine 183 | self.dim = dim 184 | if num_channels is None: 185 | num_channels = (num_features - 1) // num_groups + 1 186 | num_groups = num_features // num_channels 187 | while num_features % num_channels != 0: 188 | num_channels //= 2 189 | num_groups = num_features // num_channels 190 | assert ( 191 | num_groups > 0 and num_features % num_groups == 0 192 | ), f"num features={num_features}, num groups={num_groups}" 193 | self.num_groups = num_groups 194 | self.num_channels = num_channels 195 | shape = [1] * dim 196 | shape[1] = self.num_features 197 | if self.affine: 198 | self.weight = nn.Parameter(torch.Tensor(*shape)) 199 | self.bias = nn.Parameter(torch.Tensor(*shape)) 200 | else: 201 | self.register_parameter("weight", None) 202 | self.register_parameter("bias", None) 203 | 204 | self.register_buffer("running_mean", torch.zeros(num_groups, num_channels, 1)) 205 | # running whiten matrix 206 | self.register_buffer( 207 | "running_wm", 208 | torch.eye(num_channels).expand(num_groups, num_channels, num_channels).clone(), 209 | ) 210 | 211 | self.reset_parameters() 212 | 213 | def reset_parameters(self): 214 | if self.affine: 215 | torch.nn.init.ones_(self.weight) 216 | torch.nn.init.zeros_(self.bias) 217 | 218 | @custom_fwd(cast_inputs=torch.float32) 219 | def forward(self, X: torch.Tensor) -> torch.Tensor: 220 | X_hat = iterative_normalization_py.apply( 221 | X, 222 | self.running_mean, 223 | self.running_wm, 224 | self.num_channels, 225 | self.T, 226 | self.eps, 227 | self.momentum, 228 | self.training, 229 | ) 230 | # affine 231 | if self.affine: 232 | return X_hat * self.weight + self.bias 233 | 234 | return X_hat 235 | 236 | def extra_repr(self): 237 | return ( 238 | f"{self.num_features}, num_channels={self.num_channels}, T={self.T}, eps={self.eps}, " 239 | "momentum={momentum}, affine={affine}" 240 | ) 241 | -------------------------------------------------------------------------------- /solo/args/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import os 21 | from argparse import Namespace 22 | from contextlib import suppress 23 | 24 | 25 | N_CLASSES_PER_DATASET = { 26 | "cifar10": 10, 27 | "cifar100": 100, 28 | "stl10": 10, 29 | "imagenet": 1000, 30 | "imagenet100": 100, 31 | } 32 | 33 | 34 | def additional_setup_pretrain(args: Namespace): 35 | """Provides final setup for pretraining to non-user given parameters by changing args. 36 | 37 | Parsers arguments to extract the number of classes of a dataset, create 38 | transformations kwargs, correctly parse gpus, identify if a cifar dataset 39 | is being used and adjust the lr. 40 | 41 | Args: 42 | args (Namespace): object that needs to contain, at least: 43 | - dataset: dataset name. 44 | - brightness, contrast, saturation, hue, min_scale: required augmentations 45 | settings. 46 | - dali: flag to use dali. 47 | - optimizer: optimizer name being used. 48 | - gpus: list of gpus to use. 49 | - lr: learning rate. 50 | 51 | [optional] 52 | - gaussian_prob, solarization_prob: optional augmentations settings. 53 | """ 54 | 55 | if args.dataset in N_CLASSES_PER_DATASET: 56 | args.num_classes = N_CLASSES_PER_DATASET[args.dataset] 57 | else: 58 | # hack to maintain the current pipeline 59 | # even if the custom dataset doesn't have any labels 60 | dir_path = args.data_dir / args.train_dir 61 | args.num_classes = max( 62 | 1, 63 | len([entry.name for entry in os.scandir(dir_path) if entry.is_dir]), 64 | ) 65 | 66 | unique_augs = max( 67 | len(p) 68 | for p in [ 69 | args.brightness, 70 | args.contrast, 71 | args.saturation, 72 | args.hue, 73 | args.color_jitter_prob, 74 | args.gray_scale_prob, 75 | args.horizontal_flip_prob, 76 | args.gaussian_prob, 77 | args.solarization_prob, 78 | args.crop_size, 79 | args.min_scale, 80 | args.max_scale, 81 | ] 82 | ) 83 | assert len(args.num_crops_per_aug) == unique_augs 84 | 85 | # assert that either all unique augmentation pipelines have a unique 86 | # parameter or that a single parameter is replicated to all pipelines 87 | for p in [ 88 | "brightness", 89 | "contrast", 90 | "saturation", 91 | "hue", 92 | "color_jitter_prob", 93 | "gray_scale_prob", 94 | "horizontal_flip_prob", 95 | "gaussian_prob", 96 | "solarization_prob", 97 | "crop_size", 98 | "min_scale", 99 | "max_scale", 100 | ]: 101 | values = getattr(args, p) 102 | n = len(values) 103 | assert n == unique_augs or n == 1 104 | 105 | if n == 1: 106 | setattr(args, p, getattr(args, p) * unique_augs) 107 | 108 | args.unique_augs = unique_augs 109 | 110 | if unique_augs > 1: 111 | args.transform_kwargs = [ 112 | dict( 113 | brightness=brightness, 114 | contrast=contrast, 115 | saturation=saturation, 116 | hue=hue, 117 | color_jitter_prob=color_jitter_prob, 118 | gray_scale_prob=gray_scale_prob, 119 | horizontal_flip_prob=horizontal_flip_prob, 120 | gaussian_prob=gaussian_prob, 121 | solarization_prob=solarization_prob, 122 | crop_size=crop_size, 123 | min_scale=min_scale, 124 | max_scale=max_scale, 125 | ) 126 | for ( 127 | brightness, 128 | contrast, 129 | saturation, 130 | hue, 131 | color_jitter_prob, 132 | gray_scale_prob, 133 | horizontal_flip_prob, 134 | gaussian_prob, 135 | solarization_prob, 136 | crop_size, 137 | min_scale, 138 | max_scale, 139 | ) in zip( 140 | args.brightness, 141 | args.contrast, 142 | args.saturation, 143 | args.hue, 144 | args.color_jitter_prob, 145 | args.gray_scale_prob, 146 | args.horizontal_flip_prob, 147 | args.gaussian_prob, 148 | args.solarization_prob, 149 | args.crop_size, 150 | args.min_scale, 151 | args.max_scale, 152 | ) 153 | ] 154 | 155 | # find number of big/small crops 156 | big_size = args.crop_size[0] 157 | num_large_crops = num_small_crops = 0 158 | for size, n_crops in zip(args.crop_size, args.num_crops_per_aug): 159 | if big_size == size: 160 | num_large_crops += n_crops 161 | else: 162 | num_small_crops += n_crops 163 | args.num_large_crops = num_large_crops 164 | args.num_small_crops = num_small_crops 165 | else: 166 | args.transform_kwargs = dict( 167 | brightness=args.brightness[0], 168 | contrast=args.contrast[0], 169 | saturation=args.saturation[0], 170 | hue=args.hue[0], 171 | color_jitter_prob=args.color_jitter_prob[0], 172 | gray_scale_prob=args.gray_scale_prob[0], 173 | horizontal_flip_prob=args.horizontal_flip_prob[0], 174 | gaussian_prob=args.gaussian_prob[0], 175 | solarization_prob=args.solarization_prob[0], 176 | crop_size=args.crop_size[0], 177 | min_scale=args.min_scale[0], 178 | max_scale=args.max_scale[0], 179 | ) 180 | 181 | # find number of big/small crops 182 | args.num_large_crops = args.num_crops_per_aug[0] 183 | args.num_small_crops = 0 184 | 185 | # add support for custom mean and std 186 | if args.dataset == "custom": 187 | if isinstance(args.transform_kwargs, dict): 188 | args.transform_kwargs["mean"] = args.mean 189 | args.transform_kwargs["std"] = args.std 190 | else: 191 | for kwargs in args.transform_kwargs: 192 | kwargs["mean"] = args.mean 193 | kwargs["std"] = args.std 194 | 195 | # create backbone-specific arguments 196 | args.backbone_args = {"cifar": args.dataset in ["cifar10", "cifar100"]} 197 | if "resnet" in args.backbone: 198 | args.backbone_args["zero_init_residual"] = args.zero_init_residual 199 | else: 200 | # dataset related for all transformers 201 | crop_size = args.crop_size[0] 202 | args.backbone_args["img_size"] = crop_size 203 | if "vit" in args.backbone: 204 | args.backbone_args["patch_size"] = args.patch_size 205 | 206 | with suppress(AttributeError): 207 | del args.zero_init_residual 208 | with suppress(AttributeError): 209 | del args.patch_size 210 | 211 | if args.dali: 212 | assert args.dataset in ["imagenet100", "imagenet", "custom"] 213 | 214 | args.extra_optimizer_args = {} 215 | if args.optimizer == "sgd": 216 | args.extra_optimizer_args["momentum"] = 0.9 217 | 218 | if isinstance(args.gpus, int): 219 | args.gpus = [args.gpus] 220 | elif isinstance(args.gpus, str): 221 | args.gpus = [int(gpu) for gpu in args.gpus.split(",") if gpu] 222 | 223 | # adjust lr according to batch size 224 | args.lr = args.lr * args.batch_size * len(args.gpus) / 256 225 | 226 | 227 | def additional_setup_linear(args: Namespace): 228 | """Provides final setup for linear evaluation to non-user given parameters by changing args. 229 | 230 | Parsers arguments to extract the number of classes of a dataset, correctly parse gpus, identify 231 | if a cifar dataset is being used and adjust the lr. 232 | 233 | Args: 234 | args: Namespace object that needs to contain, at least: 235 | - dataset: dataset name. 236 | - optimizer: optimizer name being used. 237 | - gpus: list of gpus to use. 238 | - lr: learning rate. 239 | """ 240 | 241 | if args.dataset in N_CLASSES_PER_DATASET: 242 | args.num_classes = N_CLASSES_PER_DATASET[args.dataset] 243 | else: 244 | # hack to maintain the current pipeline 245 | # even if the custom dataset doesn't have any labels 246 | dir_path = args.data_dir / args.train_dir 247 | args.num_classes = max( 248 | 1, 249 | len([entry.name for entry in os.scandir(dir_path) if entry.is_dir]), 250 | ) 251 | 252 | # create backbone-specific arguments 253 | args.backbone_args = {"cifar": args.dataset in ["cifar10", "cifar100"]} 254 | 255 | if "resnet" not in args.backbone: 256 | # dataset related for all transformers 257 | crop_size = args.crop_size[0] 258 | args.backbone_args["img_size"] = crop_size 259 | 260 | if "vit" in args.backbone: 261 | args.backbone_args["patch_size"] = args.patch_size 262 | 263 | with suppress(AttributeError): 264 | del args.patch_size 265 | 266 | if args.dali: 267 | assert args.dataset in ["imagenet100", "imagenet", "custom"] 268 | 269 | args.extra_optimizer_args = {} 270 | if args.optimizer == "sgd": 271 | args.extra_optimizer_args["momentum"] = 0.9 272 | 273 | if isinstance(args.gpus, int): 274 | args.gpus = [args.gpus] 275 | elif isinstance(args.gpus, str): 276 | args.gpus = [int(gpu) for gpu in args.gpus.split(",") if gpu] 277 | -------------------------------------------------------------------------------- /solo/utils/classification_dataloader.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import os 21 | from pathlib import Path 22 | from typing import Callable, Optional, Tuple, Union 23 | 24 | import torchvision 25 | from torch import nn 26 | from torch.utils.data import DataLoader, Dataset 27 | from torchvision import transforms 28 | from torchvision.datasets import STL10, ImageFolder 29 | 30 | 31 | def build_custom_pipeline(): 32 | """Builds augmentation pipelines for custom data. 33 | If you want to do exoteric augmentations, you can just re-write this function. 34 | Needs to return a dict with the same structure. 35 | """ 36 | 37 | pipeline = { 38 | "T_train": transforms.Compose( 39 | [ 40 | transforms.RandomResizedCrop(size=224, scale=(0.08, 1.0)), 41 | transforms.RandomHorizontalFlip(), 42 | transforms.ToTensor(), 43 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.228, 0.224, 0.225)), 44 | ] 45 | ), 46 | "T_val": transforms.Compose( 47 | [ 48 | transforms.Resize(256), # resize shorter 49 | transforms.CenterCrop(224), # take center crop 50 | transforms.ToTensor(), 51 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.228, 0.224, 0.225)), 52 | ] 53 | ), 54 | } 55 | return pipeline 56 | 57 | 58 | def prepare_transforms(dataset: str) -> Tuple[nn.Module, nn.Module]: 59 | """Prepares pre-defined train and test transformation pipelines for some datasets. 60 | 61 | Args: 62 | dataset (str): dataset name. 63 | 64 | Returns: 65 | Tuple[nn.Module, nn.Module]: training and validation transformation pipelines. 66 | """ 67 | 68 | cifar_pipeline = { 69 | "T_train": transforms.Compose( 70 | [ 71 | transforms.RandomResizedCrop(size=32, scale=(0.08, 1.0)), 72 | transforms.RandomHorizontalFlip(), 73 | transforms.ToTensor(), 74 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)), 75 | ] 76 | ), 77 | "T_val": transforms.Compose( 78 | [ 79 | transforms.ToTensor(), 80 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)), 81 | ] 82 | ), 83 | } 84 | 85 | stl_pipeline = { 86 | "T_train": transforms.Compose( 87 | [ 88 | transforms.RandomResizedCrop(size=96, scale=(0.08, 1.0)), 89 | transforms.RandomHorizontalFlip(), 90 | transforms.ToTensor(), 91 | transforms.Normalize((0.4914, 0.4823, 0.4466), (0.247, 0.243, 0.261)), 92 | ] 93 | ), 94 | "T_val": transforms.Compose( 95 | [ 96 | transforms.Resize((96, 96)), 97 | transforms.ToTensor(), 98 | transforms.Normalize((0.4914, 0.4823, 0.4466), (0.247, 0.243, 0.261)), 99 | ] 100 | ), 101 | } 102 | 103 | imagenet_pipeline = { 104 | "T_train": transforms.Compose( 105 | [ 106 | transforms.RandomResizedCrop(size=224, scale=(0.08, 1.0)), 107 | transforms.RandomHorizontalFlip(), 108 | transforms.ToTensor(), 109 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.228, 0.224, 0.225)), 110 | ] 111 | ), 112 | "T_val": transforms.Compose( 113 | [ 114 | transforms.Resize(256), # resize shorter 115 | transforms.CenterCrop(224), # take center crop 116 | transforms.ToTensor(), 117 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.228, 0.224, 0.225)), 118 | ] 119 | ), 120 | } 121 | 122 | custom_pipeline = build_custom_pipeline() 123 | 124 | pipelines = { 125 | "cifar10": cifar_pipeline, 126 | "cifar100": cifar_pipeline, 127 | "stl10": stl_pipeline, 128 | "imagenet100": imagenet_pipeline, 129 | "imagenet": imagenet_pipeline, 130 | "custom": custom_pipeline, 131 | } 132 | 133 | assert dataset in pipelines 134 | 135 | pipeline = pipelines[dataset] 136 | T_train = pipeline["T_train"] 137 | T_val = pipeline["T_val"] 138 | 139 | return T_train, T_val 140 | 141 | 142 | def prepare_datasets( 143 | dataset: str, 144 | T_train: Callable, 145 | T_val: Callable, 146 | data_dir: Optional[Union[str, Path]] = None, 147 | train_dir: Optional[Union[str, Path]] = None, 148 | val_dir: Optional[Union[str, Path]] = None, 149 | download: bool = True, 150 | ) -> Tuple[Dataset, Dataset]: 151 | """Prepares train and val datasets. 152 | 153 | Args: 154 | dataset (str): dataset name. 155 | T_train (Callable): pipeline of transformations for training dataset. 156 | T_val (Callable): pipeline of transformations for validation dataset. 157 | data_dir Optional[Union[str, Path]]: path where to download/locate the dataset. 158 | train_dir Optional[Union[str, Path]]: subpath where the training data is located. 159 | val_dir Optional[Union[str, Path]]: subpath where the validation data is located. 160 | 161 | Returns: 162 | Tuple[Dataset, Dataset]: training dataset and validation dataset. 163 | """ 164 | 165 | if data_dir is None: 166 | sandbox_dir = Path(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 167 | data_dir = sandbox_dir / "datasets" 168 | else: 169 | data_dir = Path(data_dir) 170 | 171 | if train_dir is None: 172 | train_dir = Path(f"{dataset}/train") 173 | else: 174 | train_dir = Path(train_dir) 175 | 176 | if val_dir is None: 177 | val_dir = Path(f"{dataset}/val") 178 | else: 179 | val_dir = Path(val_dir) 180 | 181 | assert dataset in ["cifar10", "cifar100", "stl10", "imagenet", "imagenet100", "custom"] 182 | 183 | if dataset in ["cifar10", "cifar100"]: 184 | DatasetClass = vars(torchvision.datasets)[dataset.upper()] 185 | train_dataset = DatasetClass( 186 | data_dir / train_dir, 187 | train=True, 188 | download=download, 189 | transform=T_train, 190 | ) 191 | 192 | val_dataset = DatasetClass( 193 | data_dir / val_dir, 194 | train=False, 195 | download=download, 196 | transform=T_val, 197 | ) 198 | 199 | elif dataset == "stl10": 200 | train_dataset = STL10( 201 | data_dir / train_dir, 202 | split="train", 203 | download=True, 204 | transform=T_train, 205 | ) 206 | val_dataset = STL10( 207 | data_dir / val_dir, 208 | split="test", 209 | download=download, 210 | transform=T_val, 211 | ) 212 | 213 | elif dataset in ["imagenet", "imagenet100", "custom"]: 214 | train_dir = data_dir / train_dir 215 | val_dir = data_dir / val_dir 216 | 217 | train_dataset = ImageFolder(train_dir, T_train) 218 | val_dataset = ImageFolder(val_dir, T_val) 219 | 220 | return train_dataset, val_dataset 221 | 222 | 223 | def prepare_dataloaders( 224 | train_dataset: Dataset, val_dataset: Dataset, batch_size: int = 64, num_workers: int = 4 225 | ) -> Tuple[DataLoader, DataLoader]: 226 | """Wraps a train and a validation dataset with a DataLoader. 227 | 228 | Args: 229 | train_dataset (Dataset): object containing training data. 230 | val_dataset (Dataset): object containing validation data. 231 | batch_size (int): batch size. 232 | num_workers (int): number of parallel workers. 233 | Returns: 234 | Tuple[DataLoader, DataLoader]: training dataloader and validation dataloader. 235 | """ 236 | 237 | train_loader = DataLoader( 238 | train_dataset, 239 | batch_size=batch_size, 240 | shuffle=True, 241 | num_workers=num_workers, 242 | pin_memory=True, 243 | drop_last=True, 244 | ) 245 | val_loader = DataLoader( 246 | val_dataset, 247 | batch_size=batch_size, 248 | num_workers=num_workers, 249 | pin_memory=True, 250 | drop_last=False, 251 | ) 252 | return train_loader, val_loader 253 | 254 | 255 | def prepare_data( 256 | dataset: str, 257 | data_dir: Optional[Union[str, Path]] = None, 258 | train_dir: Optional[Union[str, Path]] = None, 259 | val_dir: Optional[Union[str, Path]] = None, 260 | batch_size: int = 64, 261 | num_workers: int = 4, 262 | download: bool = True, 263 | ) -> Tuple[DataLoader, DataLoader]: 264 | """Prepares transformations, creates dataset objects and wraps them in dataloaders. 265 | 266 | Args: 267 | dataset (str): dataset name. 268 | data_dir (Optional[Union[str, Path]], optional): path where to download/locate the dataset. 269 | Defaults to None. 270 | train_dir (Optional[Union[str, Path]], optional): subpath where the 271 | training data is located. Defaults to None. 272 | val_dir (Optional[Union[str, Path]], optional): subpath where the 273 | validation data is located. Defaults to None. 274 | batch_size (int, optional): batch size. Defaults to 64. 275 | num_workers (int, optional): number of parallel workers. Defaults to 4. 276 | 277 | Returns: 278 | Tuple[DataLoader, DataLoader]: prepared training and validation dataloader;. 279 | """ 280 | 281 | T_train, T_val = prepare_transforms(dataset) 282 | train_dataset, val_dataset = prepare_datasets( 283 | dataset, 284 | T_train, 285 | T_val, 286 | data_dir=data_dir, 287 | train_dir=train_dir, 288 | val_dir=val_dir, 289 | download=download, 290 | ) 291 | train_loader, val_loader = prepare_dataloaders( 292 | train_dataset, 293 | val_dataset, 294 | batch_size=batch_size, 295 | num_workers=num_workers, 296 | ) 297 | return train_loader, val_loader 298 | -------------------------------------------------------------------------------- /solo/utils/classification_dataloader_AdvTraining.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import os 21 | from pathlib import Path 22 | from typing import Callable, Optional, Tuple, Union 23 | 24 | import torchvision 25 | from torch import nn 26 | from torch.utils.data import DataLoader, Dataset 27 | from torchvision import transforms 28 | from torchvision.datasets import STL10, ImageFolder 29 | 30 | 31 | def build_custom_pipeline(): 32 | """Builds augmentation pipelines for custom data. 33 | If you want to do exoteric augmentations, you can just re-write this function. 34 | Needs to return a dict with the same structure. 35 | """ 36 | 37 | pipeline = { 38 | "T_train": transforms.Compose( 39 | [ 40 | transforms.RandomResizedCrop(size=224, scale=(0.08, 1.0)), 41 | transforms.RandomHorizontalFlip(), 42 | transforms.ToTensor(), 43 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.228, 0.224, 0.225)), 44 | ] 45 | ), 46 | "T_val": transforms.Compose( 47 | [ 48 | transforms.Resize(256), # resize shorter 49 | transforms.CenterCrop(224), # take center crop 50 | transforms.ToTensor(), 51 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.228, 0.224, 0.225)), 52 | ] 53 | ), 54 | } 55 | return pipeline 56 | 57 | 58 | def prepare_transforms(dataset: str) -> Tuple[nn.Module, nn.Module]: 59 | """Prepares pre-defined train and test transformation pipelines for some datasets. 60 | 61 | Args: 62 | dataset (str): dataset name. 63 | 64 | Returns: 65 | Tuple[nn.Module, nn.Module]: training and validation transformation pipelines. 66 | """ 67 | 68 | cifar_pipeline = { 69 | "T_train": transforms.Compose( 70 | [ 71 | # transforms.RandomResizedCrop(size=32, scale=(0.08, 1.0)), 72 | transforms.RandomCrop(32, padding=4), 73 | transforms.RandomHorizontalFlip(), 74 | transforms.ToTensor(), 75 | # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)), 76 | ] 77 | ), 78 | "T_val": transforms.Compose( 79 | [ 80 | transforms.ToTensor(), 81 | # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)), 82 | ] 83 | ), 84 | } 85 | 86 | stl_pipeline = { 87 | "T_train": transforms.Compose( 88 | [ 89 | transforms.RandomResizedCrop(size=96, scale=(0.08, 1.0)), 90 | transforms.RandomHorizontalFlip(), 91 | transforms.ToTensor(), 92 | transforms.Normalize((0.4914, 0.4823, 0.4466), (0.247, 0.243, 0.261)), 93 | ] 94 | ), 95 | "T_val": transforms.Compose( 96 | [ 97 | transforms.Resize((96, 96)), 98 | transforms.ToTensor(), 99 | transforms.Normalize((0.4914, 0.4823, 0.4466), (0.247, 0.243, 0.261)), 100 | ] 101 | ), 102 | } 103 | 104 | imagenet_pipeline = { 105 | "T_train": transforms.Compose( 106 | [ 107 | transforms.RandomResizedCrop(size=224, scale=(0.08, 1.0)), 108 | transforms.RandomHorizontalFlip(), 109 | transforms.ToTensor(), 110 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.228, 0.224, 0.225)), 111 | ] 112 | ), 113 | "T_val": transforms.Compose( 114 | [ 115 | transforms.Resize(256), # resize shorter 116 | transforms.CenterCrop(224), # take center crop 117 | transforms.ToTensor(), 118 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.228, 0.224, 0.225)), 119 | ] 120 | ), 121 | } 122 | 123 | custom_pipeline = build_custom_pipeline() 124 | 125 | pipelines = { 126 | "cifar10": cifar_pipeline, 127 | "cifar100": cifar_pipeline, 128 | "stl10": stl_pipeline, 129 | "imagenet100": imagenet_pipeline, 130 | "imagenet": imagenet_pipeline, 131 | "custom": custom_pipeline, 132 | } 133 | 134 | assert dataset in pipelines 135 | 136 | pipeline = pipelines[dataset] 137 | T_train = pipeline["T_train"] 138 | T_val = pipeline["T_val"] 139 | 140 | return T_train, T_val 141 | 142 | 143 | def prepare_datasets( 144 | dataset: str, 145 | T_train: Callable, 146 | T_val: Callable, 147 | data_dir: Optional[Union[str, Path]] = None, 148 | train_dir: Optional[Union[str, Path]] = None, 149 | val_dir: Optional[Union[str, Path]] = None, 150 | download: bool = True, 151 | ) -> Tuple[Dataset, Dataset]: 152 | """Prepares train and val datasets. 153 | 154 | Args: 155 | dataset (str): dataset name. 156 | T_train (Callable): pipeline of transformations for training dataset. 157 | T_val (Callable): pipeline of transformations for validation dataset. 158 | data_dir Optional[Union[str, Path]]: path where to download/locate the dataset. 159 | train_dir Optional[Union[str, Path]]: subpath where the training data is located. 160 | val_dir Optional[Union[str, Path]]: subpath where the validation data is located. 161 | 162 | Returns: 163 | Tuple[Dataset, Dataset]: training dataset and validation dataset. 164 | """ 165 | 166 | if data_dir is None: 167 | sandbox_dir = Path(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 168 | data_dir = sandbox_dir / "datasets" 169 | else: 170 | data_dir = Path(data_dir) 171 | 172 | if train_dir is None: 173 | train_dir = Path(f"{dataset}/train") 174 | else: 175 | train_dir = Path(train_dir) 176 | 177 | if val_dir is None: 178 | val_dir = Path(f"{dataset}/val") 179 | else: 180 | val_dir = Path(val_dir) 181 | 182 | assert dataset in ["cifar10", "cifar100", "stl10", "imagenet", "imagenet100", "custom"] 183 | 184 | if dataset in ["cifar10", "cifar100"]: 185 | DatasetClass = vars(torchvision.datasets)[dataset.upper()] 186 | train_dataset = DatasetClass( 187 | data_dir / train_dir, 188 | train=True, 189 | download=download, 190 | transform=T_train, 191 | ) 192 | 193 | val_dataset = DatasetClass( 194 | data_dir / val_dir, 195 | train=False, 196 | download=download, 197 | transform=T_val, 198 | ) 199 | 200 | elif dataset == "stl10": 201 | train_dataset = STL10( 202 | data_dir / train_dir, 203 | split="train", 204 | download=True, 205 | transform=T_train, 206 | ) 207 | val_dataset = STL10( 208 | data_dir / val_dir, 209 | split="test", 210 | download=download, 211 | transform=T_val, 212 | ) 213 | 214 | elif dataset in ["imagenet", "imagenet100", "custom"]: 215 | train_dir = data_dir / train_dir 216 | val_dir = data_dir / val_dir 217 | 218 | train_dataset = ImageFolder(train_dir, T_train) 219 | val_dataset = ImageFolder(val_dir, T_val) 220 | 221 | return train_dataset, val_dataset 222 | 223 | 224 | def prepare_dataloaders( 225 | train_dataset: Dataset, val_dataset: Dataset, batch_size: int = 64, num_workers: int = 4 226 | ) -> Tuple[DataLoader, DataLoader]: 227 | """Wraps a train and a validation dataset with a DataLoader. 228 | 229 | Args: 230 | train_dataset (Dataset): object containing training data. 231 | val_dataset (Dataset): object containing validation data. 232 | batch_size (int): batch size. 233 | num_workers (int): number of parallel workers. 234 | Returns: 235 | Tuple[DataLoader, DataLoader]: training dataloader and validation dataloader. 236 | """ 237 | 238 | train_loader = DataLoader( 239 | train_dataset, 240 | batch_size=batch_size, 241 | shuffle=True, 242 | num_workers=num_workers, 243 | pin_memory=True, 244 | drop_last=True, 245 | ) 246 | val_loader = DataLoader( 247 | val_dataset, 248 | batch_size=batch_size, 249 | num_workers=num_workers, 250 | pin_memory=True, 251 | drop_last=False, 252 | ) 253 | return train_loader, val_loader 254 | 255 | 256 | def prepare_data( 257 | dataset: str, 258 | data_dir: Optional[Union[str, Path]] = None, 259 | train_dir: Optional[Union[str, Path]] = None, 260 | val_dir: Optional[Union[str, Path]] = None, 261 | batch_size: int = 64, 262 | num_workers: int = 4, 263 | download: bool = True, 264 | ) -> Tuple[DataLoader, DataLoader]: 265 | """Prepares transformations, creates dataset objects and wraps them in dataloaders. 266 | 267 | Args: 268 | dataset (str): dataset name. 269 | data_dir (Optional[Union[str, Path]], optional): path where to download/locate the dataset. 270 | Defaults to None. 271 | train_dir (Optional[Union[str, Path]], optional): subpath where the 272 | training data is located. Defaults to None. 273 | val_dir (Optional[Union[str, Path]], optional): subpath where the 274 | validation data is located. Defaults to None. 275 | batch_size (int, optional): batch size. Defaults to 64. 276 | num_workers (int, optional): number of parallel workers. Defaults to 4. 277 | 278 | Returns: 279 | Tuple[DataLoader, DataLoader]: prepared training and validation dataloader;. 280 | """ 281 | 282 | T_train, T_val = prepare_transforms(dataset) 283 | train_dataset, val_dataset = prepare_datasets( 284 | dataset, 285 | T_train, 286 | T_val, 287 | data_dir=data_dir, 288 | train_dir=train_dir, 289 | val_dir=val_dir, 290 | download=download, 291 | ) 292 | train_loader, val_loader = prepare_dataloaders( 293 | train_dataset, 294 | val_dataset, 295 | batch_size=batch_size, 296 | num_workers=num_workers, 297 | ) 298 | return train_loader, val_loader 299 | -------------------------------------------------------------------------------- /solo/utils/auto_umap.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import math 21 | import os 22 | import random 23 | import string 24 | import time 25 | from argparse import ArgumentParser, Namespace 26 | from pathlib import Path 27 | from typing import Optional, Union 28 | 29 | import pandas as pd 30 | import pytorch_lightning as pl 31 | import seaborn as sns 32 | import torch 33 | import torch.nn as nn 34 | import umap 35 | import wandb 36 | from matplotlib import pyplot as plt 37 | from pytorch_lightning.callbacks import Callback 38 | from tqdm import tqdm 39 | 40 | from .misc import gather 41 | 42 | 43 | def random_string(letter_count=4, digit_count=4): 44 | tmp_random = random.Random(time.time()) 45 | rand_str = "".join((tmp_random.choice(string.ascii_lowercase) for x in range(letter_count))) 46 | rand_str += "".join((tmp_random.choice(string.digits) for x in range(digit_count))) 47 | rand_str = list(rand_str) 48 | tmp_random.shuffle(rand_str) 49 | return "".join(rand_str) 50 | 51 | 52 | class AutoUMAP(Callback): 53 | def __init__( 54 | self, 55 | args: Namespace, 56 | logdir: Union[str, Path] = Path("auto_umap"), 57 | frequency: int = 1, 58 | keep_previous: bool = False, 59 | color_palette: str = "hls", 60 | ): 61 | """UMAP callback that automatically runs UMAP on the validation dataset and uploads the 62 | figure to wandb. 63 | 64 | Args: 65 | args (Namespace): namespace object containing at least an attribute name. 66 | logdir (Union[str, Path], optional): base directory to store checkpoints. 67 | Defaults to Path("auto_umap"). 68 | frequency (int, optional): number of epochs between each UMAP. Defaults to 1. 69 | color_palette (str, optional): color scheme for the classes. Defaults to "hls". 70 | keep_previous (bool, optional): whether to keep previous plots or not. 71 | Defaults to False. 72 | """ 73 | 74 | super().__init__() 75 | 76 | self.args = args 77 | self.logdir = Path(logdir) 78 | self.frequency = frequency 79 | self.color_palette = color_palette 80 | self.keep_previous = keep_previous 81 | 82 | @staticmethod 83 | def add_auto_umap_args(parent_parser: ArgumentParser): 84 | """Adds user-required arguments to a parser. 85 | 86 | Args: 87 | parent_parser (ArgumentParser): parser to add new args to. 88 | """ 89 | 90 | parser = parent_parser.add_argument_group("auto_umap") 91 | parser.add_argument("--auto_umap_dir", default=Path("auto_umap"), type=Path) 92 | parser.add_argument("--auto_umap_frequency", default=1, type=int) 93 | return parent_parser 94 | 95 | def initial_setup(self, trainer: pl.Trainer): 96 | """Creates the directories and does the initial setup needed. 97 | 98 | Args: 99 | trainer (pl.Trainer): pytorch lightning trainer object. 100 | """ 101 | 102 | if trainer.logger is None: 103 | if self.logdir.exists(): 104 | existing_versions = set(os.listdir(self.logdir)) 105 | else: 106 | existing_versions = [] 107 | version = "offline-" + random_string() 108 | while version in existing_versions: 109 | version = "offline-" + random_string() 110 | else: 111 | version = str(trainer.logger.version) 112 | if version is not None: 113 | self.path = self.logdir / version 114 | self.umap_placeholder = f"{self.args.name}-{version}" + "-ep={}.pdf" 115 | else: 116 | self.path = self.logdir 117 | self.umap_placeholder = f"{self.args.name}" + "-ep={}.pdf" 118 | self.last_ckpt: Optional[str] = None 119 | 120 | # create logging dirs 121 | if trainer.is_global_zero: 122 | os.makedirs(self.path, exist_ok=True) 123 | 124 | def on_train_start(self, trainer: pl.Trainer, _): 125 | """Performs initial setup on training start. 126 | 127 | Args: 128 | trainer (pl.Trainer): pytorch lightning trainer object. 129 | """ 130 | 131 | self.initial_setup(trainer) 132 | 133 | def plot(self, trainer: pl.Trainer, module: pl.LightningModule): 134 | """Produces a UMAP visualization by forwarding all data of the 135 | first validation dataloader through the module. 136 | 137 | Args: 138 | trainer (pl.Trainer): pytorch lightning trainer object. 139 | module (pl.LightningModule): current module object. 140 | """ 141 | 142 | device = module.device 143 | data = [] 144 | Y = [] 145 | 146 | # set module to eval model and collect all feature representations 147 | module.eval() 148 | with torch.no_grad(): 149 | for x, y in trainer.val_dataloaders[0]: 150 | x = x.to(device, non_blocking=True) 151 | y = y.to(device, non_blocking=True) 152 | 153 | feats = module(x)["feats"] 154 | 155 | feats = gather(feats) 156 | y = gather(y) 157 | 158 | data.append(feats.cpu()) 159 | Y.append(y.cpu()) 160 | module.train() 161 | 162 | if trainer.is_global_zero and len(data): 163 | data = torch.cat(data, dim=0).numpy() 164 | Y = torch.cat(Y, dim=0) 165 | num_classes = len(torch.unique(Y)) 166 | Y = Y.numpy() 167 | 168 | data = umap.UMAP(n_components=2).fit_transform(data) 169 | 170 | # passing to dataframe 171 | df = pd.DataFrame() 172 | df["feat_1"] = data[:, 0] 173 | df["feat_2"] = data[:, 1] 174 | df["Y"] = Y 175 | plt.figure(figsize=(9, 9)) 176 | ax = sns.scatterplot( 177 | x="feat_1", 178 | y="feat_2", 179 | hue="Y", 180 | palette=sns.color_palette(self.color_palette, num_classes), 181 | data=df, 182 | legend="full", 183 | alpha=0.3, 184 | ) 185 | ax.set(xlabel="", ylabel="", xticklabels=[], yticklabels=[]) 186 | ax.tick_params(left=False, right=False, bottom=False, top=False) 187 | 188 | # manually improve quality of imagenet umaps 189 | if num_classes > 100: 190 | anchor = (0.5, 1.8) 191 | else: 192 | anchor = (0.5, 1.35) 193 | 194 | plt.legend(loc="upper center", bbox_to_anchor=anchor, ncol=math.ceil(num_classes / 10)) 195 | plt.tight_layout() 196 | 197 | if isinstance(trainer.logger, pl.loggers.WandbLogger): 198 | wandb.log( 199 | {"validation_umap": wandb.Image(ax)}, 200 | commit=False, 201 | ) 202 | 203 | # save plot locally as well 204 | epoch = trainer.current_epoch # type: ignore 205 | plt.savefig(self.path / self.umap_placeholder.format(epoch)) 206 | plt.close() 207 | 208 | def on_validation_end(self, trainer: pl.Trainer, module: pl.LightningModule): 209 | """Tries to generate an up-to-date UMAP visualization of the features 210 | at the end of each validation epoch. 211 | 212 | Args: 213 | trainer (pl.Trainer): pytorch lightning trainer object. 214 | """ 215 | 216 | epoch = trainer.current_epoch # type: ignore 217 | if epoch % self.frequency == 0 and not trainer.sanity_checking: 218 | self.plot(trainer, module) 219 | 220 | 221 | class OfflineUMAP: 222 | def __init__(self, color_palette: str = "hls"): 223 | """Offline UMAP helper. 224 | 225 | Args: 226 | color_palette (str, optional): color scheme for the classes. Defaults to "hls". 227 | """ 228 | 229 | self.color_palette = color_palette 230 | 231 | def plot( 232 | self, 233 | device: str, 234 | model: nn.Module, 235 | dataloader: torch.utils.data.DataLoader, 236 | plot_path: str, 237 | ): 238 | """Produces a UMAP visualization by forwarding all data of the 239 | first validation dataloader through the model. 240 | **Note: the model should produce features for the forward() function. 241 | 242 | Args: 243 | device (str): gpu/cpu device. 244 | model (nn.Module): current model. 245 | dataloader (torch.utils.data.Dataloader): current dataloader containing data. 246 | plot_path (str): path to save the figure. 247 | """ 248 | 249 | data = [] 250 | Y = [] 251 | 252 | # set module to eval model and collect all feature representations 253 | model.eval() 254 | with torch.no_grad(): 255 | for x, y in tqdm(dataloader, desc="Collecting features"): 256 | x = x.to(device, non_blocking=True) 257 | y = y.to(device, non_blocking=True) 258 | 259 | feats = model(x) 260 | data.append(feats.cpu()) 261 | Y.append(y.cpu()) 262 | model.train() 263 | 264 | data = torch.cat(data, dim=0).numpy() 265 | Y = torch.cat(Y, dim=0) 266 | num_classes = len(torch.unique(Y)) 267 | Y = Y.numpy() 268 | 269 | print("Creating UMAP") 270 | data = umap.UMAP(n_components=2).fit_transform(data) 271 | 272 | # passing to dataframe 273 | df = pd.DataFrame() 274 | df["feat_1"] = data[:, 0] 275 | df["feat_2"] = data[:, 1] 276 | df["Y"] = Y 277 | plt.figure(figsize=(9, 9)) 278 | ax = sns.scatterplot( 279 | x="feat_1", 280 | y="feat_2", 281 | hue="Y", 282 | palette=sns.color_palette(self.color_palette, num_classes), 283 | data=df, 284 | legend="full", 285 | alpha=0.3, 286 | ) 287 | ax.set(xlabel="", ylabel="", xticklabels=[], yticklabels=[]) 288 | ax.tick_params(left=False, right=False, bottom=False, top=False) 289 | 290 | # manually improve quality of imagenet umaps 291 | if num_classes > 100: 292 | anchor = (0.5, 1.8) 293 | else: 294 | anchor = (0.5, 1.35) 295 | 296 | plt.legend(loc="upper center", bbox_to_anchor=anchor, ncol=math.ceil(num_classes / 10)) 297 | plt.tight_layout() 298 | 299 | # save plot locally as well 300 | plt.savefig(plot_path) 301 | plt.close() 302 | -------------------------------------------------------------------------------- /solo/methods/dali.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import math 21 | from abc import ABC 22 | from pathlib import Path 23 | from typing import List 24 | 25 | import torch 26 | import torch.nn as nn 27 | from nvidia.dali.plugin.pytorch import DALIGenericIterator, LastBatchPolicy 28 | from solo.utils.dali_dataloader import ( 29 | CustomNormalPipeline, 30 | CustomTransform, 31 | ImagenetTransform, 32 | NormalPipeline, 33 | PretrainPipeline, 34 | ) 35 | 36 | 37 | class BaseWrapper(DALIGenericIterator): 38 | """Temporary fix to handle LastBatchPolicy.DROP.""" 39 | 40 | def __len__(self): 41 | size = ( 42 | self._size_no_pad // self._shards_num 43 | if self._last_batch_policy == LastBatchPolicy.DROP 44 | else self.size 45 | ) 46 | if self._reader_name: 47 | if self._last_batch_policy != LastBatchPolicy.DROP: 48 | return math.ceil(size / self.batch_size) 49 | 50 | return size // self.batch_size 51 | else: 52 | if self._last_batch_policy != LastBatchPolicy.DROP: 53 | return math.ceil(size / (self._num_gpus * self.batch_size)) 54 | 55 | return size // (self._num_gpus * self.batch_size) 56 | 57 | 58 | class PretrainWrapper(BaseWrapper): 59 | def __init__( 60 | self, 61 | model_batch_size: int, 62 | model_rank: int, 63 | model_device: str, 64 | conversion_map: List[int] = None, 65 | *args, 66 | **kwargs, 67 | ): 68 | """Adds indices to a batch fetched from the parent. 69 | 70 | Args: 71 | model_batch_size (int): batch size. 72 | model_rank (int): rank of the current process. 73 | model_device (str): id of the current device. 74 | conversion_map (List[int], optional): list of integeres that map each index 75 | to a class label. If nothing is passed, no label mapping needs to be done. 76 | Defaults to None. 77 | """ 78 | 79 | super().__init__(*args, **kwargs) 80 | self.model_batch_size = model_batch_size 81 | self.model_rank = model_rank 82 | self.model_device = model_device 83 | self.conversion_map = conversion_map 84 | if self.conversion_map is not None: 85 | self.conversion_map = torch.tensor( 86 | self.conversion_map, dtype=torch.float32, device=self.model_device 87 | ).reshape(-1, 1) 88 | self.conversion_map = nn.Embedding.from_pretrained(self.conversion_map) 89 | 90 | def __next__(self): 91 | batch = super().__next__()[0] 92 | # PyTorch Lightning does double buffering 93 | # https://github.com/PyTorchLightning/pytorch-lightning/issues/1316, 94 | # and as DALI owns the tensors it returns the content of it is trashed so the copy needs, 95 | # to be made before returning. 96 | 97 | if self.conversion_map is not None: 98 | *all_X, indexes = [batch[v] for v in self.output_map] 99 | targets = self.conversion_map(indexes).flatten().long().detach().clone() 100 | indexes = indexes.flatten().long().detach().clone() 101 | else: 102 | *all_X, targets = [batch[v] for v in self.output_map] 103 | targets = targets.squeeze(-1).long().detach().clone() 104 | # creates dummy indexes 105 | indexes = ( 106 | ( 107 | torch.arange(self.model_batch_size, device=self.model_device) 108 | + (self.model_rank * self.model_batch_size) 109 | ) 110 | .detach() 111 | .clone() 112 | ) 113 | 114 | all_X = [x.detach().clone() for x in all_X] 115 | return [indexes, all_X, targets] 116 | 117 | 118 | class Wrapper(BaseWrapper): 119 | def __next__(self): 120 | batch = super().__next__() 121 | x, target = batch[0]["x"], batch[0]["label"] 122 | target = target.squeeze(-1).long() 123 | # PyTorch Lightning does double buffering 124 | # https://github.com/PyTorchLightning/pytorch-lightning/issues/1316, 125 | # and as DALI owns the tensors it returns the content of it is trashed so the copy needs, 126 | # to be made before returning. 127 | x = x.detach().clone() 128 | target = target.detach().clone() 129 | return x, target 130 | 131 | 132 | class PretrainABC(ABC): 133 | """Abstract pretrain class that returns a train_dataloader using dali.""" 134 | 135 | def train_dataloader(self) -> DALIGenericIterator: 136 | """Returns a train dataloader using dali. Supports multi-crop and asymmetric augmentations. 137 | 138 | Returns: 139 | DALIGenericIterator: a train dataloader in the form of a dali pipeline object wrapped 140 | with PretrainWrapper. 141 | """ 142 | 143 | device_id = self.local_rank 144 | shard_id = self.global_rank 145 | num_shards = self.trainer.world_size 146 | 147 | # get data arguments from model 148 | dali_device = self.extra_args["dali_device"] 149 | 150 | # data augmentations 151 | unique_augs = self.extra_args["unique_augs"] 152 | transform_kwargs = self.extra_args["transform_kwargs"] 153 | num_crops_per_aug = self.extra_args["num_crops_per_aug"] 154 | 155 | num_workers = self.extra_args["num_workers"] 156 | data_dir = Path(self.extra_args["data_dir"]) 157 | train_dir = Path(self.extra_args["train_dir"]) 158 | 159 | # hack to encode image indexes into the labels 160 | self.encode_indexes_into_labels = self.extra_args["encode_indexes_into_labels"] 161 | 162 | # handle custom data by creating the needed pipeline 163 | dataset = self.extra_args["dataset"] 164 | if dataset in ["imagenet100", "imagenet"]: 165 | transform_pipeline = ImagenetTransform 166 | elif dataset == "custom": 167 | transform_pipeline = CustomTransform 168 | else: 169 | raise ValueError(dataset, "is not supported, used [imagenet, imagenet100 or custom]") 170 | 171 | if unique_augs > 1: 172 | transforms = [ 173 | transform_pipeline( 174 | device=dali_device, 175 | **kwargs, 176 | ) 177 | for kwargs in transform_kwargs 178 | ] 179 | else: 180 | transforms = [transform_pipeline(device=dali_device, **transform_kwargs)] 181 | 182 | train_pipeline = PretrainPipeline( 183 | data_dir / train_dir, 184 | batch_size=self.batch_size, 185 | transforms=transforms, 186 | num_crops_per_aug=num_crops_per_aug, 187 | device=dali_device, 188 | device_id=device_id, 189 | shard_id=shard_id, 190 | num_shards=num_shards, 191 | num_threads=num_workers, 192 | no_labels=self.extra_args["no_labels"], 193 | encode_indexes_into_labels=self.encode_indexes_into_labels, 194 | ) 195 | output_map = ( 196 | [f"large{i}" for i in range(self.num_large_crops)] 197 | + [f"small{i}" for i in range(self.num_small_crops)] 198 | + ["label"] 199 | ) 200 | 201 | policy = LastBatchPolicy.DROP 202 | conversion_map = train_pipeline.conversion_map if self.encode_indexes_into_labels else None 203 | train_loader = PretrainWrapper( 204 | model_batch_size=self.batch_size, 205 | model_rank=device_id, 206 | model_device=self.device, 207 | conversion_map=conversion_map, 208 | pipelines=train_pipeline, 209 | output_map=output_map, 210 | reader_name="Reader", 211 | last_batch_policy=policy, 212 | auto_reset=True, 213 | ) 214 | 215 | self.dali_epoch_size = train_pipeline.epoch_size("Reader") 216 | 217 | return train_loader 218 | 219 | 220 | class ClassificationABC(ABC): 221 | """Abstract classification class that returns a train_dataloader and val_dataloader using 222 | dali.""" 223 | 224 | def train_dataloader(self) -> DALIGenericIterator: 225 | device_id = self.local_rank 226 | shard_id = self.global_rank 227 | num_shards = self.trainer.world_size 228 | 229 | num_workers = self.extra_args["num_workers"] 230 | dali_device = self.extra_args["dali_device"] 231 | data_dir = Path(self.extra_args["data_dir"]) 232 | train_dir = Path(self.extra_args["train_dir"]) 233 | 234 | # handle custom data by creating the needed pipeline 235 | dataset = self.extra_args["dataset"] 236 | if dataset in ["imagenet100", "imagenet"]: 237 | pipeline_class = NormalPipeline 238 | elif dataset == "custom": 239 | pipeline_class = CustomNormalPipeline 240 | else: 241 | raise ValueError(dataset, "is not supported, used [imagenet, imagenet100 or custom]") 242 | 243 | train_pipeline = pipeline_class( 244 | data_dir / train_dir, 245 | validation=False, 246 | batch_size=self.batch_size, 247 | device=dali_device, 248 | device_id=device_id, 249 | shard_id=shard_id, 250 | num_shards=num_shards, 251 | num_threads=num_workers, 252 | ) 253 | train_loader = Wrapper( 254 | train_pipeline, 255 | output_map=["x", "label"], 256 | reader_name="Reader", 257 | last_batch_policy=LastBatchPolicy.DROP, 258 | auto_reset=True, 259 | ) 260 | return train_loader 261 | 262 | def val_dataloader(self) -> DALIGenericIterator: 263 | device_id = self.local_rank 264 | shard_id = self.global_rank 265 | num_shards = self.trainer.world_size 266 | 267 | num_workers = self.extra_args["num_workers"] 268 | dali_device = self.extra_args["dali_device"] 269 | data_dir = Path(self.extra_args["data_dir"]) 270 | val_dir = Path(self.extra_args["val_dir"]) 271 | 272 | # handle custom data by creating the needed pipeline 273 | dataset = self.extra_args["dataset"] 274 | if dataset in ["imagenet100", "imagenet"]: 275 | pipeline_class = NormalPipeline 276 | elif dataset == "custom": 277 | pipeline_class = CustomNormalPipeline 278 | else: 279 | raise ValueError(dataset, "is not supported, used [imagenet, imagenet100 or custom]") 280 | 281 | val_pipeline = pipeline_class( 282 | data_dir / val_dir, 283 | validation=True, 284 | batch_size=self.batch_size, 285 | device=dali_device, 286 | device_id=device_id, 287 | shard_id=shard_id, 288 | num_shards=num_shards, 289 | num_threads=num_workers, 290 | ) 291 | 292 | val_loader = Wrapper( 293 | val_pipeline, 294 | output_map=["x", "label"], 295 | reader_name="Reader", 296 | last_batch_policy=LastBatchPolicy.PARTIAL, 297 | auto_reset=True, 298 | ) 299 | return val_loader 300 | -------------------------------------------------------------------------------- /solo/methods/mocov2_distillation_AT.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import argparse 21 | from typing import Any, Dict, List, Sequence, Tuple 22 | 23 | import torch 24 | import torch.nn as nn 25 | import torch.nn.functional as F 26 | from solo.losses.moco import moco_loss_func 27 | # from solo.losses.simclr import simclr_loss_func 28 | # from solo.methods.base import BaseDistillationMethod 29 | from solo.methods.base_for_adversarial_training import BaseDistillationATMethod 30 | from solo.utils.momentum import initialize_momentum_params 31 | from solo.utils.misc import gather 32 | 33 | from torchvision import models 34 | from solo.utils.metrics import accuracy_at_k, weighted_mean 35 | 36 | 37 | class MoCoV2KDAT(BaseDistillationATMethod): 38 | queue: torch.Tensor 39 | 40 | def __init__( 41 | self, 42 | proj_output_dim: int, 43 | proj_hidden_dim: int, 44 | temperature: float, 45 | queue_size: int, 46 | teacher_logit_fix: bool, 47 | loss_type: str, 48 | projector_ablation: str, 49 | epsilon: int = 8, 50 | num_steps: int = 5, 51 | step_size: int = 2, 52 | trades_k: float = 3, 53 | aux_data: bool = False, 54 | augmentation_ablation: bool = False, 55 | expriment_code: str = "000", 56 | **kwargs 57 | ): 58 | """Implements MoCo V2+ (https://arxiv.org/abs/2011.10566). 59 | 60 | Args: 61 | proj_output_dim (int): number of dimensions of projected features. 62 | proj_hidden_dim (int): number of neurons of the hidden layers of the projector. 63 | temperature (float): temperature for the softmax in the contrastive loss. 64 | queue_size (int): number of samples to keep in the queue. 65 | """ 66 | 67 | super().__init__(**kwargs) 68 | 69 | self.temperature = temperature 70 | self.queue_size = queue_size 71 | self.loss_type = loss_type 72 | self.projector_ablation = projector_ablation 73 | self.trades_k = trades_k 74 | self.aux_data = aux_data 75 | 76 | self.augmentation_ablation = augmentation_ablation 77 | self.expriment_code = expriment_code 78 | 79 | 80 | self.epsilon = epsilon/255. 81 | self.num_steps = num_steps 82 | self.step_size = step_size/255. 83 | 84 | 85 | self.projector = nn.Identity() 86 | self.momentum_projector = nn.Identity() 87 | 88 | 89 | @staticmethod 90 | def add_model_specific_args(parent_parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 91 | parent_parser = super(MoCoV2KDAT, MoCoV2KDAT).add_model_specific_args(parent_parser) 92 | parser = parent_parser.add_argument_group("mocov2_kd_at") 93 | 94 | # projector 95 | parser.add_argument("--proj_output_dim", type=int, default=128) 96 | parser.add_argument("--proj_hidden_dim", type=int, default=2048) 97 | 98 | # parameters 99 | parser.add_argument("--temperature", type=float, default=0.1) 100 | 101 | # queue settings 102 | parser.add_argument("--queue_size", default=65536, type=int) 103 | 104 | 105 | parser.add_argument("--teacher_logit_fix", action="store_true") 106 | 107 | # choose loss to train the model 108 | SUPPORT_LOSS = ["ce2gt_ae2ce", 109 | "ae2gt_ae2ce", 110 | "ce2gt_ae2ce_infonce", 111 | "pgd", 112 | "ae2gt_infonce", 113 | 114 | "ae2gt_kl", 115 | "ce2gt_ae2ce_kl", 116 | "cs.ce2gt._kl.ae2ce.", 117 | "kl.ce2gt._cs.ae2ce.", 118 | "cs.ce2gt._infonce.ae2ce.", 119 | "infonce.ce2gt._cs.ae2ce.", 120 | "infonce.ce2gt._kl.ae2ce.", 121 | "kl.ce2gt._infonce.ae2ce." 122 | ] 123 | 124 | parser.add_argument("--loss_type", default="ce2gt_ae2ce", type=str, choices=SUPPORT_LOSS) 125 | 126 | 127 | # projector exploration 128 | PROJECTOR_ABLATION = ["remove", "same", "correspond", 'only_student', 'only_teacher'] 129 | parser.add_argument("--projector_ablation", default="remove", type=str, choices=PROJECTOR_ABLATION) 130 | 131 | 132 | # training adversarial hyper parameter 133 | parser.add_argument("--epsilon", type=int, default=8) 134 | parser.add_argument("--step_size", type=int, default=2) 135 | parser.add_argument("--num_steps", type=int, default=5) 136 | 137 | # for loss factor 138 | parser.add_argument("--trades_k", type=float, default=2) 139 | 140 | # for augmentation ablation study 141 | parser.add_argument("--augmentation_ablation", action="store_true") 142 | parser.add_argument("--expriment_code", default="00", type=str) 143 | 144 | 145 | 146 | 147 | 148 | return parent_parser 149 | 150 | @property 151 | def learnable_params(self) -> List[dict]: 152 | """Adds projector parameters together with parent's learnable parameters. 153 | 154 | Returns: 155 | List[dict]: list of learnable parameters. 156 | """ 157 | 158 | extra_learnable_params = [{"params": self.projector.parameters()}] 159 | return super().learnable_params + extra_learnable_params 160 | 161 | @property 162 | def momentum_pairs(self) -> List[Tuple[Any, Any]]: 163 | """Adds (projector, momentum_projector) to the parent's momentum pairs. 164 | 165 | Returns: 166 | List[Tuple[Any, Any]]: list of momentum pairs. 167 | """ 168 | 169 | extra_momentum_pairs = [(self.projector, self.momentum_projector)] 170 | return super().momentum_pairs + extra_momentum_pairs 171 | 172 | 173 | def forward(self, X: torch.Tensor): 174 | """Performs the forward pass of the online backbone and projector. 175 | 176 | Args: 177 | X (torch.Tensor): a batch of images in the tensor format. 178 | 179 | Returns: 180 | Dict[str, Any]: a dict containing the outputs of the parent and the projected features. 181 | """ 182 | out = self.backbone(X) 183 | 184 | return out 185 | 186 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: 187 | """ 188 | Training step for MoCo reusing BaseMomentumMethod training step. 189 | 190 | Args: 191 | batch (Sequence[Any]): a batch of data in the 192 | format of [img_indexes, [X], Y], where [X] is a list of size self.num_large_crops 193 | containing batches of images. 194 | batch_idx (int): index of the batch. 195 | 196 | Returns: 197 | torch.Tensor: total loss composed of MOCO loss and classification loss. 198 | 199 | """ 200 | 201 | 202 | self.momentum_backbone.eval() 203 | 204 | 205 | # import ipdb; ipdb.set_trace() 206 | if self.aux_data: 207 | image_tau1, image_weak = batch[0] 208 | targets = batch[1] 209 | 210 | else: 211 | image_tau1, image_weak = batch[1] 212 | targets = batch[2] 213 | 214 | 215 | ############################################################################ 216 | # Adversarial Training (CAT) 217 | ############################################################################ 218 | # if self.trainer.current_epoch ==2: 219 | # import ipdb; ipdb.set_trace() 220 | # import ipdb; ipdb.set_trace() 221 | 222 | away_target = self.momentum_projector(self.momentum_backbone(image_weak)) 223 | 224 | AE_generation_image = image_weak 225 | image_AE = self.generate_training_AE(AE_generation_image, away_target) 226 | 227 | image_CAT = torch.cat([image_weak, image_AE]) 228 | logits_all = self.projector(self.backbone(image_CAT)) 229 | bs = image_weak.size(0) 230 | student_logits_clean = logits_all[:bs] 231 | student_logits_AE = logits_all[bs:] 232 | 233 | # Cosine Similarity loss 234 | adv_loss = -F.cosine_similarity(student_logits_clean, away_target).mean() 235 | adv_loss += -self.trades_k*F.cosine_similarity(student_logits_AE, student_logits_clean).mean() 236 | 237 | ############################################################################ 238 | # Adversarial Training (CAT) 239 | ############################################################################ 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | ############################################################################ 250 | # Online clean classifier training 251 | ############################################################################ 252 | # Bug Fix: train classifier using evaluation mode 253 | self.backbone.eval() 254 | # import ipdb; ipdb.set_trace() 255 | outs_image_weak = self._base_shared_step(image_weak, targets) 256 | self.backbone.train() 257 | metrics = { 258 | "train_class_loss": outs_image_weak["loss"], 259 | "train_acc1": outs_image_weak["acc1"], 260 | "train_acc5": outs_image_weak["acc5"], 261 | } 262 | class_loss_clean = outs_image_weak["loss"] 263 | self.log_dict(metrics, on_epoch=True) 264 | ############################################################################ 265 | # Online clean classifier training 266 | ############################################################################ 267 | 268 | 269 | ae_std = F.normalize(student_logits_AE, dim=-1).std(dim=0).mean() 270 | clean_std = F.normalize(student_logits_clean, dim=-1).std(dim=0).mean() 271 | teacher_std = F.normalize(away_target, dim=-1).std(dim=0).mean() 272 | 273 | 274 | metrics = { 275 | "adv_loss": adv_loss, 276 | "ae_std": ae_std, 277 | "clean_std": clean_std, 278 | "teacher_std": teacher_std 279 | } 280 | self.log_dict(metrics, on_epoch=True, sync_dist=True) 281 | 282 | # return adv_loss + class_loss_adv + class_loss_clean 283 | return adv_loss + class_loss_clean 284 | 285 | 286 | 287 | def generate_training_AE(self, image: torch.Tensor, away_target: torch.Tensor): 288 | """ 289 | images_org: weak aug 290 | away_target: from teacher 291 | """ 292 | 293 | # self.epsilon = 8/255. 294 | # self.num_steps = 5 295 | # self.step_size = 2/255. 296 | 297 | x_cl = image.clone().detach() 298 | 299 | # if self.rand: 300 | x_cl = x_cl + torch.zeros_like(image).uniform_(-self.epsilon, self.epsilon) 301 | 302 | # f_ori_proj = self.model(images_org).detach() 303 | # Change the attack process of model to eval 304 | self.backbone.eval() 305 | 306 | for i in range(self.num_steps): 307 | x_cl.requires_grad_() 308 | with torch.enable_grad(): 309 | f_proj = self.projector(self.backbone(x_cl)) 310 | 311 | # for 16 bit training 312 | loss_contrast = -F.cosine_similarity(f_proj, away_target, dim=1).sum() *256 313 | loss = loss_contrast 314 | 315 | # import ipdb ;ipdb.set_trace() 316 | grad_x_cl = torch.autograd.grad(loss, x_cl)[0] 317 | # grad_x_cl = torch.autograd.grad(loss, x_cl, grad_outputs=torch.ones_like(loss))[0] 318 | x_cl = x_cl.detach() + self.step_size * torch.sign(grad_x_cl.detach()) 319 | 320 | # remove the clamp in for the image comparision 321 | x_cl = torch.min(torch.max(x_cl, image - self.epsilon), image + self.epsilon) 322 | x_cl = torch.clamp(x_cl, 0, 1) 323 | 324 | self.backbone.train() 325 | 326 | return x_cl 327 | 328 | 329 | 330 | 331 | 332 | 333 | --------------------------------------------------------------------------------