├── solo
├── models
│ ├── __init__.py
│ ├── model_with_linear.py
│ └── wide_resnet.py
├── args
│ ├── __init__.py
│ ├── dataset.py
│ ├── setup.py
│ └── utils.py
├── __init__.py
├── methods
│ ├── __init__.py
│ ├── dali.py
│ └── mocov2_distillation_AT.py
├── losses
│ ├── wmse.py
│ ├── byol.py
│ ├── simsiam.py
│ ├── nnclr.py
│ ├── __init__.py
│ ├── moco.py
│ ├── swav.py
│ ├── deepclusterv2.py
│ ├── ressl.py
│ ├── barlow.py
│ ├── simclr.py
│ ├── vibcreg.py
│ ├── vicreg.py
│ └── dino.py
└── utils
│ ├── __init__.py
│ ├── metrics.py
│ ├── auto_resumer.py
│ ├── momentum.py
│ ├── sinkhorn_knopp.py
│ ├── lars.py
│ ├── checkpointer.py
│ ├── misc.py
│ ├── kmeans.py
│ ├── knn.py
│ ├── whitening.py
│ ├── classification_dataloader.py
│ ├── classification_dataloader_AdvTraining.py
│ └── auto_umap.py
├── TeacherCKPT
└── README.md
├── figure
└── DeACL.png
├── .gitignore
├── requirements.txt
├── bash_files
├── eval_cifar10_resnet18.sh
├── DeACL_cifar100_resnet18.sh
├── DeACL_cifar10_resnet18.sh
└── DeACL_cifar10_resnet50.sh
├── trades
└── trades.py
├── README.md
└── main_pretrain_AdvTraining.py
/solo/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/TeacherCKPT/README.md:
--------------------------------------------------------------------------------
1 | Put the teacher model ckpt here.
--------------------------------------------------------------------------------
/figure/DeACL.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/pantheon5100/DeACL/HEAD/figure/DeACL.png
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 |
3 |
4 | wandb/
5 | code/
6 | trained_models/
7 | ckeckpoint/
8 | logger/
9 | data/
10 |
11 | checkpoint/
12 |
13 | /TeacherCKPT/*.ckpt
14 |
15 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | einops
2 | pytorch-lightning==1.5.3
3 | torchmetrics==0.6.0
4 | lightning-bolts>=0.4.0
5 | tqdm
6 | wandb
7 | scipy
8 | timm==0.5.4
9 | setuptools==59.5.0
10 |
--------------------------------------------------------------------------------
/bash_files/eval_cifar10_resnet18.sh:
--------------------------------------------------------------------------------
1 | # For SLF
2 | # python adv_finetune.py \
3 | # --ckpt DEACL_WEIGHT_res18_simclr-cifar10-offline-x4h7cp45-ep=99.ckpt \
4 | # --mode slf
5 |
6 | # For AFF
7 | # python adv_finetune.py \
8 | # --ckpt DEACL_WEIGHT_res18_simclr-cifar10-offline-x4h7cp45-ep=99.ckpt \
9 | # --mode aff
10 |
11 | # For ALF
12 | # python adv_finetune.py \
13 | # --ckpt DEACL_WEIGHT_res18_simclr-cifar10-offline-x4h7cp45-ep=99.ckpt \
14 | # --mode alf
15 |
16 |
17 | CKPT="trained_models/simclr/DEACL_WEIGHT_res18_simclr-cifar10-offline-x4h7cp45-ep=99.ckpt"
18 | # run the SLF 5 times
19 | for i in {1..5}
20 | do
21 | CUDA_VISIBLE_DEVICES=1 python adv_finetune.py \
22 | --ckpt $CKPT \
23 | --mode slf \
24 | --learning_rate 0.1
25 | done
26 |
--------------------------------------------------------------------------------
/bash_files/DeACL_cifar100_resnet18.sh:
--------------------------------------------------------------------------------
1 | python3 main_pretrain_AdvTraining.py \
2 | --dataset cifar100 \
3 | --backbone resnet18 \
4 | --data_dir ./data \
5 | --max_epochs 100 \
6 | --gpus 1 \
7 | --accelerator gpu \
8 | --precision 16 \
9 | --optimizer sgd \
10 | --scheduler warmup_cosine \
11 | --lr 0.5 \
12 | --classifier_lr 0.5 \
13 | --weight_decay 5e-4 \
14 | --batch_size 256 \
15 | --num_workers 4 \
16 | --brightness 0.4 \
17 | --contrast 0.4 \
18 | --saturation 0.4 \
19 | --hue 0.1 \
20 | --gaussian_prob 0.0 0.0 \
21 | --crop_size 32 \
22 | --num_crops_per_aug 1 1 \
23 | --name "res18_simclr-cifar10" \
24 | --save_checkpoint \
25 | --method mocov2_kd_at \
26 | --limit_val_batches 0.2 \
27 | --distillation_teacher "simclr_cifar10" \
28 | --trades_k 2
29 |
--------------------------------------------------------------------------------
/solo/models/model_with_linear.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class ModelwithLinear(nn.Module):
5 | def __init__(self, model, inplanes, num_classes=10):
6 | super(ModelwithLinear, self).__init__()
7 | self.model = model
8 |
9 | self.classifier = nn.Linear(inplanes, num_classes)
10 |
11 | def forward(self, img):
12 | x = self.model(img)
13 | out = self.classifier(x)
14 | return out
15 |
16 | class LinearClassifier(nn.Module):
17 | """Linear classifier"""
18 |
19 | def __init__(self, name='resnet50', feat_dim=512, num_classes=10):
20 | super(LinearClassifier, self).__init__()
21 | # _, feat_dim = model_dict[name]
22 | self.classifier = nn.Linear(feat_dim, num_classes)
23 |
24 | def forward(self, features):
25 | return self.classifier(features)
26 |
--------------------------------------------------------------------------------
/bash_files/DeACL_cifar10_resnet18.sh:
--------------------------------------------------------------------------------
1 | export WANDB_API_KEY=""
2 |
3 | python3 main_pretrain_AdvTraining.py \
4 | --dataset cifar10 \
5 | --backbone resnet18 \
6 | --data_dir ./data \
7 | --max_epochs 100 \
8 | --gpus 1 \
9 | --accelerator gpu \
10 | --precision 32 \
11 | --optimizer sgd \
12 | --scheduler warmup_cosine \
13 | --lr 0.5 \
14 | --classifier_lr 0.5 \
15 | --weight_decay 5e-4 \
16 | --batch_size 256 \
17 | --num_workers 4 \
18 | --brightness 0.4 \
19 | --contrast 0.4 \
20 | --saturation 0.4 \
21 | --hue 0.1 \
22 | --gaussian_prob 0.0 0.0 \
23 | --crop_size 32 \
24 | --num_crops_per_aug 1 1 \
25 | --name "res18_simclr-cifar10-fp32" \
26 | --save_checkpoint \
27 | --method mocov2_kd_at \
28 | --limit_val_batches 0.2 \
29 | --distillation_teacher "simclr_cifar10" \
30 | --trades_k 2
31 |
--------------------------------------------------------------------------------
/bash_files/DeACL_cifar10_resnet50.sh:
--------------------------------------------------------------------------------
1 | python3 main_pretrain_AdvTraining.py \
2 | --dataset cifar10 \
3 | --backbone resnet50 \
4 | --data_dir ./data \
5 | --max_epochs 100 \
6 | --gpus 4 \
7 | --accelerator gpu \
8 | --precision 16 \
9 | --optimizer sgd \
10 | --scheduler warmup_cosine \
11 | --lr 0.5 \
12 | --classifier_lr 0.5 \
13 | --weight_decay 5e-4 \
14 | --batch_size 256 \
15 | --num_workers 4 \
16 | --brightness 0.4 \
17 | --contrast 0.4 \
18 | --saturation 0.4 \
19 | --hue 0.1 \
20 | --gaussian_prob 0.0 0.0 \
21 | --crop_size 32 \
22 | --num_crops_per_aug 1 1 \
23 | --name "res50_simclr-cifar10" \
24 | --save_checkpoint \
25 | --method mocov2_kd_at \
26 | --queue_size 32768 \
27 | --temperature 0.2 \
28 | --limit_val_batches 0.2 \
29 | --distillation_teacher "simclr_cifar10_resnet50" \
30 | --trades_k 2
31 |
--------------------------------------------------------------------------------
/solo/args/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | from solo.args import dataset, setup, utils
21 |
22 | __all__ = ["dataset", "setup", "utils"]
23 |
--------------------------------------------------------------------------------
/solo/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 |
21 | from solo import args, losses, methods, utils
22 |
23 | __all__ = ["args", "losses", "methods", "utils"]
24 |
--------------------------------------------------------------------------------
/solo/methods/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | from solo.methods.base_for_adversarial_training import BaseMethod
21 |
22 | from solo.methods.mocov2_distillation_AT import MoCoV2KDAT
23 |
24 |
25 | METHODS = {
26 | # base classes
27 | "base": BaseMethod,
28 | # methods
29 | "mocov2_kd_at": MoCoV2KDAT,
30 | }
31 | __all__ = [
32 | "BaseMethod",
33 | "MoCoV2KDAT",
34 | ]
35 |
36 | try:
37 | from solo.methods import dali # noqa: F401
38 | except ImportError:
39 | pass
40 | else:
41 | __all__.append("dali")
42 |
--------------------------------------------------------------------------------
/solo/losses/wmse.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import torch
21 | import torch.nn.functional as F
22 |
23 |
24 | def wmse_loss_func(z1: torch.Tensor, z2: torch.Tensor, simplified: bool = True) -> torch.Tensor:
25 | """Computes W-MSE's loss given two batches of whitened features z1 and z2.
26 |
27 | Args:
28 | z1 (torch.Tensor): NxD Tensor containing whitened features from view 1.
29 | z2 (torch.Tensor): NxD Tensor containing whitened features from view 2.
30 | simplified (bool): faster computation, but with same result.
31 |
32 | Returns:
33 | torch.Tensor: W-MSE loss.
34 | """
35 |
36 | if simplified:
37 | return 2 - 2 * F.cosine_similarity(z1, z2.detach(), dim=-1).mean()
38 |
39 | z1 = F.normalize(z1, dim=-1)
40 | z2 = F.normalize(z2, dim=-1)
41 |
42 | return 2 - 2 * (z1 * z2).sum(dim=-1).mean()
43 |
--------------------------------------------------------------------------------
/solo/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | from solo.utils import (
21 | backbones,
22 | checkpointer,
23 | classification_dataloader,
24 | knn,
25 | lars,
26 | metrics,
27 | misc,
28 | momentum,
29 | pretrain_dataloader,
30 | sinkhorn_knopp,
31 | )
32 |
33 | __all__ = [
34 | "backbones",
35 | "classification_dataloader",
36 | "pretrain_dataloader",
37 | "checkpointer",
38 | "knn",
39 | "misc",
40 | "lars",
41 | "metrics",
42 | "momentum",
43 | "sinkhorn_knopp",
44 | ]
45 |
46 | try:
47 | from solo.utils import dali_dataloader # noqa: F401
48 | except ImportError:
49 | pass
50 | else:
51 | __all__.append("dali_dataloader")
52 |
53 | try:
54 | from solo.utils import auto_umap # noqa: F401
55 | except ImportError:
56 | pass
57 | else:
58 | __all__.append("auto_umap")
59 |
--------------------------------------------------------------------------------
/solo/losses/byol.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import torch
21 | import torch.nn.functional as F
22 |
23 |
24 | def byol_loss_func(p: torch.Tensor, z: torch.Tensor, simplified: bool = True) -> torch.Tensor:
25 | """Computes BYOL's loss given batch of predicted features p and projected momentum features z.
26 |
27 | Args:
28 | p (torch.Tensor): NxD Tensor containing predicted features from view 1
29 | z (torch.Tensor): NxD Tensor containing projected momentum features from view 2
30 | simplified (bool): faster computation, but with same result. Defaults to True.
31 |
32 | Returns:
33 | torch.Tensor: BYOL's loss.
34 | """
35 |
36 | if simplified:
37 | return 2 - 2 * F.cosine_similarity(p, z.detach(), dim=-1).mean()
38 |
39 | p = F.normalize(p, dim=-1)
40 | z = F.normalize(z, dim=-1)
41 |
42 | return 2 - 2 * (p * z.detach()).sum(dim=1).mean()
43 |
--------------------------------------------------------------------------------
/solo/losses/simsiam.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import torch
21 | import torch.nn.functional as F
22 |
23 |
24 | def simsiam_loss_func(p: torch.Tensor, z: torch.Tensor, simplified: bool = True) -> torch.Tensor:
25 | """Computes SimSiam's loss given batch of predicted features p from view 1 and
26 | a batch of projected features z from view 2.
27 |
28 | Args:
29 | p (torch.Tensor): Tensor containing predicted features from view 1.
30 | z (torch.Tensor): Tensor containing projected features from view 2.
31 | simplified (bool): faster computation, but with same result.
32 |
33 | Returns:
34 | torch.Tensor: SimSiam loss.
35 | """
36 |
37 | if simplified:
38 | return -F.cosine_similarity(p, z.detach(), dim=-1).mean()
39 |
40 | p = F.normalize(p, dim=-1)
41 | z = F.normalize(z, dim=-1)
42 |
43 | return -(p * z.detach()).sum(dim=1).mean()
44 |
--------------------------------------------------------------------------------
/solo/losses/nnclr.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import torch
21 | import torch.nn.functional as F
22 |
23 |
24 | def nnclr_loss_func(nn: torch.Tensor, p: torch.Tensor, temperature: float = 0.1) -> torch.Tensor:
25 | """Computes NNCLR's loss given batch of nearest-neighbors nn from view 1 and
26 | predicted features p from view 2.
27 |
28 | Args:
29 | nn (torch.Tensor): NxD Tensor containing nearest neighbors' features from view 1.
30 | p (torch.Tensor): NxD Tensor containing predicted features from view 2
31 | temperature (float, optional): temperature of the softmax in the contrastive loss. Defaults
32 | to 0.1.
33 |
34 | Returns:
35 | torch.Tensor: NNCLR loss.
36 | """
37 |
38 | nn = F.normalize(nn, dim=-1)
39 | p = F.normalize(p, dim=-1)
40 |
41 | logits = nn @ p.T / temperature
42 |
43 | n = p.size(0)
44 | labels = torch.arange(n, device=p.device)
45 |
46 | loss = F.cross_entropy(logits, labels)
47 | return loss
48 |
--------------------------------------------------------------------------------
/solo/losses/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | from solo.losses.barlow import barlow_loss_func
21 | from solo.losses.byol import byol_loss_func
22 | from solo.losses.deepclusterv2 import deepclusterv2_loss_func
23 | from solo.losses.dino import DINOLoss
24 | from solo.losses.moco import moco_loss_func
25 | from solo.losses.nnclr import nnclr_loss_func
26 | from solo.losses.ressl import ressl_loss_func
27 | from solo.losses.simclr import simclr_loss_func
28 | from solo.losses.simsiam import simsiam_loss_func
29 | from solo.losses.swav import swav_loss_func
30 | from solo.losses.vibcreg import vibcreg_loss_func
31 | from solo.losses.vicreg import vicreg_loss_func
32 | from solo.losses.wmse import wmse_loss_func
33 |
34 | __all__ = [
35 | "barlow_loss_func",
36 | "byol_loss_func",
37 | "deepclusterv2_loss_func",
38 | "DINOLoss",
39 | "moco_loss_func",
40 | "nnclr_loss_func",
41 | "ressl_loss_func",
42 | "simclr_loss_func",
43 | "simsiam_loss_func",
44 | "swav_loss_func",
45 | "vibcreg_loss_func",
46 | "vicreg_loss_func",
47 | "wmse_loss_func",
48 | ]
49 |
--------------------------------------------------------------------------------
/solo/losses/moco.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import torch
21 | import torch.nn.functional as F
22 |
23 |
24 | def moco_loss_func(
25 | query: torch.Tensor, key: torch.Tensor, queue: torch.Tensor, temperature=0.1
26 | ) -> torch.Tensor:
27 | """Computes MoCo's loss given a batch of queries from view 1, a batch of keys from view 2 and a
28 | queue of past elements.
29 |
30 | Args:
31 | query (torch.Tensor): NxD Tensor containing the queries from view 1.
32 | key (torch.Tensor): NxD Tensor containing the queries from view 2.
33 | queue (torch.Tensor): a queue of negative samples for the contrastive loss.
34 | temperature (float, optional): [description]. temperature of the softmax in the contrastive
35 | loss. Defaults to 0.1.
36 |
37 | Returns:
38 | torch.Tensor: MoCo loss.
39 | """
40 |
41 | pos = torch.einsum("nc,nc->n", [query, key]).unsqueeze(-1)
42 | neg = torch.einsum("nc,ck->nk", [query, queue])
43 | logits = torch.cat([pos, neg], dim=1)
44 | logits /= temperature
45 | targets = torch.zeros(query.size(0), device=query.device, dtype=torch.long)
46 | return F.cross_entropy(logits, targets)
47 |
--------------------------------------------------------------------------------
/solo/losses/swav.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | from typing import List
21 |
22 | import numpy as np
23 | import torch
24 |
25 |
26 | def swav_loss_func(
27 | preds: List[torch.Tensor], assignments: List[torch.Tensor], temperature: float = 0.1
28 | ) -> torch.Tensor:
29 | """Computes SwAV's loss given list of batch predictions from multiple views
30 | and a list of cluster assignments from the same multiple views.
31 |
32 | Args:
33 | preds (torch.Tensor): list of NxC Tensors containing nearest neighbors' features from
34 | view 1.
35 | assignments (torch.Tensor): list of NxC Tensor containing predicted features from view 2.
36 | temperature (torch.Tensor): softmax temperature for the loss. Defaults to 0.1.
37 |
38 | Returns:
39 | torch.Tensor: SwAV loss.
40 | """
41 |
42 | losses = []
43 | for v1 in range(len(preds)):
44 | for v2 in np.delete(np.arange(len(preds)), v1):
45 | a = assignments[v1]
46 | p = preds[v2] / temperature
47 | loss = -torch.mean(torch.sum(a * torch.log_softmax(p, dim=1), dim=1))
48 | losses.append(loss)
49 | return sum(losses) / len(losses)
50 |
--------------------------------------------------------------------------------
/solo/losses/deepclusterv2.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import torch
21 | import torch.nn.functional as F
22 |
23 |
24 | def deepclusterv2_loss_func(
25 | outputs: torch.Tensor, assignments: torch.Tensor, temperature: float = 0.1
26 | ) -> torch.Tensor:
27 | """Computes DeepClusterV2's loss given a tensor containing logits from multiple views
28 | and a tensor containing cluster assignments from the same multiple views.
29 |
30 | Args:
31 | outputs (torch.Tensor): tensor of size PxVxNxC where P is the number of prototype
32 | layers and V is the number of views.
33 | assignments (torch.Tensor): tensor of size PxVxNxC containing the assignments
34 | generated using k-means.
35 | temperature (float, optional): softmax temperature for the loss. Defaults to 0.1.
36 |
37 | Returns:
38 | torch.Tensor: DeepClusterV2 loss.
39 | """
40 | loss = 0
41 | for h in range(outputs.size(0)):
42 | scores = outputs[h].view(-1, outputs.size(-1)) / temperature
43 | targets = assignments[h].repeat(outputs.size(1)).to(outputs.device, non_blocking=True)
44 | loss += F.cross_entropy(scores, targets, ignore_index=-1)
45 | return loss / outputs.size(0)
46 |
--------------------------------------------------------------------------------
/solo/losses/ressl.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import torch
21 | import torch.nn.functional as F
22 |
23 |
24 | def ressl_loss_func(
25 | q: torch.Tensor,
26 | k: torch.Tensor,
27 | queue: torch.Tensor,
28 | temperature_q: float = 0.1,
29 | temperature_k: float = 0.04,
30 | ) -> torch.Tensor:
31 | """Computes ReSSL's loss given a batch of queries from view 1, a batch of keys from view 2 and a
32 | queue of past elements.
33 |
34 | Args:
35 | query (torch.Tensor): NxD Tensor containing the queries from view 1.
36 | key (torch.Tensor): NxD Tensor containing the queries from view 2.
37 | queue (torch.Tensor): a queue of negative samples for the contrastive loss.
38 | temperature_q (float, optional): [description]. temperature of the softmax for the query.
39 | Defaults to 0.1.
40 | temperature_k (float, optional): [description]. temperature of the softmax for the key.
41 | Defaults to 0.04.
42 |
43 | Returns:
44 | torch.Tensor: ReSSL loss.
45 | """
46 |
47 | logits_q = torch.einsum("nc,kc->nk", [q, queue])
48 | logits_k = torch.einsum("nc,kc->nk", [k, queue])
49 |
50 | loss = -torch.sum(
51 | F.softmax(logits_k.detach() / temperature_k, dim=1)
52 | * F.log_softmax(logits_q / temperature_q, dim=1),
53 | dim=1,
54 | ).mean()
55 |
56 | return loss
57 |
--------------------------------------------------------------------------------
/solo/losses/barlow.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import torch
21 |
22 | import torch.distributed as dist
23 |
24 |
25 | def barlow_loss_func(
26 | z1: torch.Tensor, z2: torch.Tensor, lamb: float = 5e-3, scale_loss: float = 0.025
27 | ) -> torch.Tensor:
28 | """Computes Barlow Twins' loss given batch of projected features z1 from view 1 and
29 | projected features z2 from view 2.
30 |
31 | Args:
32 | z1 (torch.Tensor): NxD Tensor containing projected features from view 1.
33 | z2 (torch.Tensor): NxD Tensor containing projected features from view 2.
34 | lamb (float, optional): off-diagonal scaling factor for the cross-covariance matrix.
35 | Defaults to 5e-3.
36 | scale_loss (float, optional): final scaling factor of the loss. Defaults to 0.025.
37 |
38 | Returns:
39 | torch.Tensor: Barlow Twins' loss.
40 | """
41 |
42 | N, D = z1.size()
43 |
44 | # to match the original code
45 | bn = torch.nn.BatchNorm1d(D, affine=False).to(z1.device)
46 | z1 = bn(z1)
47 | z2 = bn(z2)
48 |
49 | corr = torch.einsum("bi, bj -> ij", z1, z2) / N
50 |
51 | if dist.is_available() and dist.is_initialized():
52 | dist.all_reduce(corr)
53 | world_size = dist.get_world_size()
54 | corr /= world_size
55 |
56 | diag = torch.eye(D, device=corr.device)
57 | cdif = (corr - diag).pow(2)
58 | cdif[~diag.bool()] *= lamb
59 | loss = scale_loss * cdif.sum()
60 | return loss
61 |
--------------------------------------------------------------------------------
/solo/losses/simclr.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import torch
21 | import torch.nn.functional as F
22 | from solo.utils.misc import gather, get_rank
23 |
24 |
25 | def simclr_loss_func(
26 | z: torch.Tensor, indexes: torch.Tensor, temperature: float = 0.1
27 | ) -> torch.Tensor:
28 | """Computes SimCLR's loss given batch of projected features z
29 | from different views, a positive boolean mask of all positives and
30 | a negative boolean mask of all negatives.
31 |
32 | Args:
33 | z (torch.Tensor): (N*views) x D Tensor containing projected features from the views.
34 | indexes (torch.Tensor): unique identifiers for each crop (unsupervised)
35 | or targets of each crop (supervised).
36 |
37 | Return:
38 | torch.Tensor: SimCLR loss.
39 | """
40 |
41 | z = F.normalize(z, dim=-1)
42 | gathered_z = gather(z)
43 |
44 | sim = torch.exp(torch.einsum("if, jf -> ij", z, gathered_z) / temperature)
45 |
46 | gathered_indexes = gather(indexes)
47 |
48 | indexes = indexes.unsqueeze(0)
49 | gathered_indexes = gathered_indexes.unsqueeze(0)
50 | # positives
51 | pos_mask = indexes.t() == gathered_indexes
52 | pos_mask[:, z.size(0) * get_rank() :].fill_diagonal_(0)
53 | # negatives
54 | neg_mask = indexes.t() != gathered_indexes
55 |
56 | pos = torch.sum(sim * pos_mask, 1)
57 | neg = torch.sum(sim * neg_mask, 1)
58 | loss = -(torch.mean(torch.log(pos / (pos + neg))))
59 | return loss
60 |
--------------------------------------------------------------------------------
/solo/utils/metrics.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | from typing import Dict, List, Sequence
21 |
22 | import torch
23 |
24 |
25 | def accuracy_at_k(
26 | outputs: torch.Tensor, targets: torch.Tensor, top_k: Sequence[int] = (1, 5)
27 | ) -> Sequence[int]:
28 | """Computes the accuracy over the k top predictions for the specified values of k.
29 |
30 | Args:
31 | outputs (torch.Tensor): output of a classifier (logits or probabilities).
32 | targets (torch.Tensor): ground truth labels.
33 | top_k (Sequence[int], optional): sequence of top k values to compute the accuracy over.
34 | Defaults to (1, 5).
35 |
36 | Returns:
37 | Sequence[int]: accuracies at the desired k.
38 | """
39 |
40 | with torch.no_grad():
41 | maxk = max(top_k)
42 | batch_size = targets.size(0)
43 |
44 | _, pred = outputs.topk(maxk, 1, True, True)
45 | pred = pred.t()
46 | correct = pred.eq(targets.view(1, -1).expand_as(pred))
47 |
48 | res = []
49 | for k in top_k:
50 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
51 | res.append(correct_k.mul_(100.0 / batch_size))
52 | return res
53 |
54 |
55 | def weighted_mean(outputs: List[Dict], key: str, batch_size_key: str) -> float:
56 | """Computes the mean of the values of a key weighted by the batch size.
57 |
58 | Args:
59 | outputs (List[Dict]): list of dicts containing the outputs of a validation step.
60 | key (str): key of the metric of interest.
61 | batch_size_key (str): key of batch size values.
62 |
63 | Returns:
64 | float: weighted mean of the values of a key
65 | """
66 |
67 | value = 0
68 | n = 0
69 | for out in outputs:
70 | value += out[batch_size_key] * out[key]
71 | n += out[batch_size_key]
72 | value = value / n
73 | return value.squeeze(0)
74 |
--------------------------------------------------------------------------------
/solo/losses/vibcreg.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import torch
21 | from solo.losses.vicreg import invariance_loss, variance_loss
22 | from torch import Tensor
23 | from torch.nn import functional as F
24 |
25 |
26 | def covariance_loss(z1: Tensor, z2: Tensor) -> Tensor:
27 | """Computes normalized covariance loss given batch of projected features z1 from view 1 and
28 | projected features z2 from view 2.
29 |
30 | Args:
31 | z1 (torch.Tensor): NxD Tensor containing projected features from view 1.
32 | z2 (torch.Tensor): NxD Tensor containing projected features from view 2.
33 |
34 | Returns:
35 | torch.Tensor: covariance regularization loss.
36 | """
37 |
38 | norm_z1 = z1 - z1.mean(dim=0)
39 | norm_z2 = z2 - z2.mean(dim=0)
40 | norm_z1 = F.normalize(norm_z1, p=2, dim=0) # (batch * feature); l2-norm
41 | norm_z2 = F.normalize(norm_z2, p=2, dim=0)
42 | fxf_cov_z1 = torch.mm(norm_z1.T, norm_z1) # (feature * feature)
43 | fxf_cov_z2 = torch.mm(norm_z2.T, norm_z2)
44 | fxf_cov_z1.fill_diagonal_(0.0)
45 | fxf_cov_z2.fill_diagonal_(0.0)
46 | cov_loss = (fxf_cov_z1 ** 2).mean() + (fxf_cov_z2 ** 2).mean()
47 | return cov_loss
48 |
49 |
50 | def vibcreg_loss_func(
51 | z1: torch.Tensor,
52 | z2: torch.Tensor,
53 | sim_loss_weight: float = 25.0,
54 | var_loss_weight: float = 25.0,
55 | cov_loss_weight: float = 200.0,
56 | ) -> torch.Tensor:
57 | """Computes VIbCReg's loss given batch of projected features z1 from view 1 and
58 | projected features z2 from view 2.
59 |
60 | Args:
61 | z1 (torch.Tensor): NxD Tensor containing projected features from view 1.
62 | z2 (torch.Tensor): NxD Tensor containing projected features from view 2.
63 | sim_loss_weight (float): invariance loss weight.
64 | var_loss_weight (float): variance loss weight.
65 | cov_loss_weight (float): covariance loss weight.
66 |
67 | Returns:
68 | torch.Tensor: VIbCReg loss.
69 | """
70 |
71 | sim_loss = invariance_loss(z1, z2)
72 | var_loss = variance_loss(z1, z2)
73 | cov_loss = covariance_loss(z1, z2)
74 |
75 | loss = sim_loss_weight * sim_loss + var_loss_weight * var_loss + cov_loss_weight * cov_loss
76 | return loss
77 |
--------------------------------------------------------------------------------
/solo/utils/auto_resumer.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from argparse import ArgumentParser, Namespace
4 | from collections import namedtuple
5 | from datetime import datetime, timedelta
6 | from pathlib import Path
7 | from typing import Union
8 |
9 | Checkpoint = namedtuple("Checkpoint", ["creation_time", "args", "checkpoint"])
10 |
11 |
12 | class AutoResumer:
13 | SHOULD_MATCH = [
14 | "batch_size",
15 | "weight_decay",
16 | "lr",
17 | "dataset",
18 | "backbone",
19 | "max_epochs",
20 | "method",
21 | "name",
22 | "project",
23 | "entity",
24 | ]
25 |
26 | def __init__(
27 | self,
28 | checkpoint_dir: Union[str, Path] = Path("trained_models"),
29 | max_hours: int = 36,
30 | ):
31 | """Autoresumer object that automatically tries to find a checkpoint
32 | that is as old as max_time.
33 |
34 | Args:
35 | checkpoint_dir (Union[str, Path], optional): base directory to store checkpoints.
36 | Defaults to "trained_models".
37 | max_hours (int): maximum elapsed hours to consider checkpoint as valid.
38 | """
39 |
40 | self.checkpoint_dir = checkpoint_dir
41 | self.max_hours = timedelta(hours=max_hours)
42 |
43 | @staticmethod
44 | def add_autoresumer_args(parent_parser: ArgumentParser):
45 | """Adds user-required arguments to a parser.
46 |
47 | Args:
48 | parent_parser (ArgumentParser): parser to add new args to.
49 | """
50 |
51 | parser = parent_parser.add_argument_group("autoresumer")
52 | parser.add_argument("--auto_resumer_max_hours", default=36, type=int)
53 | return parent_parser
54 |
55 | def find_checkpoint(self, args: Namespace):
56 | """Finds a valid checkpoint that matches the arguments
57 |
58 | Args:
59 | args (Namespace): namespace object containing all settings of the model.
60 | """
61 |
62 | current_time = datetime.now()
63 |
64 | possible_checkpoints = []
65 | for rootdir, _, files in os.walk(self.checkpoint_dir):
66 | rootdir = Path(rootdir)
67 | if files:
68 | # skip checkpoints that are empty
69 | try:
70 | checkpoint_file = [rootdir / f for f in files if f.endswith(".ckpt")][0]
71 | except:
72 | continue
73 |
74 | creation_time = datetime.fromtimestamp(os.path.getctime(checkpoint_file))
75 | if current_time - creation_time < self.max_hours:
76 | ck = Checkpoint(
77 | creation_time=creation_time,
78 | args=rootdir / "args.json",
79 | checkpoint=checkpoint_file,
80 | )
81 | possible_checkpoints.append(ck)
82 |
83 | if possible_checkpoints:
84 | # sort by most recent
85 | possible_checkpoints = sorted(
86 | possible_checkpoints, key=lambda ck: ck.creation_time, reverse=True
87 | )
88 |
89 | for checkpoint in possible_checkpoints:
90 | checkpoint_args = Namespace(**json.load(open(checkpoint.args)))
91 | if all(
92 | getattr(checkpoint_args, param) == getattr(args, param)
93 | for param in AutoResumer.SHOULD_MATCH
94 | ):
95 | return checkpoint.checkpoint
96 |
97 | return None
98 |
--------------------------------------------------------------------------------
/solo/utils/momentum.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import math
21 |
22 | import torch
23 | from torch import nn
24 |
25 |
26 | @torch.no_grad()
27 | def initialize_momentum_params(online_net: nn.Module, momentum_net: nn.Module):
28 | """Copies the parameters of the online network to the momentum network.
29 |
30 | Args:
31 | online_net (nn.Module): online network (e.g. online backbone, online projection, etc...).
32 | momentum_net (nn.Module): momentum network (e.g. momentum backbone,
33 | momentum projection, etc...).
34 | """
35 |
36 | params_online = online_net.parameters()
37 | params_momentum = momentum_net.parameters()
38 | for po, pm in zip(params_online, params_momentum):
39 | pm.data.copy_(po.data)
40 | pm.requires_grad = False
41 |
42 |
43 | class MomentumUpdater:
44 | def __init__(self, base_tau: float = 0.996, final_tau: float = 1.0):
45 | """Updates momentum parameters using exponential moving average.
46 |
47 | Args:
48 | base_tau (float, optional): base value of the weight decrease coefficient
49 | (should be in [0,1]). Defaults to 0.996.
50 | final_tau (float, optional): final value of the weight decrease coefficient
51 | (should be in [0,1]). Defaults to 1.0.
52 | """
53 |
54 | super().__init__()
55 |
56 | assert 0 <= base_tau <= 1
57 | assert 0 <= final_tau <= 1 and base_tau <= final_tau
58 |
59 | self.base_tau = base_tau
60 | self.cur_tau = base_tau
61 | self.final_tau = final_tau
62 |
63 | @torch.no_grad()
64 | def update(self, online_net: nn.Module, momentum_net: nn.Module):
65 | """Performs the momentum update for each param group.
66 |
67 | Args:
68 | online_net (nn.Module): online network (e.g. online backbone, online projection, etc...).
69 | momentum_net (nn.Module): momentum network (e.g. momentum backbone,
70 | momentum projection, etc...).
71 | """
72 |
73 | for op, mp in zip(online_net.parameters(), momentum_net.parameters()):
74 | mp.data = self.cur_tau * mp.data + (1 - self.cur_tau) * op.data
75 |
76 | def update_tau(self, cur_step: int, max_steps: int):
77 | """Computes the next value for the weighting decrease coefficient tau using cosine annealing.
78 |
79 | Args:
80 | cur_step (int): number of gradient steps so far.
81 | max_steps (int): overall number of gradient steps in the whole training.
82 | """
83 |
84 | self.cur_tau = (
85 | self.final_tau
86 | - (self.final_tau - self.base_tau) * (math.cos(math.pi * cur_step / max_steps) + 1) / 2
87 | )
88 |
--------------------------------------------------------------------------------
/solo/utils/sinkhorn_knopp.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | # Adapted from https://github.com/facebookresearch/swav.
21 |
22 | import torch
23 | import torch.distributed as dist
24 |
25 |
26 | class SinkhornKnopp(torch.nn.Module):
27 | def __init__(self, num_iters: int = 3, epsilon: float = 0.05, world_size: int = 1):
28 | """Approximates optimal transport using the Sinkhorn-Knopp algorithm.
29 |
30 | A simple iterative method to approach the double stochastic matrix is to alternately rescale
31 | rows and columns of the matrix to sum to 1.
32 |
33 | Args:
34 | num_iters (int, optional): number of times to perform row and column normalization.
35 | Defaults to 3.
36 | epsilon (float, optional): weight for the entropy regularization term. Defaults to 0.05.
37 | world_size (int, optional): number of nodes for distributed training. Defaults to 1.
38 | """
39 |
40 | super().__init__()
41 | self.num_iters = num_iters
42 | self.epsilon = epsilon
43 | self.world_size = world_size
44 |
45 | @torch.no_grad()
46 | def forward(self, Q: torch.Tensor) -> torch.Tensor:
47 | """Produces assignments using Sinkhorn-Knopp algorithm.
48 |
49 | Applies the entropy regularization, normalizes the Q matrix and then normalizes rows and
50 | columns in an alternating fashion for num_iter times. Before returning it normalizes again
51 | the columns in order for the output to be an assignment of samples to prototypes.
52 |
53 | Args:
54 | Q (torch.Tensor): cosine similarities between the features of the
55 | samples and the prototypes.
56 |
57 | Returns:
58 | torch.Tensor: assignment of samples to prototypes according to optimal transport.
59 | """
60 |
61 | Q = torch.exp(Q / self.epsilon).t()
62 | B = Q.shape[1] * self.world_size
63 | K = Q.shape[0] # num prototypes
64 |
65 | # make the matrix sums to 1
66 | sum_Q = torch.sum(Q)
67 | if dist.is_available() and dist.is_initialized():
68 | dist.all_reduce(sum_Q)
69 | Q /= sum_Q
70 |
71 | for _ in range(self.num_iters):
72 | # normalize each row: total weight per prototype must be 1/K
73 | sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
74 | if dist.is_available() and dist.is_initialized():
75 | dist.all_reduce(sum_of_rows)
76 | Q /= sum_of_rows
77 | Q /= K
78 |
79 | # normalize each column: total weight per sample must be 1/B
80 | Q /= torch.sum(Q, dim=0, keepdim=True)
81 | Q /= B
82 |
83 | Q *= B # the colomns must sum to 1 so that Q is an assignment
84 | return Q.t()
85 |
--------------------------------------------------------------------------------
/solo/args/dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | from argparse import ArgumentParser
21 | from pathlib import Path
22 |
23 |
24 | def dataset_args(parser: ArgumentParser):
25 | """Adds dataset-related arguments to a parser.
26 |
27 | Args:
28 | parser (ArgumentParser): parser to add dataset args to.
29 | """
30 |
31 | SUPPORTED_DATASETS = [
32 | "cifar10",
33 | "cifar100",
34 | "stl10",
35 | "imagenet",
36 | "imagenet100",
37 | "custom",
38 | ]
39 |
40 | parser.add_argument("--dataset", choices=SUPPORTED_DATASETS, type=str, required=True)
41 |
42 | # dataset path
43 | parser.add_argument("--data_dir", type=Path, required=True)
44 | parser.add_argument("--train_dir", type=Path, default=None)
45 | parser.add_argument("--val_dir", type=Path, default=None)
46 |
47 | # dali (imagenet-100/imagenet/custom only)
48 | parser.add_argument("--dali", action="store_true")
49 | parser.add_argument("--dali_device", type=str, default="gpu")
50 |
51 |
52 | def augmentations_args(parser: ArgumentParser):
53 | """Adds augmentation-related arguments to a parser.
54 |
55 | Args:
56 | parser (ArgumentParser): parser to add augmentation args to.
57 | """
58 |
59 | # cropping
60 | parser.add_argument("--num_crops_per_aug", type=int, default=[2], nargs="+")
61 |
62 | # color jitter
63 | parser.add_argument("--brightness", type=float, required=True, nargs="+")
64 | parser.add_argument("--contrast", type=float, required=True, nargs="+")
65 | parser.add_argument("--saturation", type=float, required=True, nargs="+")
66 | parser.add_argument("--hue", type=float, required=True, nargs="+")
67 | parser.add_argument("--color_jitter_prob", type=float, default=[0.8], nargs="+")
68 |
69 | # other augmentation probabilities
70 | parser.add_argument("--gray_scale_prob", type=float, default=[0.2], nargs="+")
71 | parser.add_argument("--horizontal_flip_prob", type=float, default=[0.5], nargs="+")
72 | parser.add_argument("--gaussian_prob", type=float, default=[0.5], nargs="+")
73 | parser.add_argument("--solarization_prob", type=float, default=[0.0], nargs="+")
74 |
75 | # cropping
76 | parser.add_argument("--crop_size", type=int, default=[224], nargs="+")
77 | parser.add_argument("--min_scale", type=float, default=[0.08], nargs="+")
78 | parser.add_argument("--max_scale", type=float, default=[1.0], nargs="+")
79 |
80 | # debug
81 | parser.add_argument("--debug_augmentations", action="store_true")
82 |
83 |
84 | def linear_augmentations_args(parser: ArgumentParser):
85 | parser.add_argument("--crop_size", type=int, default=[224], nargs="+")
86 |
87 |
88 | def custom_dataset_args(parser: ArgumentParser):
89 | """Adds custom data-related arguments to a parser.
90 |
91 | Args:
92 | parser (ArgumentParser): parser to add augmentation args to.
93 | """
94 |
95 | # custom dataset only
96 | parser.add_argument("--no_labels", action="store_true")
97 |
98 | # for custom dataset
99 | parser.add_argument("--mean", type=float, default=[0.485, 0.456, 0.406], nargs="+")
100 | parser.add_argument("--std", type=float, default=[0.228, 0.224, 0.225], nargs="+")
101 |
--------------------------------------------------------------------------------
/solo/losses/vicreg.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import torch
21 | import torch.nn.functional as F
22 |
23 |
24 | def invariance_loss(z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor:
25 | """Computes mse loss given batch of projected features z1 from view 1 and
26 | projected features z2 from view 2.
27 |
28 | Args:
29 | z1 (torch.Tensor): NxD Tensor containing projected features from view 1.
30 | z2 (torch.Tensor): NxD Tensor containing projected features from view 2.
31 |
32 | Returns:
33 | torch.Tensor: invariance loss (mean squared error).
34 | """
35 |
36 | return F.mse_loss(z1, z2)
37 |
38 |
39 | def variance_loss(z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor:
40 | """Computes variance loss given batch of projected features z1 from view 1 and
41 | projected features z2 from view 2.
42 |
43 | Args:
44 | z1 (torch.Tensor): NxD Tensor containing projected features from view 1.
45 | z2 (torch.Tensor): NxD Tensor containing projected features from view 2.
46 |
47 | Returns:
48 | torch.Tensor: variance regularization loss.
49 | """
50 |
51 | eps = 1e-4
52 | std_z1 = torch.sqrt(z1.var(dim=0) + eps)
53 | std_z2 = torch.sqrt(z2.var(dim=0) + eps)
54 | std_loss = torch.mean(F.relu(1 - std_z1)) + torch.mean(F.relu(1 - std_z2))
55 | return std_loss
56 |
57 |
58 | def covariance_loss(z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor:
59 | """Computes covariance loss given batch of projected features z1 from view 1 and
60 | projected features z2 from view 2.
61 |
62 | Args:
63 | z1 (torch.Tensor): NxD Tensor containing projected features from view 1.
64 | z2 (torch.Tensor): NxD Tensor containing projected features from view 2.
65 |
66 | Returns:
67 | torch.Tensor: covariance regularization loss.
68 | """
69 |
70 | N, D = z1.size()
71 |
72 | z1 = z1 - z1.mean(dim=0)
73 | z2 = z2 - z2.mean(dim=0)
74 | cov_z1 = (z1.T @ z1) / (N - 1)
75 | cov_z2 = (z2.T @ z2) / (N - 1)
76 |
77 | diag = torch.eye(D, device=z1.device)
78 | cov_loss = cov_z1[~diag.bool()].pow_(2).sum() / D + cov_z2[~diag.bool()].pow_(2).sum() / D
79 | return cov_loss
80 |
81 |
82 | def vicreg_loss_func(
83 | z1: torch.Tensor,
84 | z2: torch.Tensor,
85 | sim_loss_weight: float = 25.0,
86 | var_loss_weight: float = 25.0,
87 | cov_loss_weight: float = 1.0,
88 | ) -> torch.Tensor:
89 | """Computes VICReg's loss given batch of projected features z1 from view 1 and
90 | projected features z2 from view 2.
91 |
92 | Args:
93 | z1 (torch.Tensor): NxD Tensor containing projected features from view 1.
94 | z2 (torch.Tensor): NxD Tensor containing projected features from view 2.
95 | sim_loss_weight (float): invariance loss weight.
96 | var_loss_weight (float): variance loss weight.
97 | cov_loss_weight (float): covariance loss weight.
98 |
99 | Returns:
100 | torch.Tensor: VICReg loss.
101 | """
102 |
103 | sim_loss = invariance_loss(z1, z2)
104 | var_loss = variance_loss(z1, z2)
105 | cov_loss = covariance_loss(z1, z2)
106 |
107 | loss = sim_loss_weight * sim_loss + var_loss_weight * var_loss + cov_loss_weight * cov_loss
108 | return loss
109 |
--------------------------------------------------------------------------------
/trades/trades.py:
--------------------------------------------------------------------------------
1 | # This code is adapted from https://github.com/yaodongyu/TRADES/blob/master/trades.py
2 |
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torch.autograd import Variable
8 | import torch.optim as optim
9 |
10 |
11 | def squared_l2_norm(x):
12 | flattened = x.view(x.unsqueeze(0).shape[0], -1)
13 | return (flattened ** 2).sum(1)
14 |
15 |
16 | def l2_norm(x):
17 | return squared_l2_norm(x).sqrt()
18 |
19 |
20 | def trades_loss(model,
21 | model_backbone,
22 | model_linear,
23 | x_natural,
24 | y,
25 | optimizer,
26 | step_size=0.003,
27 | epsilon=0.031,
28 | perturb_steps=10,
29 | beta=1.0,
30 | distance='l_inf'):
31 | # define KL-loss
32 | criterion_kl = nn.KLDivLoss(size_average=False)
33 | model_backbone.eval()
34 | model_linear.eval()
35 | batch_size = len(x_natural)
36 | # generate adversarial example
37 | x_adv = x_natural.detach() + 0.001 * torch.randn(x_natural.shape).cuda().detach()
38 | if distance == 'l_inf':
39 | for _ in range(perturb_steps):
40 | x_adv.requires_grad_()
41 | with torch.enable_grad():
42 | loss_kl = criterion_kl(F.log_softmax(model(x_adv), dim=1),
43 | F.softmax(model(x_natural), dim=1))
44 | grad = torch.autograd.grad(loss_kl, [x_adv])[0]
45 | x_adv = x_adv.detach() + step_size * torch.sign(grad.detach())
46 | x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon)
47 | x_adv = torch.clamp(x_adv, 0.0, 1.0)
48 | elif distance == 'l_2':
49 | delta = 0.001 * torch.randn(x_natural.shape).cuda().detach()
50 | delta = Variable(delta.data, requires_grad=True)
51 |
52 | # Setup optimizers
53 | optimizer_delta = optim.SGD([delta], lr=epsilon / perturb_steps * 2)
54 |
55 | for _ in range(perturb_steps):
56 | adv = x_natural + delta
57 |
58 | # optimize
59 | optimizer_delta.zero_grad()
60 | with torch.enable_grad():
61 | loss = (-1) * criterion_kl(F.log_softmax(model(adv), dim=1),
62 | F.softmax(model(x_natural), dim=1))
63 | loss.backward()
64 | # renorming gradient
65 | grad_norms = delta.grad.view(batch_size, -1).norm(p=2, dim=1)
66 | delta.grad.div_(grad_norms.view(-1, 1, 1, 1))
67 | # avoid nan or inf if gradient is 0
68 | if (grad_norms == 0).any():
69 | delta.grad[grad_norms == 0] = torch.randn_like(delta.grad[grad_norms == 0])
70 | optimizer_delta.step()
71 |
72 | # projection
73 | delta.data.add_(x_natural)
74 | delta.data.clamp_(0, 1).sub_(x_natural)
75 | delta.data.renorm_(p=2, dim=0, maxnorm=epsilon)
76 | x_adv = Variable(x_natural + delta, requires_grad=False)
77 | else:
78 | x_adv = torch.clamp(x_adv, 0.0, 1.0)
79 | # model.train()
80 | model_backbone.train()
81 | model_linear.train()
82 |
83 | x_adv = Variable(torch.clamp(x_adv, 0.0, 1.0), requires_grad=False)
84 | # zero gradient
85 | optimizer.zero_grad()
86 | # calculate robust loss
87 | logits = model(x_natural)
88 | loss_natural = F.cross_entropy(logits, y)
89 | loss_robust = (1.0 / batch_size) * criterion_kl(F.log_softmax(model(x_adv), dim=1),
90 | F.softmax(model(x_natural), dim=1))
91 | loss = loss_natural + beta * loss_robust
92 | return loss, logits
93 |
94 |
95 | # This function is adapted from https://github.com/yaodongyu/TRADES/blob/master/pgd_attack_cifar10.py#L54
96 | def _pgd_whitebox(model,
97 | X,
98 | y,
99 | epsilon,
100 | num_steps,
101 | step_size,
102 | random,
103 | device,
104 | ):
105 | out_clean = model(X)
106 | X_pgd = Variable(X.data, requires_grad=True)
107 | if random:
108 | random_noise = torch.FloatTensor(*X_pgd.shape).uniform_(-epsilon, epsilon).to(device)
109 | X_pgd = Variable(X_pgd.data + random_noise, requires_grad=True)
110 |
111 | for _ in range(num_steps):
112 | opt = optim.SGD([X_pgd], lr=1e-3)
113 | opt.zero_grad()
114 |
115 | with torch.enable_grad():
116 | loss = nn.CrossEntropyLoss()(model(X_pgd), y)
117 | loss.backward()
118 | eta = step_size * X_pgd.grad.data.sign()
119 | X_pgd = Variable(X_pgd.data + eta, requires_grad=True)
120 | eta = torch.clamp(X_pgd.data - X.data, -epsilon, epsilon)
121 | X_pgd = Variable(X.data + eta, requires_grad=True)
122 | X_pgd = Variable(torch.clamp(X_pgd, 0, 1.0), requires_grad=True)
123 |
124 | out_pgd = model(X_pgd)
125 | return out_clean, out_pgd
126 |
127 |
--------------------------------------------------------------------------------
/solo/utils/lars.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | # Copied from Pytorch Lightning (https://github.com/PyTorchLightning/pytorch-lightning/)
21 | # with extra documentations.
22 |
23 |
24 | import torch
25 | from torch.optim import Optimizer
26 |
27 |
28 | class LARSWrapper:
29 | def __init__(
30 | self,
31 | optimizer: Optimizer,
32 | eta: float = 1e-3,
33 | clip: bool = False,
34 | eps: float = 1e-8,
35 | exclude_bias_n_norm: bool = False,
36 | ):
37 | """Wrapper that adds LARS scheduling to any optimizer.
38 | This helps stability with huge batch sizes.
39 |
40 | Args:
41 | optimizer (Optimizer): torch optimizer.
42 | eta (float, optional): trust coefficient. Defaults to 1e-3.
43 | clip (bool, optional): clip gradient values. Defaults to False.
44 | eps (float, optional): adaptive_lr stability coefficient. Defaults to 1e-8.
45 | exclude_bias_n_norm (bool, optional): exclude bias and normalization layers from lars.
46 | Defaults to False.
47 | """
48 |
49 | self.optim = optimizer
50 | self.eta = eta
51 | self.eps = eps
52 | self.clip = clip
53 | self.exclude_bias_n_norm = exclude_bias_n_norm
54 |
55 | # transfer optim methods
56 | self.state_dict = self.optim.state_dict
57 | self.load_state_dict = self.optim.load_state_dict
58 | self.zero_grad = self.optim.zero_grad
59 | self.add_param_group = self.optim.add_param_group
60 |
61 | self.__setstate__ = self.optim.__setstate__ # type: ignore
62 | self.__getstate__ = self.optim.__getstate__ # type: ignore
63 | self.__repr__ = self.optim.__repr__ # type: ignore
64 |
65 | @property
66 | def defaults(self):
67 | return self.optim.defaults
68 |
69 | @defaults.setter
70 | def defaults(self, defaults):
71 | self.optim.defaults = defaults
72 |
73 | @property # type: ignore
74 | def __class__(self):
75 | return Optimizer
76 |
77 | @property
78 | def state(self):
79 | return self.optim.state
80 |
81 | @state.setter
82 | def state(self, state):
83 | self.optim.state = state
84 |
85 | @property
86 | def param_groups(self):
87 | return self.optim.param_groups
88 |
89 | @param_groups.setter
90 | def param_groups(self, value):
91 | self.optim.param_groups = value
92 |
93 | @torch.no_grad()
94 | def step(self, closure=None):
95 | weight_decays = []
96 |
97 | for group in self.optim.param_groups:
98 | weight_decay = group.get("weight_decay", 0)
99 | weight_decays.append(weight_decay)
100 |
101 | # reset weight decay
102 | group["weight_decay"] = 0
103 |
104 | # update the parameters
105 | for p in group["params"]:
106 | if p.grad is not None and (p.ndim != 1 or not self.exclude_bias_n_norm):
107 | self.update_p(p, group, weight_decay)
108 |
109 | # update the optimizer
110 | self.optim.step(closure=closure)
111 |
112 | # return weight decay control to optimizer
113 | for group_idx, group in enumerate(self.optim.param_groups):
114 | group["weight_decay"] = weight_decays[group_idx]
115 |
116 | def update_p(self, p, group, weight_decay):
117 | # calculate new norms
118 | p_norm = torch.norm(p.data)
119 | g_norm = torch.norm(p.grad.data)
120 |
121 | if p_norm != 0 and g_norm != 0:
122 | # calculate new lr
123 | new_lr = (self.eta * p_norm) / (g_norm + p_norm * weight_decay + self.eps)
124 |
125 | # clip lr
126 | if self.clip:
127 | new_lr = min(new_lr / group["lr"], 1)
128 |
129 | # update params with clipped lr
130 | p.grad.data += weight_decay * p.data
131 | p.grad.data *= new_lr
132 |
--------------------------------------------------------------------------------
/solo/losses/dino.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import numpy as np
21 | import torch
22 | import torch.distributed as dist
23 | import torch.nn as nn
24 | import torch.nn.functional as F
25 |
26 |
27 | class DINOLoss(nn.Module):
28 | def __init__(
29 | self,
30 | num_prototypes: int,
31 | warmup_teacher_temp: float,
32 | teacher_temp: float,
33 | warmup_teacher_temp_epochs: float,
34 | num_epochs: int,
35 | student_temp: float = 0.1,
36 | num_large_crops: int = 2,
37 | center_momentum: float = 0.9,
38 | ):
39 | """Auxiliary module to compute DINO's loss.
40 |
41 | Args:
42 | num_prototypes (int): number of prototypes.
43 | warmup_teacher_temp (float): base temperature for the temperature schedule
44 | of the teacher.
45 | teacher_temp (float): final temperature for the teacher.
46 | warmup_teacher_temp_epochs (float): number of epochs for the cosine annealing schedule.
47 | num_epochs (int): total number of epochs.
48 | student_temp (float, optional): temperature for the student. Defaults to 0.1.
49 | num_large_crops (int, optional): number of crops/views. Defaults to 2.
50 | center_momentum (float, optional): momentum for the EMA update of the center of
51 | mass of the teacher. Defaults to 0.9.
52 | """
53 |
54 | super().__init__()
55 | self.epoch = 0
56 | self.student_temp = student_temp
57 | self.center_momentum = center_momentum
58 | self.num_large_crops = num_large_crops
59 | self.register_buffer("center", torch.zeros(1, num_prototypes))
60 | # we apply a warm up for the teacher temperature because
61 | # a too high temperature makes the training instable at the beginning
62 | self.teacher_temp_schedule = np.concatenate(
63 | (
64 | np.linspace(warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs),
65 | np.ones(num_epochs - warmup_teacher_temp_epochs) * teacher_temp,
66 | )
67 | )
68 |
69 | def forward(self, student_output: torch.Tensor, teacher_output: torch.Tensor) -> torch.Tensor:
70 | """Computes DINO's loss given a batch of logits of the student and a batch of logits of the
71 | teacher.
72 |
73 | Args:
74 | student_output (torch.Tensor): NxP Tensor containing student logits for all views.
75 | teacher_output (torch.Tensor): NxP Tensor containing teacher logits for all views.
76 |
77 | Returns:
78 | torch.Tensor: DINO loss.
79 | """
80 |
81 | student_out = student_output / self.student_temp
82 | student_out = student_out.chunk(self.num_large_crops)
83 |
84 | # teacher centering and sharpening
85 | temp = self.teacher_temp_schedule[self.epoch]
86 | teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)
87 | teacher_out = teacher_out.detach().chunk(2)
88 |
89 | total_loss = 0
90 | n_loss_terms = 0
91 | for iq, q in enumerate(teacher_out):
92 | for iv, v in enumerate(student_out):
93 | if iv == iq:
94 | # we skip cases where student and teacher operate on the same view
95 | continue
96 | loss = torch.sum(-q * F.log_softmax(v, dim=-1), dim=-1)
97 | total_loss += loss.mean()
98 | n_loss_terms += 1
99 | total_loss /= n_loss_terms
100 | self.update_center(teacher_output)
101 | return total_loss
102 |
103 | @torch.no_grad()
104 | def update_center(self, teacher_output: torch.Tensor):
105 | """Updates the center for DINO's loss using exponential moving average.
106 |
107 | Args:
108 | teacher_output (torch.Tensor): NxP Tensor containing teacher logits of all views.
109 | """
110 |
111 | batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
112 | if dist.is_available() and dist.is_initialized():
113 | dist.all_reduce(batch_center)
114 | batch_center = batch_center / dist.get_world_size()
115 | batch_center = batch_center / len(teacher_output)
116 |
117 | # ema update
118 | self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)
119 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Decoupled Adversarial Contrastive Learning for Self-supervised Adversarial Robustness
2 |
3 | Chaoning Zhang*, Kang Zhang*, Chenshuang Zhang, Axi Niu, Jiu Feng, Chang D. Yoo, In So Kweon
4 | (*Equal contribution)
5 |
6 | This is the official implementation of the paper "Decoupled Adversarial Contrastive Learning for Self-supervised Adversarial Robustness," which was accepted for an oral presentation at ECCV 2022.
7 |
8 | The DeACL framework consists of two stages. In the first stage, DeACL performs standard self-supervised learning (SSL) to obtain a non-robust encoder. In the second stage, the pretrained encoder acts as a teacher model, generating pseudo-targets to guide supervised adversarial training (AT) on a student model. The student model, trained through these two stages, is the final model of interest. DeACL is a general framework that can be applied to any SSL method and AT method.
9 |
10 |
11 | Paper link: [arXiv](https://arxiv.org/abs/2207.10899), [ECCV2022](https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136900716.pdf), [supplementary material](https://www.ecva.net/papers/eccv_2022/papers_ECCV/papers/136900716-supp.pdf)
12 |
13 |
14 | # Change log
15 | *2024.7.14* We have rewritten the SLF and AFF code to enhance its readability and usability. The new code is more modular, making it easier to extend to other datasets and models. Using this updated code, we conducted experiments with SimCLR and ResNet18 on the CIFAR10 dataset (ckpt [here](https://drive.google.com/file/d/1yc38miWGY57sHS6W6aY_k5t69Gt5v5fm/view?usp=sharing)). The code was executed five times, and the average results are reported below. The updated code can be found in the `adv_finetune.py` file.
16 |
17 | *2023.3.2* The different definitions of the Resnet model between pre-train and SLF make the forward and backward different. Our previous code can get a different result given in the paper. We fixed the bug by changing the Resnet used during the SLF setting and released the pre-trained model with new code, which performs slightly differently from the one reported in the paper (SLF with CIFAR10 (AA,RA,SA) reported in paper: `45.31, 53.95, 80.17` -> with current code: `45.57, 55.43, 79.53`). (We apologize for not providing the model used in the paper since we accidentally deleted the original file.) We also update the environment configuration to help you reproduce our result.
18 |
19 | # 🔧 Enviroment
20 | We use [conda](https://docs.conda.io/en/latest/miniconda.html) for python enviroment management. After installing conda,
21 |
22 | 1. conda create -n deacl python=3.8
23 |
24 | 2. conda activate deacl
25 |
26 | 3. pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117
27 |
28 | 2. pip install -r requirements.txt
29 |
30 |
31 | # ⚡ Training and Evaluation
32 |
33 | ## 1. prepare the pretrained teacher self-supervised model
34 | You can download pretrained checkpoint from [solo-learn](https://github.com/vturrisi/solo-learn#cifar-10) or train by yourself.
35 |
36 | SimCLR model pretrained by solo-learn [link](https://drive.google.com/drive/folders/1mcvWr8P2WNJZ7TVpdLHA_Q91q4VK3y8O?usp=sharing).
37 |
38 | put the downloaded model into folder 'TeacherCKPT'
39 |
40 | ## 2. Train DeACL with ResNet18 on CIFAR10 dataset
41 |
42 | Using the file `bash_files\DeACL_cifar10_resnet18.sh`.
43 |
44 | You need to specific the `--project xxx`, put your wandb api key, and add `--wandb` to enable wandb logging.
45 |
46 | ## 3. Test the robustness of PGD and AutoAttack under standard linear fine-tuning (SLF) and adversarial full fine-tuning (AFF)
47 | First install [autoattack](https://github.com/fra31/auto-attack) package `pip install git+https://github.com/fra31/auto-attack`
48 |
49 | ### a. Eval the the trained model in step 2
50 | Use the following commandline, replace the `CKPT_PATH` with the path of the trained model.
51 |
52 | ```bash
53 | # SLF
54 | python adv_finetune.py --ckpt CKPT_PATH --mode SLF --learning_rate 0.1
55 | # AFF
56 | python adv_finetune.py --ckpt CKPT_PATH --mode AFF --learning_rate 0.01
57 | ```
58 |
59 | ### b. Eval the pretrained model provided by us
60 | We privide our pretrained model on CIFAR10 with teacher model SimCLR at [here](https://drive.google.com/file/d/1yc38miWGY57sHS6W6aY_k5t69Gt5v5fm/view?usp=sharing), you can download it and use the commandline in a. to evaluate the model.
61 |
62 | The results get by the pretrained model with the above code are as follows (average of 5 runs). For SLF, the initial learning rate is 0.1, and for AFF and ALF, the initial learning rate is 0.01 with beta 6 in trades loss. All three modes training epochs are 25, and the learning rate decay at 15 and 20 epochs by 10 times.
63 | | Mode | AA | RA | SA |
64 | | --- | --- | --- | --- |
65 | | SLF | 46.14 ± 0.054 | 53.45 ± 0.095 | 80.82 ± 0.090 |
66 | | AFF | 50.75 ± 0.150 | 54.23 ± 0.238 | 83.64 ± 0.139 |
67 | | ALF | 45.55 ± 0.134 | 55.30 ± 0.142 | 79.39 ± 0.140 |
68 |
69 |
70 | # Acknowledgement
71 | This code is developed based on [solo-learn](https://github.com/vturrisi/solo-learn) for training and [AdvCL](https://github.com/LijieFan/AdvCL.git), [AutoAttack](https://github.com/fra31/auto-attack) and [TRADES](https://github.com/yaodongyu/TRADES) for testing.
72 |
73 |
82 |
83 | # See also our other works
84 |
85 | Dual Temperature Helps Contrastive Learning Without Many Negative Samples: Towards Understanding and Simplifying MoCo (Accepted by CVPR2022) [GitHub](https://github.com/ChaoningZhang/Dual-temperature.git), [arXiv](https://arxiv.org/abs/2203.17248)
86 |
87 |
88 |
89 | # Citation
90 | ```
91 | @inproceedings{zhang2022decoupled,
92 | title={Decoupled Adversarial Contrastive Learning for Self-supervised Adversarial Robustness},
93 | author={Zhang, Chaoning and Zhang, Kang and Zhang, Chenshuang and Niu, Axi and Feng, Jiu and Yoo, Chang D and Kweon, In So},
94 | booktitle={ECCV 2022},
95 | pages={725--742},
96 | year={2022},
97 | organization={Springer}
98 | }
99 | ```
100 |
--------------------------------------------------------------------------------
/solo/utils/checkpointer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import json
21 | import os
22 | import random
23 | import string
24 | import time
25 | from argparse import ArgumentParser, Namespace
26 | from pathlib import Path
27 | from typing import Optional, Union
28 |
29 | import pytorch_lightning as pl
30 | from pytorch_lightning.callbacks import Callback
31 |
32 |
33 | def random_string(letter_count=4, digit_count=4):
34 | tmp_random = random.Random(time.time())
35 | rand_str = "".join((tmp_random.choice(string.ascii_lowercase) for x in range(letter_count)))
36 | rand_str += "".join((tmp_random.choice(string.digits) for x in range(digit_count)))
37 | rand_str = list(rand_str)
38 | tmp_random.shuffle(rand_str)
39 | return "".join(rand_str)
40 |
41 |
42 | class Checkpointer(Callback):
43 | def __init__(
44 | self,
45 | args: Namespace,
46 | logdir: Union[str, Path] = Path("trained_models"),
47 | frequency: int = 1,
48 | keep_previous_checkpoints: bool = False,
49 | ):
50 | """Custom checkpointer callback that stores checkpoints in an easier to access way.
51 |
52 | Args:
53 | args (Namespace): namespace object containing at least an attribute name.
54 | logdir (Union[str, Path], optional): base directory to store checkpoints.
55 | Defaults to "trained_models".
56 | frequency (int, optional): number of epochs between each checkpoint. Defaults to 1.
57 | keep_previous_checkpoints (bool, optional): whether to keep previous checkpoints or not.
58 | Defaults to False.
59 | """
60 |
61 | super().__init__()
62 |
63 | self.args = args
64 | self.logdir = Path(logdir)
65 | self.frequency = frequency
66 | self.keep_previous_checkpoints = keep_previous_checkpoints
67 |
68 | @staticmethod
69 | def add_checkpointer_args(parent_parser: ArgumentParser):
70 | """Adds user-required arguments to a parser.
71 |
72 | Args:
73 | parent_parser (ArgumentParser): parser to add new args to.
74 | """
75 |
76 | parser = parent_parser.add_argument_group("checkpointer")
77 | parser.add_argument("--checkpoint_dir", default=Path("trained_models"), type=Path)
78 | parser.add_argument("--checkpoint_frequency", default=1, type=int)
79 | return parent_parser
80 |
81 | def initial_setup(self, trainer: pl.Trainer):
82 | """Creates the directories and does the initial setup needed.
83 |
84 | Args:
85 | trainer (pl.Trainer): pytorch lightning trainer object.
86 | """
87 |
88 | if trainer.logger is None:
89 | if self.logdir.exists():
90 | existing_versions = set(os.listdir(self.logdir))
91 | else:
92 | existing_versions = []
93 | version = "offline-" + random_string()
94 | while version in existing_versions:
95 | version = "offline-" + random_string()
96 | else:
97 | version = str(trainer.logger.version)
98 | if version is not None:
99 | self.path = self.logdir / version
100 | self.ckpt_placeholder = f"{self.args.name}-{version}" + "-ep={}.ckpt"
101 | else:
102 | self.path = self.logdir
103 | self.ckpt_placeholder = f"{self.args.name}" + "-ep={}.ckpt"
104 | self.last_ckpt: Optional[str] = None
105 |
106 | # create logging dirs
107 | if trainer.is_global_zero:
108 | os.makedirs(self.path, exist_ok=True)
109 |
110 | def save_args(self, trainer: pl.Trainer):
111 | """Stores arguments into a json file.
112 |
113 | Args:
114 | trainer (pl.Trainer): pytorch lightning trainer object.
115 | """
116 |
117 | if trainer.is_global_zero:
118 | args = vars(self.args)
119 | json_path = self.path / "args.json"
120 | json.dump(args, open(json_path, "w"), default=lambda o: "")
121 |
122 | def save(self, trainer: pl.Trainer):
123 | """Saves current checkpoint.
124 |
125 | Args:
126 | trainer (pl.Trainer): pytorch lightning trainer object.
127 | """
128 |
129 | if trainer.is_global_zero and not trainer.sanity_checking:
130 | epoch = trainer.current_epoch # type: ignore
131 | ckpt = self.path / self.ckpt_placeholder.format(epoch)
132 | trainer.save_checkpoint(ckpt)
133 |
134 | if self.last_ckpt and self.last_ckpt != ckpt and not self.keep_previous_checkpoints:
135 | os.remove(self.last_ckpt)
136 | self.last_ckpt = ckpt
137 |
138 | def on_train_start(self, trainer: pl.Trainer, _):
139 | """Executes initial setup and saves arguments.
140 |
141 | Args:
142 | trainer (pl.Trainer): pytorch lightning trainer object.
143 | """
144 |
145 | self.initial_setup(trainer)
146 | self.save_args(trainer)
147 |
148 | def on_train_epoch_end(self, trainer: pl.Trainer, _):
149 | """Tries to save current checkpoint at the end of each train epoch.
150 |
151 | Args:
152 | trainer (pl.Trainer): pytorch lightning trainer object.
153 | """
154 |
155 | epoch = trainer.current_epoch # type: ignore
156 | if epoch % self.frequency == 0:
157 | self.save(trainer)
158 |
--------------------------------------------------------------------------------
/solo/args/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import argparse
21 |
22 | import pytorch_lightning as pl
23 | from solo.args.dataset import (
24 | augmentations_args,
25 | custom_dataset_args,
26 | dataset_args,
27 | linear_augmentations_args,
28 | )
29 | from solo.args.utils import additional_setup_linear, additional_setup_pretrain
30 | from solo.methods import METHODS
31 | from solo.utils.auto_resumer import AutoResumer
32 | from solo.utils.checkpointer import Checkpointer
33 |
34 | try:
35 | from solo.utils.auto_umap import AutoUMAP
36 | except ImportError:
37 | _umap_available = False
38 | else:
39 | _umap_available = True
40 |
41 |
42 | def parse_args_pretrain() -> argparse.Namespace:
43 | """Parses dataset, augmentation, pytorch lightning, model specific and additional args.
44 |
45 | First adds shared args such as dataset, augmentation and pytorch lightning args, then pulls the
46 | model name from the command and proceeds to add model specific args from the desired class. If
47 | wandb is enabled, it adds checkpointer args. Finally, adds additional non-user given parameters.
48 |
49 | Returns:
50 | argparse.Namespace: a namespace containing all args needed for pretraining.
51 | """
52 |
53 | parser = argparse.ArgumentParser()
54 |
55 | parser.add_argument("--aux_data", action='store_true')
56 |
57 |
58 | # add shared arguments
59 | dataset_args(parser)
60 | augmentations_args(parser)
61 | custom_dataset_args(parser)
62 |
63 | # add pytorch lightning trainer args
64 | parser = pl.Trainer.add_argparse_args(parser)
65 |
66 | # add method-specific arguments
67 | parser.add_argument("--method", type=str)
68 |
69 | # THIS LINE IS KEY TO PULL THE MODEL NAME
70 | temp_args, _ = parser.parse_known_args()
71 |
72 | # add model specific args
73 | parser = METHODS[temp_args.method].add_model_specific_args(parser)
74 |
75 | # add auto checkpoint/umap args
76 | parser.add_argument("--save_checkpoint", action="store_true")
77 | parser.add_argument("--auto_umap", action="store_true")
78 | parser.add_argument("--auto_resume", action="store_true")
79 | temp_args, _ = parser.parse_known_args()
80 |
81 | # optionally add checkpointer and AutoUMAP args
82 | if temp_args.save_checkpoint:
83 | parser = Checkpointer.add_checkpointer_args(parser)
84 |
85 | if _umap_available and temp_args.auto_umap:
86 | parser = AutoUMAP.add_auto_umap_args(parser)
87 |
88 | if temp_args.auto_resume:
89 | parser = AutoResumer.add_autoresumer_args(parser)
90 |
91 | # parse args
92 | args = parser.parse_args()
93 |
94 | # prepare arguments with additional setup
95 | additional_setup_pretrain(args)
96 |
97 | return args
98 |
99 |
100 | def parse_args_linear() -> argparse.Namespace:
101 | """Parses feature extractor, dataset, pytorch lightning, linear eval specific and additional args.
102 |
103 | First adds an arg for the pretrained feature extractor, then adds dataset, pytorch lightning
104 | and linear eval specific args. If wandb is enabled, it adds checkpointer args. Finally, adds
105 | additional non-user given parameters.
106 |
107 | Returns:
108 | argparse.Namespace: a namespace containing all args needed for pretraining.
109 | """
110 |
111 | parser = argparse.ArgumentParser()
112 |
113 | parser.add_argument("--pretrained_feature_extractor", type=str)
114 |
115 | # add shared arguments
116 | dataset_args(parser)
117 | linear_augmentations_args(parser)
118 | custom_dataset_args(parser)
119 |
120 | # add pytorch lightning trainer args
121 | parser = pl.Trainer.add_argparse_args(parser)
122 |
123 | # linear model
124 | parser = METHODS["linear"].add_model_specific_args(parser)
125 |
126 | # THIS LINE IS KEY TO PULL WANDB AND SAVE_CHECKPOINT
127 | parser.add_argument("--save_checkpoint", action="store_true")
128 | temp_args, _ = parser.parse_known_args()
129 |
130 | # optionally add checkpointer
131 | if temp_args.save_checkpoint:
132 | parser = Checkpointer.add_checkpointer_args(parser)
133 |
134 | # parse args
135 | args = parser.parse_args()
136 | additional_setup_linear(args)
137 |
138 | return args
139 |
140 |
141 | def parse_args_knn() -> argparse.Namespace:
142 | """Parses arguments for offline K-NN.
143 |
144 | Returns:
145 | argparse.Namespace: a namespace containing all args needed for pretraining.
146 | """
147 |
148 | parser = argparse.ArgumentParser()
149 |
150 | # add knn args
151 | parser.add_argument("--pretrained_checkpoint_dir", type=str)
152 | parser.add_argument("--batch_size", type=int, default=16)
153 | parser.add_argument("--num_workers", type=int, default=10)
154 | parser.add_argument("--k", type=int, nargs="+")
155 | parser.add_argument("--temperature", type=float, nargs="+")
156 | parser.add_argument("--distance_function", type=str, nargs="+")
157 | parser.add_argument("--feature_type", type=str, nargs="+")
158 |
159 | # add shared arguments
160 | dataset_args(parser)
161 | custom_dataset_args(parser)
162 |
163 | # parse args
164 | args = parser.parse_args()
165 |
166 | return args
167 |
168 |
169 | def parse_args_umap() -> argparse.Namespace:
170 | """Parses arguments for offline UMAP.
171 |
172 | Returns:
173 | argparse.Namespace: a namespace containing all args needed for pretraining.
174 | """
175 |
176 | parser = argparse.ArgumentParser()
177 |
178 | # add knn args
179 | parser.add_argument("--pretrained_checkpoint_dir", type=str)
180 | parser.add_argument("--batch_size", type=int, default=16)
181 | parser.add_argument("--num_workers", type=int, default=10)
182 |
183 | # add shared arguments
184 | dataset_args(parser)
185 | custom_dataset_args(parser)
186 |
187 | # parse args
188 | args = parser.parse_args()
189 |
190 | return args
191 |
--------------------------------------------------------------------------------
/solo/utils/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import math
21 | import warnings
22 | from typing import List, Tuple
23 |
24 | import torch
25 | import torch.distributed as dist
26 | import torch.nn as nn
27 |
28 |
29 | def _1d_filter(tensor: torch.Tensor) -> torch.Tensor:
30 | return tensor.isfinite()
31 |
32 |
33 | def _2d_filter(tensor: torch.Tensor) -> torch.Tensor:
34 | return tensor.isfinite().all(dim=1)
35 |
36 |
37 | def _single_input_filter(tensor: torch.Tensor) -> Tuple[torch.Tensor]:
38 | if len(tensor.size()) == 1:
39 | filter_func = _1d_filter
40 | elif len(tensor.size()) == 2:
41 | filter_func = _2d_filter
42 | else:
43 | raise RuntimeError("Only 1d and 2d tensors are supported.")
44 |
45 | selected = filter_func(tensor)
46 | tensor = tensor[selected]
47 |
48 | return tensor, selected
49 |
50 |
51 | def _multi_input_filter(tensors: List[torch.Tensor]) -> Tuple[torch.Tensor]:
52 | if len(tensors[0].size()) == 1:
53 | filter_func = _1d_filter
54 | elif len(tensors[0].size()) == 2:
55 | filter_func = _2d_filter
56 | else:
57 | raise RuntimeError("Only 1d and 2d tensors are supported.")
58 |
59 | selected = filter_func(tensors[0])
60 | for tensor in tensors[1:]:
61 | selected = torch.logical_and(selected, filter_func(tensor))
62 | tensors = [tensor[selected] for tensor in tensors]
63 |
64 | return tensors, selected
65 |
66 |
67 | def filter_inf_n_nan(tensors: List[torch.Tensor], return_indexes: bool = False):
68 | """Filters out inf and nans from any tensor.
69 | This is usefull when there are instability issues,
70 | which cause a small number of values to go bad.
71 |
72 | Args:
73 | tensor (List): tensor to remove nans and infs from.
74 |
75 | Returns:
76 | torch.Tensor: filtered view of the tensor without nans or infs.
77 | """
78 |
79 | if isinstance(tensors, torch.Tensor):
80 | tensors, selected = _single_input_filter(tensors)
81 | else:
82 | tensors, selected = _multi_input_filter(tensors)
83 |
84 | if return_indexes:
85 | return tensors, selected
86 | return tensors
87 |
88 |
89 | class FilterInfNNan(nn.Module):
90 | def __init__(self, module):
91 | """Layer that filters out inf and nans from any tensor.
92 | This is usefull when there are instability issues,
93 | which cause a small number of values to go bad.
94 |
95 | Args:
96 | tensor (List): tensor to remove nans and infs from.
97 |
98 | Returns:
99 | torch.Tensor: filtered view of the tensor without nans or infs.
100 | """
101 | super().__init__()
102 |
103 | self.module = module
104 |
105 | def forward(self, x: torch.Tensor) -> torch.Tensor:
106 | out = self.module(x)
107 | out = filter_inf_n_nan(out)
108 | return out
109 |
110 | def __getattr__(self, name):
111 | try:
112 | return super().__getattr__(name)
113 | except AttributeError:
114 | if name == "module":
115 | raise AttributeError()
116 | return getattr(self.module, name)
117 |
118 |
119 | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
120 | """Copy & paste from PyTorch official master until it's in a few official releases - RW
121 | Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
122 | """
123 |
124 | def norm_cdf(x):
125 | """Computes standard normal cumulative distribution function"""
126 |
127 | return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
128 |
129 | if (mean < a - 2 * std) or (mean > b + 2 * std):
130 | warnings.warn(
131 | "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
132 | "The distribution of values may be incorrect.",
133 | stacklevel=2,
134 | )
135 |
136 | with torch.no_grad():
137 | # Values are generated by using a truncated uniform distribution and
138 | # then using the inverse CDF for the normal distribution.
139 | # Get upper and lower cdf values
140 | l = norm_cdf((a - mean) / std)
141 | u = norm_cdf((b - mean) / std)
142 |
143 | # Uniformly fill tensor with values from [l, u], then translate to
144 | # [2l-1, 2u-1].
145 | tensor.uniform_(2 * l - 1, 2 * u - 1)
146 |
147 | # Use inverse cdf transform for normal distribution to get truncated
148 | # standard normal
149 | tensor.erfinv_()
150 |
151 | # Transform to proper mean, std
152 | tensor.mul_(std * math.sqrt(2.0))
153 | tensor.add_(mean)
154 |
155 | # Clamp to ensure it's in the proper range
156 | tensor.clamp_(min=a, max=b)
157 | return tensor
158 |
159 |
160 | def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
161 | """Copy & paste from PyTorch official master until it's in a few official releases - RW
162 | Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
163 | """
164 |
165 | return _no_grad_trunc_normal_(tensor, mean, std, a, b)
166 |
167 |
168 | class GatherLayer(torch.autograd.Function):
169 | """Gathers tensors from all processes, supporting backward propagation."""
170 |
171 | @staticmethod
172 | def forward(ctx, inp):
173 | ctx.save_for_backward(inp)
174 | if dist.is_available() and dist.is_initialized():
175 | output = [torch.zeros_like(inp) for _ in range(dist.get_world_size())]
176 | dist.all_gather(output, inp)
177 | else:
178 | output = [inp]
179 | return tuple(output)
180 |
181 | @staticmethod
182 | def backward(ctx, *grads):
183 | (inp,) = ctx.saved_tensors
184 | if dist.is_available() and dist.is_initialized():
185 | grad_out = torch.zeros_like(inp)
186 | grad_out[:] = grads[dist.get_rank()]
187 | else:
188 | grad_out = grads[0]
189 | return grad_out
190 |
191 |
192 | def gather(X, dim=0):
193 | """Gathers tensors from all processes, supporting backward propagation."""
194 | return torch.cat(GatherLayer.apply(X), dim=dim)
195 |
196 |
197 | def get_rank():
198 | if dist.is_available() and dist.is_initialized():
199 | return dist.get_rank()
200 | return 0
201 |
--------------------------------------------------------------------------------
/solo/utils/kmeans.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | from typing import Any, Sequence
21 |
22 | import numpy as np
23 | import torch
24 | import torch.distributed as dist
25 | import torch.nn.functional as F
26 | from scipy.sparse import csr_matrix
27 |
28 |
29 | class KMeans:
30 | def __init__(
31 | self,
32 | world_size: int,
33 | rank: int,
34 | num_large_crops: int,
35 | dataset_size: int,
36 | proj_features_dim: int,
37 | num_prototypes: int,
38 | kmeans_iters: int = 10,
39 | ):
40 | """Class that performs K-Means on the hypersphere.
41 |
42 | Args:
43 | world_size (int): world size.
44 | rank (int): rank of the current process.
45 | num_large_crops (int): number of crops.
46 | dataset_size (int): total size of the dataset (number of samples).
47 | proj_features_dim (int): number of dimensions of the projected features.
48 | num_prototypes (int): number of prototypes.
49 | kmeans_iters (int, optional): number of iterations for the k-means clustering.
50 | Defaults to 10.
51 | """
52 | self.world_size = world_size
53 | self.rank = rank
54 | self.num_large_crops = num_large_crops
55 | self.dataset_size = dataset_size
56 | self.proj_features_dim = proj_features_dim
57 | self.num_prototypes = num_prototypes
58 | self.kmeans_iters = kmeans_iters
59 |
60 | @staticmethod
61 | def get_indices_sparse(data: np.ndarray):
62 | cols = np.arange(data.size)
63 | M = csr_matrix((cols, (data.ravel(), cols)), shape=(int(data.max()) + 1, data.size))
64 | return [np.unravel_index(row.data, data.shape) for row in M]
65 |
66 | def cluster_memory(
67 | self,
68 | local_memory_index: torch.Tensor,
69 | local_memory_embeddings: torch.Tensor,
70 | ) -> Sequence[Any]:
71 | """Performs K-Means clustering on the hypersphere and returns centroids and
72 | assignments for each sample.
73 |
74 | Args:
75 | local_memory_index (torch.Tensor): memory bank cointaining indices of the
76 | samples.
77 | local_memory_embeddings (torch.Tensor): memory bank cointaining embeddings
78 | of the samples.
79 |
80 | Returns:
81 | Sequence[Any]: assignments and centroids.
82 | """
83 | j = 0
84 | device = local_memory_embeddings.device
85 | assignments = -torch.ones(len(self.num_prototypes), self.dataset_size).long()
86 | centroids_list = []
87 | with torch.no_grad():
88 | for i_K, K in enumerate(self.num_prototypes):
89 | # run distributed k-means
90 |
91 | # init centroids with elements from memory bank of rank 0
92 | centroids = torch.empty(K, self.proj_features_dim).to(device, non_blocking=True)
93 | if self.rank == 0:
94 | random_idx = torch.randperm(len(local_memory_embeddings[j]))[:K]
95 | assert len(random_idx) >= K, "please reduce the number of centroids"
96 | centroids = local_memory_embeddings[j][random_idx]
97 | if dist.is_available() and dist.is_initialized():
98 | dist.broadcast(centroids, 0)
99 |
100 | for n_iter in range(self.kmeans_iters + 1):
101 |
102 | # E step
103 | dot_products = torch.mm(local_memory_embeddings[j], centroids.t())
104 | _, local_assignments = dot_products.max(dim=1)
105 |
106 | # finish
107 | if n_iter == self.kmeans_iters:
108 | break
109 |
110 | # M step
111 | where_helper = self.get_indices_sparse(local_assignments.cpu().numpy())
112 | counts = torch.zeros(K).to(device, non_blocking=True).int()
113 | emb_sums = torch.zeros(K, self.proj_features_dim).to(device, non_blocking=True)
114 | for k in range(len(where_helper)):
115 | if len(where_helper[k][0]) > 0:
116 | emb_sums[k] = torch.sum(
117 | local_memory_embeddings[j][where_helper[k][0]],
118 | dim=0,
119 | )
120 | counts[k] = len(where_helper[k][0])
121 | if dist.is_available() and dist.is_initialized():
122 | dist.all_reduce(counts)
123 | dist.all_reduce(emb_sums)
124 | mask = counts > 0
125 | centroids[mask] = emb_sums[mask] / counts[mask].unsqueeze(1)
126 |
127 | # normalize centroids
128 | centroids = F.normalize(centroids, dim=1, p=2)
129 |
130 | centroids_list.append(centroids)
131 |
132 | if dist.is_available() and dist.is_initialized():
133 | # gather the assignments
134 | assignments_all = torch.empty(
135 | self.world_size,
136 | local_assignments.size(0),
137 | dtype=local_assignments.dtype,
138 | device=local_assignments.device,
139 | )
140 | assignments_all = list(assignments_all.unbind(0))
141 |
142 | dist_process = dist.all_gather(
143 | assignments_all, local_assignments, async_op=True
144 | )
145 | dist_process.wait()
146 | assignments_all = torch.cat(assignments_all).cpu()
147 |
148 | # gather the indexes
149 | indexes_all = torch.empty(
150 | self.world_size,
151 | local_memory_index.size(0),
152 | dtype=local_memory_index.dtype,
153 | device=local_memory_index.device,
154 | )
155 | indexes_all = list(indexes_all.unbind(0))
156 | dist_process = dist.all_gather(indexes_all, local_memory_index, async_op=True)
157 | dist_process.wait()
158 | indexes_all = torch.cat(indexes_all).cpu()
159 |
160 | else:
161 | assignments_all = local_assignments
162 | indexes_all = local_memory_index
163 |
164 | # log assignments
165 | assignments[i_K][indexes_all] = assignments_all
166 |
167 | # next memory bank to use
168 | j = (j + 1) % self.num_large_crops
169 |
170 | return assignments, centroids_list
171 |
--------------------------------------------------------------------------------
/main_pretrain_AdvTraining.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import os
21 | from pprint import pprint
22 |
23 | from pytorch_lightning import Trainer, seed_everything
24 | from pytorch_lightning.callbacks import LearningRateMonitor
25 | from pytorch_lightning.loggers import WandbLogger
26 |
27 | from solo.args.setup import parse_args_pretrain
28 | from solo.methods import METHODS
29 | from solo.utils.auto_resumer import AutoResumer
30 | import wandb
31 | try:
32 | from solo.methods.dali import PretrainABC
33 | except ImportError as e:
34 | print(e)
35 | _dali_avaliable = False
36 | else:
37 | _dali_avaliable = True
38 |
39 | try:
40 | from solo.utils.auto_umap import AutoUMAP
41 | except ImportError:
42 | _umap_available = False
43 | else:
44 | _umap_available = True
45 |
46 | import shutil
47 | import types
48 |
49 | from torchvision import transforms
50 |
51 | from solo.utils.checkpointer import Checkpointer
52 | # from solo.utils.classification_dataloader import prepare_data as prepare_data_classification
53 | from solo.utils.classification_dataloader_AdvTraining import \
54 | prepare_data as prepare_data_classification
55 | from solo.utils.pretrain_dataloader_AdvTraining import (
56 | prepare_dataloader, prepare_datasets, prepare_n_crop_transform,
57 | prepare_transform)
58 |
59 |
60 | def main():
61 | seed_everything(5)
62 |
63 | args = parse_args_pretrain()
64 |
65 | print(args)
66 |
67 | assert args.method in METHODS, f"Choose from {METHODS.keys()}"
68 |
69 | if args.num_large_crops != 2:
70 | assert args.method == "wmse"
71 |
72 | MethodClass = METHODS[args.method]
73 | if args.dali:
74 | assert (
75 | _dali_avaliable
76 | ), "Dali is not currently avaiable, please install it first with [dali]."
77 | MethodClass = types.new_class(f"Dali{MethodClass.__name__}", (PretrainABC, MethodClass))
78 |
79 | model = MethodClass(**args.__dict__)
80 |
81 | # pretrain dataloader
82 | if not args.dali:
83 | # asymmetric augmentations
84 | if args.unique_augs > 1:
85 |
86 | if args.dataset in ["cifar10", "cifar100", "svhn"]:
87 | crop_size = 32
88 | elif args.dataset in ["tinyimagenet"]:
89 | crop_size = 64
90 | else:
91 | raise
92 |
93 | weak_transforms = transforms.Compose([
94 | transforms.RandomCrop(crop_size, padding=4),
95 | transforms.RandomHorizontalFlip(),
96 | transforms.ToTensor(),
97 | ])
98 | transform = [
99 | weak_transforms
100 | ]
101 |
102 | else:
103 | transform = [prepare_transform(args.dataset, **args.transform_kwargs)]
104 |
105 | transform = prepare_n_crop_transform(transform, num_crops_per_aug=args.num_crops_per_aug)
106 | if args.debug_augmentations:
107 | print("Transforms:")
108 | pprint(transform)
109 |
110 | train_dataset = prepare_datasets(
111 | args.dataset,
112 | transform,
113 | data_dir=args.data_dir,
114 | train_dir=args.train_dir,
115 | no_labels=args.no_labels,
116 | )
117 | train_loader = prepare_dataloader(
118 | train_dataset, batch_size=args.batch_size, num_workers=args.num_workers
119 | )
120 |
121 |
122 |
123 | # normal dataloader for when it is available
124 | if args.dataset == "custom" and (args.no_labels or args.val_dir is None):
125 | val_loader = None
126 | elif args.dataset in ["imagenet100", "imagenet"] and args.val_dir is None:
127 | val_loader = None
128 | else:
129 | _, val_loader = prepare_data_classification(
130 | args.dataset,
131 | data_dir=args.data_dir,
132 | train_dir=args.train_dir,
133 | val_dir=args.val_dir,
134 | batch_size=args.batch_size,
135 | num_workers=args.num_workers,
136 | )
137 |
138 | callbacks = []
139 |
140 | # wandb logging
141 | if args.wandb:
142 | wandb_logger = WandbLogger(
143 | name=args.name,
144 | project=args.project,
145 | offline=args.offline,
146 | settings=wandb.Settings(start_method="fork")
147 | )
148 | wandb_logger.watch(model, log="gradients", log_freq=100)
149 | wandb_logger.log_hyperparams(args)
150 |
151 | # lr logging
152 | lr_monitor = LearningRateMonitor(logging_interval="epoch")
153 | callbacks.append(lr_monitor)
154 |
155 | if args.save_checkpoint:
156 | # save checkpoint on last epoch only
157 | ckpt = Checkpointer(
158 | args,
159 | logdir=os.path.join(args.checkpoint_dir, args.method),
160 | frequency=args.checkpoint_frequency,
161 | )
162 | callbacks.append(ckpt)
163 |
164 | if args.auto_umap:
165 | assert (
166 | _umap_available
167 | ), "UMAP is not currently avaiable, please install it first with [umap]."
168 | auto_umap = AutoUMAP(
169 | args,
170 | logdir=os.path.join(args.auto_umap_dir, args.method),
171 | frequency=args.auto_umap_frequency,
172 | )
173 | callbacks.append(auto_umap)
174 |
175 | # 1.7 will deprecate resume_from_checkpoint, but for the moment
176 | # the argument is the same, but we need to pass it as ckpt_path to trainer.fit
177 | ckpt_path = None
178 |
179 |
180 | trainer = Trainer.from_argparse_args(
181 | args,
182 | logger=wandb_logger if args.wandb else None,
183 | callbacks=callbacks,
184 | enable_checkpointing=False,
185 | )
186 |
187 | # File backup
188 | if args.wandb:
189 | experimentdir = f"code/{args.method}_{args.project}_{args.name}_{trainer.logger.version}"
190 | args.codepath = experimentdir
191 | else:
192 | experimentdir = f"code/{args.method}_{args.project}_{args.name}_test"
193 |
194 | if not os.path.exists("code"):
195 | os.mkdir("code")
196 |
197 | if os.path.exists(experimentdir):
198 | print(experimentdir + ' : exists. overwrite it.')
199 | shutil.rmtree(experimentdir)
200 | os.mkdir(experimentdir)
201 | else:
202 | os.mkdir(experimentdir)
203 |
204 | shutil.copytree(f"solo", os.path.join(experimentdir, 'solo'))
205 | shutil.copytree(f"bash_files", os.path.join(experimentdir, 'bash_files'))
206 | shutil.copyfile(f"main_pretrain_AdvTraining.py", os.path.join(experimentdir, 'main_pretrain_AdvTraining.py'))
207 | shutil.copyfile(f"adv_slf.py", os.path.join(experimentdir, 'adv_slf.py'))
208 |
209 |
210 | if args.dali:
211 | trainer.fit(model, val_dataloaders=val_loader, ckpt_path=ckpt_path)
212 | else:
213 | trainer.fit(model, train_loader, val_loader, ckpt_path=ckpt_path)
214 |
215 |
216 |
217 | if __name__ == "__main__":
218 | main()
219 |
--------------------------------------------------------------------------------
/solo/models/wide_resnet.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | # Adapted from timm https://github.com/rwightman/pytorch-image-models/blob/master/timm/
21 |
22 |
23 | import torch
24 | import torch.nn as nn
25 | import torch.nn.functional as F
26 | from timm.models.registry import register_model
27 |
28 |
29 | def normalize_fn(tensor, mean, std):
30 | """Differentiable version of torchvision.functional.normalize"""
31 | # here we assume the color channel is in at dim=1
32 | mean = mean[None, :, None, None]
33 | std = std[None, :, None, None]
34 | # import ipdb; ipdb.set_trace()
35 | return tensor.sub(mean).div(std)
36 |
37 |
38 | class NormalizeByChannelMeanStd(nn.Module):
39 | def __init__(self, mean, std):
40 | super(NormalizeByChannelMeanStd, self).__init__()
41 | if not isinstance(mean, torch.Tensor):
42 | mean = torch.tensor(mean)
43 | if not isinstance(std, torch.Tensor):
44 | std = torch.tensor(std)
45 | self.register_buffer("mean", mean)
46 | self.register_buffer("std", std)
47 |
48 | def forward(self, tensor):
49 | # self.mean = self.mean.to("cuda")
50 | # self.std = self.std.to("cuda")
51 |
52 | return normalize_fn(tensor, self.mean, self.std)
53 |
54 | def extra_repr(self):
55 | return 'mean={}, std={}'.format(self.mean, self.std)
56 |
57 | class WideResnetBasicBlock(nn.Module):
58 | def __init__(
59 | self, in_planes, out_planes, stride, drop_rate=0.0, activate_before_residual=False
60 | ):
61 | super(WideResnetBasicBlock, self).__init__()
62 | self.bn1 = nn.BatchNorm2d(in_planes, momentum=0.001, eps=0.001)
63 | self.relu1 = nn.LeakyReLU(negative_slope=0.1, inplace=False)
64 | self.conv1 = nn.Conv2d(
65 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True
66 | )
67 | self.bn2 = nn.BatchNorm2d(out_planes, momentum=0.001, eps=0.001)
68 | self.relu2 = nn.LeakyReLU(negative_slope=0.1, inplace=False)
69 | self.conv2 = nn.Conv2d(
70 | out_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=True
71 | )
72 | self.drop_rate = drop_rate
73 | self.equalInOut = in_planes == out_planes
74 | self.convShortcut = (
75 | (not self.equalInOut)
76 | and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=True)
77 | or None
78 | )
79 | self.activate_before_residual = activate_before_residual
80 |
81 | def forward(self, x):
82 | if not self.equalInOut and self.activate_before_residual:
83 | x = self.relu1(self.bn1(x))
84 | else:
85 | out = self.relu1(self.bn1(x))
86 | out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
87 | if self.drop_rate > 0:
88 | out = F.dropout(out, p=self.drop_rate, training=self.training)
89 | out = self.conv2(out)
90 | return torch.add(x if self.equalInOut else self.convShortcut(x), out)
91 |
92 |
93 | class WideResnetNetworkBlock(nn.Module):
94 | def __init__(
95 | self,
96 | nb_layers,
97 | in_planes,
98 | out_planes,
99 | block,
100 | stride,
101 | drop_rate=0.0,
102 | activate_before_residual=False,
103 | ):
104 | super(WideResnetNetworkBlock, self).__init__()
105 | self.layer = self._make_layer(
106 | block, in_planes, out_planes, nb_layers, stride, drop_rate, activate_before_residual
107 | )
108 |
109 | def _make_layer(
110 | self, block, in_planes, out_planes, nb_layers, stride, drop_rate, activate_before_residual
111 | ):
112 | layers = []
113 | for i in range(int(nb_layers)):
114 | layers.append(
115 | block(
116 | i == 0 and in_planes or out_planes,
117 | out_planes,
118 | i == 0 and stride or 1,
119 | drop_rate,
120 | activate_before_residual,
121 | )
122 | )
123 | return nn.Sequential(*layers)
124 |
125 | def forward(self, x):
126 | return self.layer(x)
127 |
128 |
129 | class WideResNet(nn.Module):
130 | def __init__(self, first_stride=1, depth=28, widen_factor=2, drop_rate=0.0, **kwargs):
131 | super(WideResNet, self).__init__()
132 | channels = [16, 16 * widen_factor, 32 * widen_factor, 64 * widen_factor]
133 | self.num_features = channels[-1]
134 | self.inplanes = self.num_features
135 | assert (depth - 4) % 6 == 0
136 | n = (depth - 4) / 6
137 |
138 | self.normalize = NormalizeByChannelMeanStd(
139 | mean=[0.4914, 0.4822, 0.4465], std=[0.247, 0.243, 0.261])
140 |
141 | block = WideResnetBasicBlock
142 | # 1st conv before any network block
143 | self.conv1 = nn.Conv2d(3, channels[0], kernel_size=3, stride=1, padding=1, bias=True)
144 | # import ipdb; ipdb.set_trace()
145 | # 1st block
146 | self.block1 = WideResnetNetworkBlock(
147 | n,
148 | channels[0],
149 | channels[1],
150 | block,
151 | first_stride,
152 | drop_rate,
153 | activate_before_residual=True,
154 | )
155 | # 2nd block
156 | self.block2 = WideResnetNetworkBlock(n, channels[1], channels[2], block, 2, drop_rate)
157 | # 3rd block
158 | self.block3 = WideResnetNetworkBlock(n, channels[2], channels[3], block, 2, drop_rate)
159 | # global average pooling
160 | self.bn1 = nn.BatchNorm2d(channels[3], momentum=0.001, eps=0.001)
161 | self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=False)
162 |
163 | for m in self.modules():
164 | if isinstance(m, nn.Conv2d):
165 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="leaky_relu")
166 | elif isinstance(m, nn.BatchNorm2d):
167 | m.weight.data.fill_(1)
168 | m.bias.data.zero_()
169 | elif isinstance(m, nn.Linear):
170 | nn.init.xavier_normal_(m.weight.data)
171 | if m.bias is not None:
172 | m.bias.data.zero_()
173 |
174 | def forward(self, x):
175 | x = self.normalize(x)
176 | out = self.conv1(x)
177 | out = self.block1(out)
178 | out = self.block2(out)
179 | out = self.block3(out)
180 | out = self.relu(self.bn1(out))
181 | out = F.adaptive_avg_pool2d(out, 1)
182 | x = out.view(-1, self.num_features)
183 | return x
184 |
185 |
186 | @register_model
187 | def wide_resnet28w2(**kwargs):
188 | encoder = WideResNet(depth=28, widen_factor=2, **kwargs)
189 | return encoder
190 |
191 |
192 | @register_model
193 | def wide_resnet28w8(**kwargs):
194 | encoder = WideResNet(depth=28, widen_factor=8, **kwargs)
195 | return encoder
196 |
197 |
198 | @register_model
199 | def wide_resnet28w10(**kwargs):
200 | encoder = WideResNet(depth=28, widen_factor=10, **kwargs)
201 | return encoder
202 |
--------------------------------------------------------------------------------
/solo/utils/knn.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | from typing import Tuple
21 |
22 | import torch
23 | import torch.nn.functional as F
24 | from torchmetrics.metric import Metric
25 |
26 |
27 | class WeightedKNNClassifier(Metric):
28 | def __init__(
29 | self,
30 | k: int = 20,
31 | T: float = 0.07,
32 | max_distance_matrix_size: int = int(5e6),
33 | distance_fx: str = "cosine",
34 | epsilon: float = 0.00001,
35 | dist_sync_on_step: bool = False,
36 | ):
37 | """Implements the weighted k-NN classifier used for evaluation.
38 |
39 | Args:
40 | k (int, optional): number of neighbors. Defaults to 20.
41 | T (float, optional): temperature for the exponential. Only used with cosine
42 | distance. Defaults to 0.07.
43 | max_distance_matrix_size (int, optional): maximum number of elements in the
44 | distance matrix. Defaults to 5e6.
45 | distance_fx (str, optional): Distance function. Accepted arguments: "cosine" or
46 | "euclidean". Defaults to "cosine".
47 | epsilon (float, optional): Small value for numerical stability. Only used with
48 | euclidean distance. Defaults to 0.00001.
49 | dist_sync_on_step (bool, optional): whether to sync distributed values at every
50 | step. Defaults to False.
51 | """
52 |
53 | super().__init__(dist_sync_on_step=dist_sync_on_step, compute_on_step=False)
54 |
55 | self.k = k
56 | self.T = T
57 | self.max_distance_matrix_size = max_distance_matrix_size
58 | self.distance_fx = distance_fx
59 | self.epsilon = epsilon
60 |
61 | self.add_state("train_features", default=[], persistent=False)
62 | self.add_state("train_targets", default=[], persistent=False)
63 | self.add_state("test_features", default=[], persistent=False)
64 | self.add_state("test_targets", default=[], persistent=False)
65 |
66 | def update(
67 | self,
68 | train_features: torch.Tensor = None,
69 | train_targets: torch.Tensor = None,
70 | test_features: torch.Tensor = None,
71 | test_targets: torch.Tensor = None,
72 | ):
73 | """Updates the memory banks. If train (test) features are passed as input, the
74 | corresponding train (test) targets must be passed as well.
75 |
76 | Args:
77 | train_features (torch.Tensor, optional): a batch of train features. Defaults to None.
78 | train_targets (torch.Tensor, optional): a batch of train targets. Defaults to None.
79 | test_features (torch.Tensor, optional): a batch of test features. Defaults to None.
80 | test_targets (torch.Tensor, optional): a batch of test targets. Defaults to None.
81 | """
82 | assert (train_features is None) == (train_targets is None)
83 | assert (test_features is None) == (test_targets is None)
84 |
85 | if train_features is not None:
86 | assert train_features.size(0) == train_targets.size(0)
87 | self.train_features.append(train_features.detach())
88 | self.train_targets.append(train_targets.detach())
89 |
90 | if test_features is not None:
91 | assert test_features.size(0) == test_targets.size(0)
92 | self.test_features.append(test_features.detach())
93 | self.test_targets.append(test_targets.detach())
94 |
95 | @torch.no_grad()
96 | def compute(self) -> Tuple[float]:
97 | """Computes weighted k-NN accuracy @1 and @5. If cosine distance is selected,
98 | the weight is computed using the exponential of the temperature scaled cosine
99 | distance of the samples. If euclidean distance is selected, the weight corresponds
100 | to the inverse of the euclidean distance.
101 |
102 | Returns:
103 | Tuple[float]: k-NN accuracy @1 and @5.
104 | """
105 |
106 | train_features = torch.cat(self.train_features)
107 | train_targets = torch.cat(self.train_targets)
108 | test_features = torch.cat(self.test_features)
109 | test_targets = torch.cat(self.test_targets)
110 |
111 | if self.distance_fx == "cosine":
112 | train_features = F.normalize(train_features)
113 | test_features = F.normalize(test_features)
114 |
115 | num_classes = torch.unique(test_targets).numel()
116 | num_train_images = train_targets.size(0)
117 | num_test_images = test_targets.size(0)
118 | num_train_images = train_targets.size(0)
119 | chunk_size = min(
120 | max(1, self.max_distance_matrix_size // num_train_images),
121 | num_test_images,
122 | )
123 | k = min(self.k, num_train_images)
124 |
125 | top1, top5, total = 0.0, 0.0, 0
126 | retrieval_one_hot = torch.zeros(k, num_classes).to(train_features.device)
127 | for idx in range(0, num_test_images, chunk_size):
128 | # get the features for test images
129 | features = test_features[idx : min((idx + chunk_size), num_test_images), :]
130 | targets = test_targets[idx : min((idx + chunk_size), num_test_images)]
131 | batch_size = targets.size(0)
132 |
133 | # calculate the dot product and compute top-k neighbors
134 | if self.distance_fx == "cosine":
135 | similarities = torch.mm(features, train_features.t())
136 | elif self.distance_fx == "euclidean":
137 | similarities = 1 / (torch.cdist(features, train_features) + self.epsilon)
138 | else:
139 | raise NotImplementedError
140 |
141 | similarities, indices = similarities.topk(k, largest=True, sorted=True)
142 | candidates = train_targets.view(1, -1).expand(batch_size, -1)
143 | retrieved_neighbors = torch.gather(candidates, 1, indices)
144 |
145 | retrieval_one_hot.resize_(batch_size * k, num_classes).zero_()
146 | retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1)
147 |
148 | if self.distance_fx == "cosine":
149 | similarities = similarities.clone().div_(self.T).exp_()
150 |
151 | probs = torch.sum(
152 | torch.mul(
153 | retrieval_one_hot.view(batch_size, -1, num_classes),
154 | similarities.view(batch_size, -1, 1),
155 | ),
156 | 1,
157 | )
158 | _, predictions = probs.sort(1, True)
159 |
160 | # find the predictions that match the target
161 | correct = predictions.eq(targets.data.view(-1, 1))
162 | top1 = top1 + correct.narrow(1, 0, 1).sum().item()
163 | top5 = (
164 | top5 + correct.narrow(1, 0, min(5, k, correct.size(-1))).sum().item()
165 | ) # top5 does not make sense if k < 5
166 | total += targets.size(0)
167 |
168 | top1 = top1 * 100.0 / total
169 | top5 = top5 * 100.0 / total
170 |
171 | self.reset()
172 |
173 | return top1, top5
174 |
--------------------------------------------------------------------------------
/solo/utils/whitening.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 |
21 | from typing import Optional
22 |
23 | import torch
24 | import torch.nn as nn
25 | from torch.cuda.amp import custom_fwd
26 | from torch.nn.functional import conv2d
27 |
28 |
29 | class Whitening2d(nn.Module):
30 | def __init__(self, output_dim: int, eps: float = 0.0):
31 | """Layer that computes hard whitening for W-MSE using the Cholesky decomposition.
32 |
33 | Args:
34 | output_dim (int): number of dimension of projected features.
35 | eps (float, optional): eps for numerical stability in Cholesky decomposition. Defaults
36 | to 0.0.
37 | """
38 |
39 | super(Whitening2d, self).__init__()
40 | self.output_dim = output_dim
41 | self.eps = eps
42 |
43 | @custom_fwd(cast_inputs=torch.float32)
44 | def forward(self, x: torch.Tensor) -> torch.Tensor:
45 | """Performs whitening using the Cholesky decomposition.
46 |
47 | Args:
48 | x (torch.Tensor): a batch or slice of projected features.
49 |
50 | Returns:
51 | torch.Tensor: a batch or slice of whitened features.
52 | """
53 |
54 | x = x.unsqueeze(2).unsqueeze(3)
55 | m = x.mean(0).view(self.output_dim, -1).mean(-1).view(1, -1, 1, 1)
56 | xn = x - m
57 |
58 | T = xn.permute(1, 0, 2, 3).contiguous().view(self.output_dim, -1)
59 | f_cov = torch.mm(T, T.permute(1, 0)) / (T.shape[-1] - 1)
60 |
61 | eye = torch.eye(self.output_dim).type(f_cov.type())
62 |
63 | f_cov_shrinked = (1 - self.eps) * f_cov + self.eps * eye
64 |
65 | inv_sqrt = torch.triangular_solve(eye, torch.cholesky(f_cov_shrinked), upper=False)[0]
66 | inv_sqrt = inv_sqrt.contiguous().view(self.output_dim, self.output_dim, 1, 1)
67 |
68 | decorrelated = conv2d(xn, inv_sqrt)
69 |
70 | return decorrelated.squeeze(2).squeeze(2)
71 |
72 |
73 | class iterative_normalization_py(torch.autograd.Function):
74 | @staticmethod
75 | def forward(ctx, *args) -> torch.Tensor:
76 | X, running_mean, running_wmat, nc, ctx.T, eps, momentum, training = args
77 |
78 | # change NxCxHxW to (G x D) x(NxHxW), i.e., g*d*m
79 | ctx.g = X.size(1) // nc
80 | x = X.transpose(0, 1).contiguous().view(ctx.g, nc, -1)
81 | _, d, m = x.size()
82 | saved = []
83 | if training:
84 | # calculate centered activation by subtracted mini-batch mean
85 | mean = x.mean(-1, keepdim=True)
86 | xc = x - mean
87 | saved.append(xc)
88 | # calculate covariance matrix
89 | P = [None] * (ctx.T + 1)
90 | P[0] = torch.eye(d).to(X).expand(ctx.g, d, d)
91 | Sigma = torch.baddbmm(
92 | beta=eps,
93 | input=P[0],
94 | alpha=1.0 / m,
95 | batch1=xc,
96 | batch2=xc.transpose(1, 2),
97 | )
98 | # reciprocal of trace of Sigma: shape [g, 1, 1]
99 | rTr = (Sigma * P[0]).sum((1, 2), keepdim=True).reciprocal_()
100 | saved.append(rTr)
101 | Sigma_N = Sigma * rTr
102 | saved.append(Sigma_N)
103 | for k in range(ctx.T):
104 | P[k + 1] = torch.baddbmm(
105 | beta=1.5,
106 | input=P[k],
107 | alpha=-0.5,
108 | batch1=torch.matrix_power(P[k], 3),
109 | batch2=Sigma_N,
110 | )
111 | saved.extend(P)
112 | wm = P[ctx.T].mul_(
113 | rTr.sqrt()
114 | ) # whiten matrix: the matrix inverse of Sigma, i.e., Sigma^{-1/2}
115 |
116 | running_mean.copy_(momentum * mean + (1.0 - momentum) * running_mean)
117 | running_wmat.copy_(momentum * wm + (1.0 - momentum) * running_wmat)
118 | else:
119 | xc = x - running_mean
120 | wm = running_wmat
121 | xn = wm.matmul(xc)
122 | Xn = xn.view(X.size(1), X.size(0), *X.size()[2:]).transpose(0, 1).contiguous()
123 | ctx.save_for_backward(*saved)
124 | return Xn
125 |
126 | @staticmethod
127 | def backward(ctx, *grad_outputs):
128 | (grad,) = grad_outputs
129 | saved = ctx.saved_tensors
130 | if len(saved) == 0:
131 | return None, None, None, None, None, None, None, None
132 |
133 | xc = saved[0] # centered input
134 | rTr = saved[1] # trace of Sigma
135 | sn = saved[2].transpose(-2, -1) # normalized Sigma
136 | P = saved[3:] # middle result matrix,
137 | g, d, m = xc.size()
138 |
139 | g_ = grad.transpose(0, 1).contiguous().view_as(xc)
140 | g_wm = g_.matmul(xc.transpose(-2, -1))
141 | g_P = g_wm * rTr.sqrt()
142 | wm = P[ctx.T]
143 | g_sn = 0
144 | for k in range(ctx.T, 1, -1):
145 | P[k - 1].transpose_(-2, -1)
146 | P2 = P[k - 1].matmul(P[k - 1])
147 | g_sn += P2.matmul(P[k - 1]).matmul(g_P)
148 | g_tmp = g_P.matmul(sn)
149 | g_P.baddbmm_(beta=1.5, alpha=-0.5, batch1=g_tmp, batch2=P2)
150 | g_P.baddbmm_(beta=1, alpha=-0.5, batch1=P2, batch2=g_tmp)
151 | g_P.baddbmm_(beta=1, alpha=-0.5, batch1=P[k - 1].matmul(g_tmp), batch2=P[k - 1])
152 | g_sn += g_P
153 | g_tr = ((-sn.matmul(g_sn) + g_wm.transpose(-2, -1).matmul(wm)) * P[0]).sum(
154 | (1, 2), keepdim=True
155 | ) * P[0]
156 | g_sigma = (g_sn + g_sn.transpose(-2, -1) + 2.0 * g_tr) * (-0.5 / m * rTr)
157 | g_x = torch.baddbmm(wm.matmul(g_ - g_.mean(-1, keepdim=True)), g_sigma, xc)
158 | grad_input = (
159 | g_x.view(grad.size(1), grad.size(0), *grad.size()[2:]).transpose(0, 1).contiguous()
160 | )
161 | return grad_input, None, None, None, None, None, None, None
162 |
163 |
164 | class IterNorm(torch.nn.Module):
165 | def __init__(
166 | self,
167 | num_features: int,
168 | num_groups: int = 64,
169 | num_channels: Optional[int] = None,
170 | T: int = 5,
171 | dim: int = 2,
172 | eps: float = 1.0e-5,
173 | momentum: float = 0.1,
174 | affine: bool = True,
175 | ):
176 | super(IterNorm, self).__init__()
177 | # assert dim == 4, 'IterNorm does not support 2D'
178 | self.T = T
179 | self.eps = eps
180 | self.momentum = momentum
181 | self.num_features = num_features
182 | self.affine = affine
183 | self.dim = dim
184 | if num_channels is None:
185 | num_channels = (num_features - 1) // num_groups + 1
186 | num_groups = num_features // num_channels
187 | while num_features % num_channels != 0:
188 | num_channels //= 2
189 | num_groups = num_features // num_channels
190 | assert (
191 | num_groups > 0 and num_features % num_groups == 0
192 | ), f"num features={num_features}, num groups={num_groups}"
193 | self.num_groups = num_groups
194 | self.num_channels = num_channels
195 | shape = [1] * dim
196 | shape[1] = self.num_features
197 | if self.affine:
198 | self.weight = nn.Parameter(torch.Tensor(*shape))
199 | self.bias = nn.Parameter(torch.Tensor(*shape))
200 | else:
201 | self.register_parameter("weight", None)
202 | self.register_parameter("bias", None)
203 |
204 | self.register_buffer("running_mean", torch.zeros(num_groups, num_channels, 1))
205 | # running whiten matrix
206 | self.register_buffer(
207 | "running_wm",
208 | torch.eye(num_channels).expand(num_groups, num_channels, num_channels).clone(),
209 | )
210 |
211 | self.reset_parameters()
212 |
213 | def reset_parameters(self):
214 | if self.affine:
215 | torch.nn.init.ones_(self.weight)
216 | torch.nn.init.zeros_(self.bias)
217 |
218 | @custom_fwd(cast_inputs=torch.float32)
219 | def forward(self, X: torch.Tensor) -> torch.Tensor:
220 | X_hat = iterative_normalization_py.apply(
221 | X,
222 | self.running_mean,
223 | self.running_wm,
224 | self.num_channels,
225 | self.T,
226 | self.eps,
227 | self.momentum,
228 | self.training,
229 | )
230 | # affine
231 | if self.affine:
232 | return X_hat * self.weight + self.bias
233 |
234 | return X_hat
235 |
236 | def extra_repr(self):
237 | return (
238 | f"{self.num_features}, num_channels={self.num_channels}, T={self.T}, eps={self.eps}, "
239 | "momentum={momentum}, affine={affine}"
240 | )
241 |
--------------------------------------------------------------------------------
/solo/args/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import os
21 | from argparse import Namespace
22 | from contextlib import suppress
23 |
24 |
25 | N_CLASSES_PER_DATASET = {
26 | "cifar10": 10,
27 | "cifar100": 100,
28 | "stl10": 10,
29 | "imagenet": 1000,
30 | "imagenet100": 100,
31 | }
32 |
33 |
34 | def additional_setup_pretrain(args: Namespace):
35 | """Provides final setup for pretraining to non-user given parameters by changing args.
36 |
37 | Parsers arguments to extract the number of classes of a dataset, create
38 | transformations kwargs, correctly parse gpus, identify if a cifar dataset
39 | is being used and adjust the lr.
40 |
41 | Args:
42 | args (Namespace): object that needs to contain, at least:
43 | - dataset: dataset name.
44 | - brightness, contrast, saturation, hue, min_scale: required augmentations
45 | settings.
46 | - dali: flag to use dali.
47 | - optimizer: optimizer name being used.
48 | - gpus: list of gpus to use.
49 | - lr: learning rate.
50 |
51 | [optional]
52 | - gaussian_prob, solarization_prob: optional augmentations settings.
53 | """
54 |
55 | if args.dataset in N_CLASSES_PER_DATASET:
56 | args.num_classes = N_CLASSES_PER_DATASET[args.dataset]
57 | else:
58 | # hack to maintain the current pipeline
59 | # even if the custom dataset doesn't have any labels
60 | dir_path = args.data_dir / args.train_dir
61 | args.num_classes = max(
62 | 1,
63 | len([entry.name for entry in os.scandir(dir_path) if entry.is_dir]),
64 | )
65 |
66 | unique_augs = max(
67 | len(p)
68 | for p in [
69 | args.brightness,
70 | args.contrast,
71 | args.saturation,
72 | args.hue,
73 | args.color_jitter_prob,
74 | args.gray_scale_prob,
75 | args.horizontal_flip_prob,
76 | args.gaussian_prob,
77 | args.solarization_prob,
78 | args.crop_size,
79 | args.min_scale,
80 | args.max_scale,
81 | ]
82 | )
83 | assert len(args.num_crops_per_aug) == unique_augs
84 |
85 | # assert that either all unique augmentation pipelines have a unique
86 | # parameter or that a single parameter is replicated to all pipelines
87 | for p in [
88 | "brightness",
89 | "contrast",
90 | "saturation",
91 | "hue",
92 | "color_jitter_prob",
93 | "gray_scale_prob",
94 | "horizontal_flip_prob",
95 | "gaussian_prob",
96 | "solarization_prob",
97 | "crop_size",
98 | "min_scale",
99 | "max_scale",
100 | ]:
101 | values = getattr(args, p)
102 | n = len(values)
103 | assert n == unique_augs or n == 1
104 |
105 | if n == 1:
106 | setattr(args, p, getattr(args, p) * unique_augs)
107 |
108 | args.unique_augs = unique_augs
109 |
110 | if unique_augs > 1:
111 | args.transform_kwargs = [
112 | dict(
113 | brightness=brightness,
114 | contrast=contrast,
115 | saturation=saturation,
116 | hue=hue,
117 | color_jitter_prob=color_jitter_prob,
118 | gray_scale_prob=gray_scale_prob,
119 | horizontal_flip_prob=horizontal_flip_prob,
120 | gaussian_prob=gaussian_prob,
121 | solarization_prob=solarization_prob,
122 | crop_size=crop_size,
123 | min_scale=min_scale,
124 | max_scale=max_scale,
125 | )
126 | for (
127 | brightness,
128 | contrast,
129 | saturation,
130 | hue,
131 | color_jitter_prob,
132 | gray_scale_prob,
133 | horizontal_flip_prob,
134 | gaussian_prob,
135 | solarization_prob,
136 | crop_size,
137 | min_scale,
138 | max_scale,
139 | ) in zip(
140 | args.brightness,
141 | args.contrast,
142 | args.saturation,
143 | args.hue,
144 | args.color_jitter_prob,
145 | args.gray_scale_prob,
146 | args.horizontal_flip_prob,
147 | args.gaussian_prob,
148 | args.solarization_prob,
149 | args.crop_size,
150 | args.min_scale,
151 | args.max_scale,
152 | )
153 | ]
154 |
155 | # find number of big/small crops
156 | big_size = args.crop_size[0]
157 | num_large_crops = num_small_crops = 0
158 | for size, n_crops in zip(args.crop_size, args.num_crops_per_aug):
159 | if big_size == size:
160 | num_large_crops += n_crops
161 | else:
162 | num_small_crops += n_crops
163 | args.num_large_crops = num_large_crops
164 | args.num_small_crops = num_small_crops
165 | else:
166 | args.transform_kwargs = dict(
167 | brightness=args.brightness[0],
168 | contrast=args.contrast[0],
169 | saturation=args.saturation[0],
170 | hue=args.hue[0],
171 | color_jitter_prob=args.color_jitter_prob[0],
172 | gray_scale_prob=args.gray_scale_prob[0],
173 | horizontal_flip_prob=args.horizontal_flip_prob[0],
174 | gaussian_prob=args.gaussian_prob[0],
175 | solarization_prob=args.solarization_prob[0],
176 | crop_size=args.crop_size[0],
177 | min_scale=args.min_scale[0],
178 | max_scale=args.max_scale[0],
179 | )
180 |
181 | # find number of big/small crops
182 | args.num_large_crops = args.num_crops_per_aug[0]
183 | args.num_small_crops = 0
184 |
185 | # add support for custom mean and std
186 | if args.dataset == "custom":
187 | if isinstance(args.transform_kwargs, dict):
188 | args.transform_kwargs["mean"] = args.mean
189 | args.transform_kwargs["std"] = args.std
190 | else:
191 | for kwargs in args.transform_kwargs:
192 | kwargs["mean"] = args.mean
193 | kwargs["std"] = args.std
194 |
195 | # create backbone-specific arguments
196 | args.backbone_args = {"cifar": args.dataset in ["cifar10", "cifar100"]}
197 | if "resnet" in args.backbone:
198 | args.backbone_args["zero_init_residual"] = args.zero_init_residual
199 | else:
200 | # dataset related for all transformers
201 | crop_size = args.crop_size[0]
202 | args.backbone_args["img_size"] = crop_size
203 | if "vit" in args.backbone:
204 | args.backbone_args["patch_size"] = args.patch_size
205 |
206 | with suppress(AttributeError):
207 | del args.zero_init_residual
208 | with suppress(AttributeError):
209 | del args.patch_size
210 |
211 | if args.dali:
212 | assert args.dataset in ["imagenet100", "imagenet", "custom"]
213 |
214 | args.extra_optimizer_args = {}
215 | if args.optimizer == "sgd":
216 | args.extra_optimizer_args["momentum"] = 0.9
217 |
218 | if isinstance(args.gpus, int):
219 | args.gpus = [args.gpus]
220 | elif isinstance(args.gpus, str):
221 | args.gpus = [int(gpu) for gpu in args.gpus.split(",") if gpu]
222 |
223 | # adjust lr according to batch size
224 | args.lr = args.lr * args.batch_size * len(args.gpus) / 256
225 |
226 |
227 | def additional_setup_linear(args: Namespace):
228 | """Provides final setup for linear evaluation to non-user given parameters by changing args.
229 |
230 | Parsers arguments to extract the number of classes of a dataset, correctly parse gpus, identify
231 | if a cifar dataset is being used and adjust the lr.
232 |
233 | Args:
234 | args: Namespace object that needs to contain, at least:
235 | - dataset: dataset name.
236 | - optimizer: optimizer name being used.
237 | - gpus: list of gpus to use.
238 | - lr: learning rate.
239 | """
240 |
241 | if args.dataset in N_CLASSES_PER_DATASET:
242 | args.num_classes = N_CLASSES_PER_DATASET[args.dataset]
243 | else:
244 | # hack to maintain the current pipeline
245 | # even if the custom dataset doesn't have any labels
246 | dir_path = args.data_dir / args.train_dir
247 | args.num_classes = max(
248 | 1,
249 | len([entry.name for entry in os.scandir(dir_path) if entry.is_dir]),
250 | )
251 |
252 | # create backbone-specific arguments
253 | args.backbone_args = {"cifar": args.dataset in ["cifar10", "cifar100"]}
254 |
255 | if "resnet" not in args.backbone:
256 | # dataset related for all transformers
257 | crop_size = args.crop_size[0]
258 | args.backbone_args["img_size"] = crop_size
259 |
260 | if "vit" in args.backbone:
261 | args.backbone_args["patch_size"] = args.patch_size
262 |
263 | with suppress(AttributeError):
264 | del args.patch_size
265 |
266 | if args.dali:
267 | assert args.dataset in ["imagenet100", "imagenet", "custom"]
268 |
269 | args.extra_optimizer_args = {}
270 | if args.optimizer == "sgd":
271 | args.extra_optimizer_args["momentum"] = 0.9
272 |
273 | if isinstance(args.gpus, int):
274 | args.gpus = [args.gpus]
275 | elif isinstance(args.gpus, str):
276 | args.gpus = [int(gpu) for gpu in args.gpus.split(",") if gpu]
277 |
--------------------------------------------------------------------------------
/solo/utils/classification_dataloader.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import os
21 | from pathlib import Path
22 | from typing import Callable, Optional, Tuple, Union
23 |
24 | import torchvision
25 | from torch import nn
26 | from torch.utils.data import DataLoader, Dataset
27 | from torchvision import transforms
28 | from torchvision.datasets import STL10, ImageFolder
29 |
30 |
31 | def build_custom_pipeline():
32 | """Builds augmentation pipelines for custom data.
33 | If you want to do exoteric augmentations, you can just re-write this function.
34 | Needs to return a dict with the same structure.
35 | """
36 |
37 | pipeline = {
38 | "T_train": transforms.Compose(
39 | [
40 | transforms.RandomResizedCrop(size=224, scale=(0.08, 1.0)),
41 | transforms.RandomHorizontalFlip(),
42 | transforms.ToTensor(),
43 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.228, 0.224, 0.225)),
44 | ]
45 | ),
46 | "T_val": transforms.Compose(
47 | [
48 | transforms.Resize(256), # resize shorter
49 | transforms.CenterCrop(224), # take center crop
50 | transforms.ToTensor(),
51 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.228, 0.224, 0.225)),
52 | ]
53 | ),
54 | }
55 | return pipeline
56 |
57 |
58 | def prepare_transforms(dataset: str) -> Tuple[nn.Module, nn.Module]:
59 | """Prepares pre-defined train and test transformation pipelines for some datasets.
60 |
61 | Args:
62 | dataset (str): dataset name.
63 |
64 | Returns:
65 | Tuple[nn.Module, nn.Module]: training and validation transformation pipelines.
66 | """
67 |
68 | cifar_pipeline = {
69 | "T_train": transforms.Compose(
70 | [
71 | transforms.RandomResizedCrop(size=32, scale=(0.08, 1.0)),
72 | transforms.RandomHorizontalFlip(),
73 | transforms.ToTensor(),
74 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
75 | ]
76 | ),
77 | "T_val": transforms.Compose(
78 | [
79 | transforms.ToTensor(),
80 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
81 | ]
82 | ),
83 | }
84 |
85 | stl_pipeline = {
86 | "T_train": transforms.Compose(
87 | [
88 | transforms.RandomResizedCrop(size=96, scale=(0.08, 1.0)),
89 | transforms.RandomHorizontalFlip(),
90 | transforms.ToTensor(),
91 | transforms.Normalize((0.4914, 0.4823, 0.4466), (0.247, 0.243, 0.261)),
92 | ]
93 | ),
94 | "T_val": transforms.Compose(
95 | [
96 | transforms.Resize((96, 96)),
97 | transforms.ToTensor(),
98 | transforms.Normalize((0.4914, 0.4823, 0.4466), (0.247, 0.243, 0.261)),
99 | ]
100 | ),
101 | }
102 |
103 | imagenet_pipeline = {
104 | "T_train": transforms.Compose(
105 | [
106 | transforms.RandomResizedCrop(size=224, scale=(0.08, 1.0)),
107 | transforms.RandomHorizontalFlip(),
108 | transforms.ToTensor(),
109 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.228, 0.224, 0.225)),
110 | ]
111 | ),
112 | "T_val": transforms.Compose(
113 | [
114 | transforms.Resize(256), # resize shorter
115 | transforms.CenterCrop(224), # take center crop
116 | transforms.ToTensor(),
117 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.228, 0.224, 0.225)),
118 | ]
119 | ),
120 | }
121 |
122 | custom_pipeline = build_custom_pipeline()
123 |
124 | pipelines = {
125 | "cifar10": cifar_pipeline,
126 | "cifar100": cifar_pipeline,
127 | "stl10": stl_pipeline,
128 | "imagenet100": imagenet_pipeline,
129 | "imagenet": imagenet_pipeline,
130 | "custom": custom_pipeline,
131 | }
132 |
133 | assert dataset in pipelines
134 |
135 | pipeline = pipelines[dataset]
136 | T_train = pipeline["T_train"]
137 | T_val = pipeline["T_val"]
138 |
139 | return T_train, T_val
140 |
141 |
142 | def prepare_datasets(
143 | dataset: str,
144 | T_train: Callable,
145 | T_val: Callable,
146 | data_dir: Optional[Union[str, Path]] = None,
147 | train_dir: Optional[Union[str, Path]] = None,
148 | val_dir: Optional[Union[str, Path]] = None,
149 | download: bool = True,
150 | ) -> Tuple[Dataset, Dataset]:
151 | """Prepares train and val datasets.
152 |
153 | Args:
154 | dataset (str): dataset name.
155 | T_train (Callable): pipeline of transformations for training dataset.
156 | T_val (Callable): pipeline of transformations for validation dataset.
157 | data_dir Optional[Union[str, Path]]: path where to download/locate the dataset.
158 | train_dir Optional[Union[str, Path]]: subpath where the training data is located.
159 | val_dir Optional[Union[str, Path]]: subpath where the validation data is located.
160 |
161 | Returns:
162 | Tuple[Dataset, Dataset]: training dataset and validation dataset.
163 | """
164 |
165 | if data_dir is None:
166 | sandbox_dir = Path(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
167 | data_dir = sandbox_dir / "datasets"
168 | else:
169 | data_dir = Path(data_dir)
170 |
171 | if train_dir is None:
172 | train_dir = Path(f"{dataset}/train")
173 | else:
174 | train_dir = Path(train_dir)
175 |
176 | if val_dir is None:
177 | val_dir = Path(f"{dataset}/val")
178 | else:
179 | val_dir = Path(val_dir)
180 |
181 | assert dataset in ["cifar10", "cifar100", "stl10", "imagenet", "imagenet100", "custom"]
182 |
183 | if dataset in ["cifar10", "cifar100"]:
184 | DatasetClass = vars(torchvision.datasets)[dataset.upper()]
185 | train_dataset = DatasetClass(
186 | data_dir / train_dir,
187 | train=True,
188 | download=download,
189 | transform=T_train,
190 | )
191 |
192 | val_dataset = DatasetClass(
193 | data_dir / val_dir,
194 | train=False,
195 | download=download,
196 | transform=T_val,
197 | )
198 |
199 | elif dataset == "stl10":
200 | train_dataset = STL10(
201 | data_dir / train_dir,
202 | split="train",
203 | download=True,
204 | transform=T_train,
205 | )
206 | val_dataset = STL10(
207 | data_dir / val_dir,
208 | split="test",
209 | download=download,
210 | transform=T_val,
211 | )
212 |
213 | elif dataset in ["imagenet", "imagenet100", "custom"]:
214 | train_dir = data_dir / train_dir
215 | val_dir = data_dir / val_dir
216 |
217 | train_dataset = ImageFolder(train_dir, T_train)
218 | val_dataset = ImageFolder(val_dir, T_val)
219 |
220 | return train_dataset, val_dataset
221 |
222 |
223 | def prepare_dataloaders(
224 | train_dataset: Dataset, val_dataset: Dataset, batch_size: int = 64, num_workers: int = 4
225 | ) -> Tuple[DataLoader, DataLoader]:
226 | """Wraps a train and a validation dataset with a DataLoader.
227 |
228 | Args:
229 | train_dataset (Dataset): object containing training data.
230 | val_dataset (Dataset): object containing validation data.
231 | batch_size (int): batch size.
232 | num_workers (int): number of parallel workers.
233 | Returns:
234 | Tuple[DataLoader, DataLoader]: training dataloader and validation dataloader.
235 | """
236 |
237 | train_loader = DataLoader(
238 | train_dataset,
239 | batch_size=batch_size,
240 | shuffle=True,
241 | num_workers=num_workers,
242 | pin_memory=True,
243 | drop_last=True,
244 | )
245 | val_loader = DataLoader(
246 | val_dataset,
247 | batch_size=batch_size,
248 | num_workers=num_workers,
249 | pin_memory=True,
250 | drop_last=False,
251 | )
252 | return train_loader, val_loader
253 |
254 |
255 | def prepare_data(
256 | dataset: str,
257 | data_dir: Optional[Union[str, Path]] = None,
258 | train_dir: Optional[Union[str, Path]] = None,
259 | val_dir: Optional[Union[str, Path]] = None,
260 | batch_size: int = 64,
261 | num_workers: int = 4,
262 | download: bool = True,
263 | ) -> Tuple[DataLoader, DataLoader]:
264 | """Prepares transformations, creates dataset objects and wraps them in dataloaders.
265 |
266 | Args:
267 | dataset (str): dataset name.
268 | data_dir (Optional[Union[str, Path]], optional): path where to download/locate the dataset.
269 | Defaults to None.
270 | train_dir (Optional[Union[str, Path]], optional): subpath where the
271 | training data is located. Defaults to None.
272 | val_dir (Optional[Union[str, Path]], optional): subpath where the
273 | validation data is located. Defaults to None.
274 | batch_size (int, optional): batch size. Defaults to 64.
275 | num_workers (int, optional): number of parallel workers. Defaults to 4.
276 |
277 | Returns:
278 | Tuple[DataLoader, DataLoader]: prepared training and validation dataloader;.
279 | """
280 |
281 | T_train, T_val = prepare_transforms(dataset)
282 | train_dataset, val_dataset = prepare_datasets(
283 | dataset,
284 | T_train,
285 | T_val,
286 | data_dir=data_dir,
287 | train_dir=train_dir,
288 | val_dir=val_dir,
289 | download=download,
290 | )
291 | train_loader, val_loader = prepare_dataloaders(
292 | train_dataset,
293 | val_dataset,
294 | batch_size=batch_size,
295 | num_workers=num_workers,
296 | )
297 | return train_loader, val_loader
298 |
--------------------------------------------------------------------------------
/solo/utils/classification_dataloader_AdvTraining.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import os
21 | from pathlib import Path
22 | from typing import Callable, Optional, Tuple, Union
23 |
24 | import torchvision
25 | from torch import nn
26 | from torch.utils.data import DataLoader, Dataset
27 | from torchvision import transforms
28 | from torchvision.datasets import STL10, ImageFolder
29 |
30 |
31 | def build_custom_pipeline():
32 | """Builds augmentation pipelines for custom data.
33 | If you want to do exoteric augmentations, you can just re-write this function.
34 | Needs to return a dict with the same structure.
35 | """
36 |
37 | pipeline = {
38 | "T_train": transforms.Compose(
39 | [
40 | transforms.RandomResizedCrop(size=224, scale=(0.08, 1.0)),
41 | transforms.RandomHorizontalFlip(),
42 | transforms.ToTensor(),
43 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.228, 0.224, 0.225)),
44 | ]
45 | ),
46 | "T_val": transforms.Compose(
47 | [
48 | transforms.Resize(256), # resize shorter
49 | transforms.CenterCrop(224), # take center crop
50 | transforms.ToTensor(),
51 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.228, 0.224, 0.225)),
52 | ]
53 | ),
54 | }
55 | return pipeline
56 |
57 |
58 | def prepare_transforms(dataset: str) -> Tuple[nn.Module, nn.Module]:
59 | """Prepares pre-defined train and test transformation pipelines for some datasets.
60 |
61 | Args:
62 | dataset (str): dataset name.
63 |
64 | Returns:
65 | Tuple[nn.Module, nn.Module]: training and validation transformation pipelines.
66 | """
67 |
68 | cifar_pipeline = {
69 | "T_train": transforms.Compose(
70 | [
71 | # transforms.RandomResizedCrop(size=32, scale=(0.08, 1.0)),
72 | transforms.RandomCrop(32, padding=4),
73 | transforms.RandomHorizontalFlip(),
74 | transforms.ToTensor(),
75 | # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
76 | ]
77 | ),
78 | "T_val": transforms.Compose(
79 | [
80 | transforms.ToTensor(),
81 | # transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
82 | ]
83 | ),
84 | }
85 |
86 | stl_pipeline = {
87 | "T_train": transforms.Compose(
88 | [
89 | transforms.RandomResizedCrop(size=96, scale=(0.08, 1.0)),
90 | transforms.RandomHorizontalFlip(),
91 | transforms.ToTensor(),
92 | transforms.Normalize((0.4914, 0.4823, 0.4466), (0.247, 0.243, 0.261)),
93 | ]
94 | ),
95 | "T_val": transforms.Compose(
96 | [
97 | transforms.Resize((96, 96)),
98 | transforms.ToTensor(),
99 | transforms.Normalize((0.4914, 0.4823, 0.4466), (0.247, 0.243, 0.261)),
100 | ]
101 | ),
102 | }
103 |
104 | imagenet_pipeline = {
105 | "T_train": transforms.Compose(
106 | [
107 | transforms.RandomResizedCrop(size=224, scale=(0.08, 1.0)),
108 | transforms.RandomHorizontalFlip(),
109 | transforms.ToTensor(),
110 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.228, 0.224, 0.225)),
111 | ]
112 | ),
113 | "T_val": transforms.Compose(
114 | [
115 | transforms.Resize(256), # resize shorter
116 | transforms.CenterCrop(224), # take center crop
117 | transforms.ToTensor(),
118 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.228, 0.224, 0.225)),
119 | ]
120 | ),
121 | }
122 |
123 | custom_pipeline = build_custom_pipeline()
124 |
125 | pipelines = {
126 | "cifar10": cifar_pipeline,
127 | "cifar100": cifar_pipeline,
128 | "stl10": stl_pipeline,
129 | "imagenet100": imagenet_pipeline,
130 | "imagenet": imagenet_pipeline,
131 | "custom": custom_pipeline,
132 | }
133 |
134 | assert dataset in pipelines
135 |
136 | pipeline = pipelines[dataset]
137 | T_train = pipeline["T_train"]
138 | T_val = pipeline["T_val"]
139 |
140 | return T_train, T_val
141 |
142 |
143 | def prepare_datasets(
144 | dataset: str,
145 | T_train: Callable,
146 | T_val: Callable,
147 | data_dir: Optional[Union[str, Path]] = None,
148 | train_dir: Optional[Union[str, Path]] = None,
149 | val_dir: Optional[Union[str, Path]] = None,
150 | download: bool = True,
151 | ) -> Tuple[Dataset, Dataset]:
152 | """Prepares train and val datasets.
153 |
154 | Args:
155 | dataset (str): dataset name.
156 | T_train (Callable): pipeline of transformations for training dataset.
157 | T_val (Callable): pipeline of transformations for validation dataset.
158 | data_dir Optional[Union[str, Path]]: path where to download/locate the dataset.
159 | train_dir Optional[Union[str, Path]]: subpath where the training data is located.
160 | val_dir Optional[Union[str, Path]]: subpath where the validation data is located.
161 |
162 | Returns:
163 | Tuple[Dataset, Dataset]: training dataset and validation dataset.
164 | """
165 |
166 | if data_dir is None:
167 | sandbox_dir = Path(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
168 | data_dir = sandbox_dir / "datasets"
169 | else:
170 | data_dir = Path(data_dir)
171 |
172 | if train_dir is None:
173 | train_dir = Path(f"{dataset}/train")
174 | else:
175 | train_dir = Path(train_dir)
176 |
177 | if val_dir is None:
178 | val_dir = Path(f"{dataset}/val")
179 | else:
180 | val_dir = Path(val_dir)
181 |
182 | assert dataset in ["cifar10", "cifar100", "stl10", "imagenet", "imagenet100", "custom"]
183 |
184 | if dataset in ["cifar10", "cifar100"]:
185 | DatasetClass = vars(torchvision.datasets)[dataset.upper()]
186 | train_dataset = DatasetClass(
187 | data_dir / train_dir,
188 | train=True,
189 | download=download,
190 | transform=T_train,
191 | )
192 |
193 | val_dataset = DatasetClass(
194 | data_dir / val_dir,
195 | train=False,
196 | download=download,
197 | transform=T_val,
198 | )
199 |
200 | elif dataset == "stl10":
201 | train_dataset = STL10(
202 | data_dir / train_dir,
203 | split="train",
204 | download=True,
205 | transform=T_train,
206 | )
207 | val_dataset = STL10(
208 | data_dir / val_dir,
209 | split="test",
210 | download=download,
211 | transform=T_val,
212 | )
213 |
214 | elif dataset in ["imagenet", "imagenet100", "custom"]:
215 | train_dir = data_dir / train_dir
216 | val_dir = data_dir / val_dir
217 |
218 | train_dataset = ImageFolder(train_dir, T_train)
219 | val_dataset = ImageFolder(val_dir, T_val)
220 |
221 | return train_dataset, val_dataset
222 |
223 |
224 | def prepare_dataloaders(
225 | train_dataset: Dataset, val_dataset: Dataset, batch_size: int = 64, num_workers: int = 4
226 | ) -> Tuple[DataLoader, DataLoader]:
227 | """Wraps a train and a validation dataset with a DataLoader.
228 |
229 | Args:
230 | train_dataset (Dataset): object containing training data.
231 | val_dataset (Dataset): object containing validation data.
232 | batch_size (int): batch size.
233 | num_workers (int): number of parallel workers.
234 | Returns:
235 | Tuple[DataLoader, DataLoader]: training dataloader and validation dataloader.
236 | """
237 |
238 | train_loader = DataLoader(
239 | train_dataset,
240 | batch_size=batch_size,
241 | shuffle=True,
242 | num_workers=num_workers,
243 | pin_memory=True,
244 | drop_last=True,
245 | )
246 | val_loader = DataLoader(
247 | val_dataset,
248 | batch_size=batch_size,
249 | num_workers=num_workers,
250 | pin_memory=True,
251 | drop_last=False,
252 | )
253 | return train_loader, val_loader
254 |
255 |
256 | def prepare_data(
257 | dataset: str,
258 | data_dir: Optional[Union[str, Path]] = None,
259 | train_dir: Optional[Union[str, Path]] = None,
260 | val_dir: Optional[Union[str, Path]] = None,
261 | batch_size: int = 64,
262 | num_workers: int = 4,
263 | download: bool = True,
264 | ) -> Tuple[DataLoader, DataLoader]:
265 | """Prepares transformations, creates dataset objects and wraps them in dataloaders.
266 |
267 | Args:
268 | dataset (str): dataset name.
269 | data_dir (Optional[Union[str, Path]], optional): path where to download/locate the dataset.
270 | Defaults to None.
271 | train_dir (Optional[Union[str, Path]], optional): subpath where the
272 | training data is located. Defaults to None.
273 | val_dir (Optional[Union[str, Path]], optional): subpath where the
274 | validation data is located. Defaults to None.
275 | batch_size (int, optional): batch size. Defaults to 64.
276 | num_workers (int, optional): number of parallel workers. Defaults to 4.
277 |
278 | Returns:
279 | Tuple[DataLoader, DataLoader]: prepared training and validation dataloader;.
280 | """
281 |
282 | T_train, T_val = prepare_transforms(dataset)
283 | train_dataset, val_dataset = prepare_datasets(
284 | dataset,
285 | T_train,
286 | T_val,
287 | data_dir=data_dir,
288 | train_dir=train_dir,
289 | val_dir=val_dir,
290 | download=download,
291 | )
292 | train_loader, val_loader = prepare_dataloaders(
293 | train_dataset,
294 | val_dataset,
295 | batch_size=batch_size,
296 | num_workers=num_workers,
297 | )
298 | return train_loader, val_loader
299 |
--------------------------------------------------------------------------------
/solo/utils/auto_umap.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import math
21 | import os
22 | import random
23 | import string
24 | import time
25 | from argparse import ArgumentParser, Namespace
26 | from pathlib import Path
27 | from typing import Optional, Union
28 |
29 | import pandas as pd
30 | import pytorch_lightning as pl
31 | import seaborn as sns
32 | import torch
33 | import torch.nn as nn
34 | import umap
35 | import wandb
36 | from matplotlib import pyplot as plt
37 | from pytorch_lightning.callbacks import Callback
38 | from tqdm import tqdm
39 |
40 | from .misc import gather
41 |
42 |
43 | def random_string(letter_count=4, digit_count=4):
44 | tmp_random = random.Random(time.time())
45 | rand_str = "".join((tmp_random.choice(string.ascii_lowercase) for x in range(letter_count)))
46 | rand_str += "".join((tmp_random.choice(string.digits) for x in range(digit_count)))
47 | rand_str = list(rand_str)
48 | tmp_random.shuffle(rand_str)
49 | return "".join(rand_str)
50 |
51 |
52 | class AutoUMAP(Callback):
53 | def __init__(
54 | self,
55 | args: Namespace,
56 | logdir: Union[str, Path] = Path("auto_umap"),
57 | frequency: int = 1,
58 | keep_previous: bool = False,
59 | color_palette: str = "hls",
60 | ):
61 | """UMAP callback that automatically runs UMAP on the validation dataset and uploads the
62 | figure to wandb.
63 |
64 | Args:
65 | args (Namespace): namespace object containing at least an attribute name.
66 | logdir (Union[str, Path], optional): base directory to store checkpoints.
67 | Defaults to Path("auto_umap").
68 | frequency (int, optional): number of epochs between each UMAP. Defaults to 1.
69 | color_palette (str, optional): color scheme for the classes. Defaults to "hls".
70 | keep_previous (bool, optional): whether to keep previous plots or not.
71 | Defaults to False.
72 | """
73 |
74 | super().__init__()
75 |
76 | self.args = args
77 | self.logdir = Path(logdir)
78 | self.frequency = frequency
79 | self.color_palette = color_palette
80 | self.keep_previous = keep_previous
81 |
82 | @staticmethod
83 | def add_auto_umap_args(parent_parser: ArgumentParser):
84 | """Adds user-required arguments to a parser.
85 |
86 | Args:
87 | parent_parser (ArgumentParser): parser to add new args to.
88 | """
89 |
90 | parser = parent_parser.add_argument_group("auto_umap")
91 | parser.add_argument("--auto_umap_dir", default=Path("auto_umap"), type=Path)
92 | parser.add_argument("--auto_umap_frequency", default=1, type=int)
93 | return parent_parser
94 |
95 | def initial_setup(self, trainer: pl.Trainer):
96 | """Creates the directories and does the initial setup needed.
97 |
98 | Args:
99 | trainer (pl.Trainer): pytorch lightning trainer object.
100 | """
101 |
102 | if trainer.logger is None:
103 | if self.logdir.exists():
104 | existing_versions = set(os.listdir(self.logdir))
105 | else:
106 | existing_versions = []
107 | version = "offline-" + random_string()
108 | while version in existing_versions:
109 | version = "offline-" + random_string()
110 | else:
111 | version = str(trainer.logger.version)
112 | if version is not None:
113 | self.path = self.logdir / version
114 | self.umap_placeholder = f"{self.args.name}-{version}" + "-ep={}.pdf"
115 | else:
116 | self.path = self.logdir
117 | self.umap_placeholder = f"{self.args.name}" + "-ep={}.pdf"
118 | self.last_ckpt: Optional[str] = None
119 |
120 | # create logging dirs
121 | if trainer.is_global_zero:
122 | os.makedirs(self.path, exist_ok=True)
123 |
124 | def on_train_start(self, trainer: pl.Trainer, _):
125 | """Performs initial setup on training start.
126 |
127 | Args:
128 | trainer (pl.Trainer): pytorch lightning trainer object.
129 | """
130 |
131 | self.initial_setup(trainer)
132 |
133 | def plot(self, trainer: pl.Trainer, module: pl.LightningModule):
134 | """Produces a UMAP visualization by forwarding all data of the
135 | first validation dataloader through the module.
136 |
137 | Args:
138 | trainer (pl.Trainer): pytorch lightning trainer object.
139 | module (pl.LightningModule): current module object.
140 | """
141 |
142 | device = module.device
143 | data = []
144 | Y = []
145 |
146 | # set module to eval model and collect all feature representations
147 | module.eval()
148 | with torch.no_grad():
149 | for x, y in trainer.val_dataloaders[0]:
150 | x = x.to(device, non_blocking=True)
151 | y = y.to(device, non_blocking=True)
152 |
153 | feats = module(x)["feats"]
154 |
155 | feats = gather(feats)
156 | y = gather(y)
157 |
158 | data.append(feats.cpu())
159 | Y.append(y.cpu())
160 | module.train()
161 |
162 | if trainer.is_global_zero and len(data):
163 | data = torch.cat(data, dim=0).numpy()
164 | Y = torch.cat(Y, dim=0)
165 | num_classes = len(torch.unique(Y))
166 | Y = Y.numpy()
167 |
168 | data = umap.UMAP(n_components=2).fit_transform(data)
169 |
170 | # passing to dataframe
171 | df = pd.DataFrame()
172 | df["feat_1"] = data[:, 0]
173 | df["feat_2"] = data[:, 1]
174 | df["Y"] = Y
175 | plt.figure(figsize=(9, 9))
176 | ax = sns.scatterplot(
177 | x="feat_1",
178 | y="feat_2",
179 | hue="Y",
180 | palette=sns.color_palette(self.color_palette, num_classes),
181 | data=df,
182 | legend="full",
183 | alpha=0.3,
184 | )
185 | ax.set(xlabel="", ylabel="", xticklabels=[], yticklabels=[])
186 | ax.tick_params(left=False, right=False, bottom=False, top=False)
187 |
188 | # manually improve quality of imagenet umaps
189 | if num_classes > 100:
190 | anchor = (0.5, 1.8)
191 | else:
192 | anchor = (0.5, 1.35)
193 |
194 | plt.legend(loc="upper center", bbox_to_anchor=anchor, ncol=math.ceil(num_classes / 10))
195 | plt.tight_layout()
196 |
197 | if isinstance(trainer.logger, pl.loggers.WandbLogger):
198 | wandb.log(
199 | {"validation_umap": wandb.Image(ax)},
200 | commit=False,
201 | )
202 |
203 | # save plot locally as well
204 | epoch = trainer.current_epoch # type: ignore
205 | plt.savefig(self.path / self.umap_placeholder.format(epoch))
206 | plt.close()
207 |
208 | def on_validation_end(self, trainer: pl.Trainer, module: pl.LightningModule):
209 | """Tries to generate an up-to-date UMAP visualization of the features
210 | at the end of each validation epoch.
211 |
212 | Args:
213 | trainer (pl.Trainer): pytorch lightning trainer object.
214 | """
215 |
216 | epoch = trainer.current_epoch # type: ignore
217 | if epoch % self.frequency == 0 and not trainer.sanity_checking:
218 | self.plot(trainer, module)
219 |
220 |
221 | class OfflineUMAP:
222 | def __init__(self, color_palette: str = "hls"):
223 | """Offline UMAP helper.
224 |
225 | Args:
226 | color_palette (str, optional): color scheme for the classes. Defaults to "hls".
227 | """
228 |
229 | self.color_palette = color_palette
230 |
231 | def plot(
232 | self,
233 | device: str,
234 | model: nn.Module,
235 | dataloader: torch.utils.data.DataLoader,
236 | plot_path: str,
237 | ):
238 | """Produces a UMAP visualization by forwarding all data of the
239 | first validation dataloader through the model.
240 | **Note: the model should produce features for the forward() function.
241 |
242 | Args:
243 | device (str): gpu/cpu device.
244 | model (nn.Module): current model.
245 | dataloader (torch.utils.data.Dataloader): current dataloader containing data.
246 | plot_path (str): path to save the figure.
247 | """
248 |
249 | data = []
250 | Y = []
251 |
252 | # set module to eval model and collect all feature representations
253 | model.eval()
254 | with torch.no_grad():
255 | for x, y in tqdm(dataloader, desc="Collecting features"):
256 | x = x.to(device, non_blocking=True)
257 | y = y.to(device, non_blocking=True)
258 |
259 | feats = model(x)
260 | data.append(feats.cpu())
261 | Y.append(y.cpu())
262 | model.train()
263 |
264 | data = torch.cat(data, dim=0).numpy()
265 | Y = torch.cat(Y, dim=0)
266 | num_classes = len(torch.unique(Y))
267 | Y = Y.numpy()
268 |
269 | print("Creating UMAP")
270 | data = umap.UMAP(n_components=2).fit_transform(data)
271 |
272 | # passing to dataframe
273 | df = pd.DataFrame()
274 | df["feat_1"] = data[:, 0]
275 | df["feat_2"] = data[:, 1]
276 | df["Y"] = Y
277 | plt.figure(figsize=(9, 9))
278 | ax = sns.scatterplot(
279 | x="feat_1",
280 | y="feat_2",
281 | hue="Y",
282 | palette=sns.color_palette(self.color_palette, num_classes),
283 | data=df,
284 | legend="full",
285 | alpha=0.3,
286 | )
287 | ax.set(xlabel="", ylabel="", xticklabels=[], yticklabels=[])
288 | ax.tick_params(left=False, right=False, bottom=False, top=False)
289 |
290 | # manually improve quality of imagenet umaps
291 | if num_classes > 100:
292 | anchor = (0.5, 1.8)
293 | else:
294 | anchor = (0.5, 1.35)
295 |
296 | plt.legend(loc="upper center", bbox_to_anchor=anchor, ncol=math.ceil(num_classes / 10))
297 | plt.tight_layout()
298 |
299 | # save plot locally as well
300 | plt.savefig(plot_path)
301 | plt.close()
302 |
--------------------------------------------------------------------------------
/solo/methods/dali.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import math
21 | from abc import ABC
22 | from pathlib import Path
23 | from typing import List
24 |
25 | import torch
26 | import torch.nn as nn
27 | from nvidia.dali.plugin.pytorch import DALIGenericIterator, LastBatchPolicy
28 | from solo.utils.dali_dataloader import (
29 | CustomNormalPipeline,
30 | CustomTransform,
31 | ImagenetTransform,
32 | NormalPipeline,
33 | PretrainPipeline,
34 | )
35 |
36 |
37 | class BaseWrapper(DALIGenericIterator):
38 | """Temporary fix to handle LastBatchPolicy.DROP."""
39 |
40 | def __len__(self):
41 | size = (
42 | self._size_no_pad // self._shards_num
43 | if self._last_batch_policy == LastBatchPolicy.DROP
44 | else self.size
45 | )
46 | if self._reader_name:
47 | if self._last_batch_policy != LastBatchPolicy.DROP:
48 | return math.ceil(size / self.batch_size)
49 |
50 | return size // self.batch_size
51 | else:
52 | if self._last_batch_policy != LastBatchPolicy.DROP:
53 | return math.ceil(size / (self._num_gpus * self.batch_size))
54 |
55 | return size // (self._num_gpus * self.batch_size)
56 |
57 |
58 | class PretrainWrapper(BaseWrapper):
59 | def __init__(
60 | self,
61 | model_batch_size: int,
62 | model_rank: int,
63 | model_device: str,
64 | conversion_map: List[int] = None,
65 | *args,
66 | **kwargs,
67 | ):
68 | """Adds indices to a batch fetched from the parent.
69 |
70 | Args:
71 | model_batch_size (int): batch size.
72 | model_rank (int): rank of the current process.
73 | model_device (str): id of the current device.
74 | conversion_map (List[int], optional): list of integeres that map each index
75 | to a class label. If nothing is passed, no label mapping needs to be done.
76 | Defaults to None.
77 | """
78 |
79 | super().__init__(*args, **kwargs)
80 | self.model_batch_size = model_batch_size
81 | self.model_rank = model_rank
82 | self.model_device = model_device
83 | self.conversion_map = conversion_map
84 | if self.conversion_map is not None:
85 | self.conversion_map = torch.tensor(
86 | self.conversion_map, dtype=torch.float32, device=self.model_device
87 | ).reshape(-1, 1)
88 | self.conversion_map = nn.Embedding.from_pretrained(self.conversion_map)
89 |
90 | def __next__(self):
91 | batch = super().__next__()[0]
92 | # PyTorch Lightning does double buffering
93 | # https://github.com/PyTorchLightning/pytorch-lightning/issues/1316,
94 | # and as DALI owns the tensors it returns the content of it is trashed so the copy needs,
95 | # to be made before returning.
96 |
97 | if self.conversion_map is not None:
98 | *all_X, indexes = [batch[v] for v in self.output_map]
99 | targets = self.conversion_map(indexes).flatten().long().detach().clone()
100 | indexes = indexes.flatten().long().detach().clone()
101 | else:
102 | *all_X, targets = [batch[v] for v in self.output_map]
103 | targets = targets.squeeze(-1).long().detach().clone()
104 | # creates dummy indexes
105 | indexes = (
106 | (
107 | torch.arange(self.model_batch_size, device=self.model_device)
108 | + (self.model_rank * self.model_batch_size)
109 | )
110 | .detach()
111 | .clone()
112 | )
113 |
114 | all_X = [x.detach().clone() for x in all_X]
115 | return [indexes, all_X, targets]
116 |
117 |
118 | class Wrapper(BaseWrapper):
119 | def __next__(self):
120 | batch = super().__next__()
121 | x, target = batch[0]["x"], batch[0]["label"]
122 | target = target.squeeze(-1).long()
123 | # PyTorch Lightning does double buffering
124 | # https://github.com/PyTorchLightning/pytorch-lightning/issues/1316,
125 | # and as DALI owns the tensors it returns the content of it is trashed so the copy needs,
126 | # to be made before returning.
127 | x = x.detach().clone()
128 | target = target.detach().clone()
129 | return x, target
130 |
131 |
132 | class PretrainABC(ABC):
133 | """Abstract pretrain class that returns a train_dataloader using dali."""
134 |
135 | def train_dataloader(self) -> DALIGenericIterator:
136 | """Returns a train dataloader using dali. Supports multi-crop and asymmetric augmentations.
137 |
138 | Returns:
139 | DALIGenericIterator: a train dataloader in the form of a dali pipeline object wrapped
140 | with PretrainWrapper.
141 | """
142 |
143 | device_id = self.local_rank
144 | shard_id = self.global_rank
145 | num_shards = self.trainer.world_size
146 |
147 | # get data arguments from model
148 | dali_device = self.extra_args["dali_device"]
149 |
150 | # data augmentations
151 | unique_augs = self.extra_args["unique_augs"]
152 | transform_kwargs = self.extra_args["transform_kwargs"]
153 | num_crops_per_aug = self.extra_args["num_crops_per_aug"]
154 |
155 | num_workers = self.extra_args["num_workers"]
156 | data_dir = Path(self.extra_args["data_dir"])
157 | train_dir = Path(self.extra_args["train_dir"])
158 |
159 | # hack to encode image indexes into the labels
160 | self.encode_indexes_into_labels = self.extra_args["encode_indexes_into_labels"]
161 |
162 | # handle custom data by creating the needed pipeline
163 | dataset = self.extra_args["dataset"]
164 | if dataset in ["imagenet100", "imagenet"]:
165 | transform_pipeline = ImagenetTransform
166 | elif dataset == "custom":
167 | transform_pipeline = CustomTransform
168 | else:
169 | raise ValueError(dataset, "is not supported, used [imagenet, imagenet100 or custom]")
170 |
171 | if unique_augs > 1:
172 | transforms = [
173 | transform_pipeline(
174 | device=dali_device,
175 | **kwargs,
176 | )
177 | for kwargs in transform_kwargs
178 | ]
179 | else:
180 | transforms = [transform_pipeline(device=dali_device, **transform_kwargs)]
181 |
182 | train_pipeline = PretrainPipeline(
183 | data_dir / train_dir,
184 | batch_size=self.batch_size,
185 | transforms=transforms,
186 | num_crops_per_aug=num_crops_per_aug,
187 | device=dali_device,
188 | device_id=device_id,
189 | shard_id=shard_id,
190 | num_shards=num_shards,
191 | num_threads=num_workers,
192 | no_labels=self.extra_args["no_labels"],
193 | encode_indexes_into_labels=self.encode_indexes_into_labels,
194 | )
195 | output_map = (
196 | [f"large{i}" for i in range(self.num_large_crops)]
197 | + [f"small{i}" for i in range(self.num_small_crops)]
198 | + ["label"]
199 | )
200 |
201 | policy = LastBatchPolicy.DROP
202 | conversion_map = train_pipeline.conversion_map if self.encode_indexes_into_labels else None
203 | train_loader = PretrainWrapper(
204 | model_batch_size=self.batch_size,
205 | model_rank=device_id,
206 | model_device=self.device,
207 | conversion_map=conversion_map,
208 | pipelines=train_pipeline,
209 | output_map=output_map,
210 | reader_name="Reader",
211 | last_batch_policy=policy,
212 | auto_reset=True,
213 | )
214 |
215 | self.dali_epoch_size = train_pipeline.epoch_size("Reader")
216 |
217 | return train_loader
218 |
219 |
220 | class ClassificationABC(ABC):
221 | """Abstract classification class that returns a train_dataloader and val_dataloader using
222 | dali."""
223 |
224 | def train_dataloader(self) -> DALIGenericIterator:
225 | device_id = self.local_rank
226 | shard_id = self.global_rank
227 | num_shards = self.trainer.world_size
228 |
229 | num_workers = self.extra_args["num_workers"]
230 | dali_device = self.extra_args["dali_device"]
231 | data_dir = Path(self.extra_args["data_dir"])
232 | train_dir = Path(self.extra_args["train_dir"])
233 |
234 | # handle custom data by creating the needed pipeline
235 | dataset = self.extra_args["dataset"]
236 | if dataset in ["imagenet100", "imagenet"]:
237 | pipeline_class = NormalPipeline
238 | elif dataset == "custom":
239 | pipeline_class = CustomNormalPipeline
240 | else:
241 | raise ValueError(dataset, "is not supported, used [imagenet, imagenet100 or custom]")
242 |
243 | train_pipeline = pipeline_class(
244 | data_dir / train_dir,
245 | validation=False,
246 | batch_size=self.batch_size,
247 | device=dali_device,
248 | device_id=device_id,
249 | shard_id=shard_id,
250 | num_shards=num_shards,
251 | num_threads=num_workers,
252 | )
253 | train_loader = Wrapper(
254 | train_pipeline,
255 | output_map=["x", "label"],
256 | reader_name="Reader",
257 | last_batch_policy=LastBatchPolicy.DROP,
258 | auto_reset=True,
259 | )
260 | return train_loader
261 |
262 | def val_dataloader(self) -> DALIGenericIterator:
263 | device_id = self.local_rank
264 | shard_id = self.global_rank
265 | num_shards = self.trainer.world_size
266 |
267 | num_workers = self.extra_args["num_workers"]
268 | dali_device = self.extra_args["dali_device"]
269 | data_dir = Path(self.extra_args["data_dir"])
270 | val_dir = Path(self.extra_args["val_dir"])
271 |
272 | # handle custom data by creating the needed pipeline
273 | dataset = self.extra_args["dataset"]
274 | if dataset in ["imagenet100", "imagenet"]:
275 | pipeline_class = NormalPipeline
276 | elif dataset == "custom":
277 | pipeline_class = CustomNormalPipeline
278 | else:
279 | raise ValueError(dataset, "is not supported, used [imagenet, imagenet100 or custom]")
280 |
281 | val_pipeline = pipeline_class(
282 | data_dir / val_dir,
283 | validation=True,
284 | batch_size=self.batch_size,
285 | device=dali_device,
286 | device_id=device_id,
287 | shard_id=shard_id,
288 | num_shards=num_shards,
289 | num_threads=num_workers,
290 | )
291 |
292 | val_loader = Wrapper(
293 | val_pipeline,
294 | output_map=["x", "label"],
295 | reader_name="Reader",
296 | last_batch_policy=LastBatchPolicy.PARTIAL,
297 | auto_reset=True,
298 | )
299 | return val_loader
300 |
--------------------------------------------------------------------------------
/solo/methods/mocov2_distillation_AT.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import argparse
21 | from typing import Any, Dict, List, Sequence, Tuple
22 |
23 | import torch
24 | import torch.nn as nn
25 | import torch.nn.functional as F
26 | from solo.losses.moco import moco_loss_func
27 | # from solo.losses.simclr import simclr_loss_func
28 | # from solo.methods.base import BaseDistillationMethod
29 | from solo.methods.base_for_adversarial_training import BaseDistillationATMethod
30 | from solo.utils.momentum import initialize_momentum_params
31 | from solo.utils.misc import gather
32 |
33 | from torchvision import models
34 | from solo.utils.metrics import accuracy_at_k, weighted_mean
35 |
36 |
37 | class MoCoV2KDAT(BaseDistillationATMethod):
38 | queue: torch.Tensor
39 |
40 | def __init__(
41 | self,
42 | proj_output_dim: int,
43 | proj_hidden_dim: int,
44 | temperature: float,
45 | queue_size: int,
46 | teacher_logit_fix: bool,
47 | loss_type: str,
48 | projector_ablation: str,
49 | epsilon: int = 8,
50 | num_steps: int = 5,
51 | step_size: int = 2,
52 | trades_k: float = 3,
53 | aux_data: bool = False,
54 | augmentation_ablation: bool = False,
55 | expriment_code: str = "000",
56 | **kwargs
57 | ):
58 | """Implements MoCo V2+ (https://arxiv.org/abs/2011.10566).
59 |
60 | Args:
61 | proj_output_dim (int): number of dimensions of projected features.
62 | proj_hidden_dim (int): number of neurons of the hidden layers of the projector.
63 | temperature (float): temperature for the softmax in the contrastive loss.
64 | queue_size (int): number of samples to keep in the queue.
65 | """
66 |
67 | super().__init__(**kwargs)
68 |
69 | self.temperature = temperature
70 | self.queue_size = queue_size
71 | self.loss_type = loss_type
72 | self.projector_ablation = projector_ablation
73 | self.trades_k = trades_k
74 | self.aux_data = aux_data
75 |
76 | self.augmentation_ablation = augmentation_ablation
77 | self.expriment_code = expriment_code
78 |
79 |
80 | self.epsilon = epsilon/255.
81 | self.num_steps = num_steps
82 | self.step_size = step_size/255.
83 |
84 |
85 | self.projector = nn.Identity()
86 | self.momentum_projector = nn.Identity()
87 |
88 |
89 | @staticmethod
90 | def add_model_specific_args(parent_parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
91 | parent_parser = super(MoCoV2KDAT, MoCoV2KDAT).add_model_specific_args(parent_parser)
92 | parser = parent_parser.add_argument_group("mocov2_kd_at")
93 |
94 | # projector
95 | parser.add_argument("--proj_output_dim", type=int, default=128)
96 | parser.add_argument("--proj_hidden_dim", type=int, default=2048)
97 |
98 | # parameters
99 | parser.add_argument("--temperature", type=float, default=0.1)
100 |
101 | # queue settings
102 | parser.add_argument("--queue_size", default=65536, type=int)
103 |
104 |
105 | parser.add_argument("--teacher_logit_fix", action="store_true")
106 |
107 | # choose loss to train the model
108 | SUPPORT_LOSS = ["ce2gt_ae2ce",
109 | "ae2gt_ae2ce",
110 | "ce2gt_ae2ce_infonce",
111 | "pgd",
112 | "ae2gt_infonce",
113 |
114 | "ae2gt_kl",
115 | "ce2gt_ae2ce_kl",
116 | "cs.ce2gt._kl.ae2ce.",
117 | "kl.ce2gt._cs.ae2ce.",
118 | "cs.ce2gt._infonce.ae2ce.",
119 | "infonce.ce2gt._cs.ae2ce.",
120 | "infonce.ce2gt._kl.ae2ce.",
121 | "kl.ce2gt._infonce.ae2ce."
122 | ]
123 |
124 | parser.add_argument("--loss_type", default="ce2gt_ae2ce", type=str, choices=SUPPORT_LOSS)
125 |
126 |
127 | # projector exploration
128 | PROJECTOR_ABLATION = ["remove", "same", "correspond", 'only_student', 'only_teacher']
129 | parser.add_argument("--projector_ablation", default="remove", type=str, choices=PROJECTOR_ABLATION)
130 |
131 |
132 | # training adversarial hyper parameter
133 | parser.add_argument("--epsilon", type=int, default=8)
134 | parser.add_argument("--step_size", type=int, default=2)
135 | parser.add_argument("--num_steps", type=int, default=5)
136 |
137 | # for loss factor
138 | parser.add_argument("--trades_k", type=float, default=2)
139 |
140 | # for augmentation ablation study
141 | parser.add_argument("--augmentation_ablation", action="store_true")
142 | parser.add_argument("--expriment_code", default="00", type=str)
143 |
144 |
145 |
146 |
147 |
148 | return parent_parser
149 |
150 | @property
151 | def learnable_params(self) -> List[dict]:
152 | """Adds projector parameters together with parent's learnable parameters.
153 |
154 | Returns:
155 | List[dict]: list of learnable parameters.
156 | """
157 |
158 | extra_learnable_params = [{"params": self.projector.parameters()}]
159 | return super().learnable_params + extra_learnable_params
160 |
161 | @property
162 | def momentum_pairs(self) -> List[Tuple[Any, Any]]:
163 | """Adds (projector, momentum_projector) to the parent's momentum pairs.
164 |
165 | Returns:
166 | List[Tuple[Any, Any]]: list of momentum pairs.
167 | """
168 |
169 | extra_momentum_pairs = [(self.projector, self.momentum_projector)]
170 | return super().momentum_pairs + extra_momentum_pairs
171 |
172 |
173 | def forward(self, X: torch.Tensor):
174 | """Performs the forward pass of the online backbone and projector.
175 |
176 | Args:
177 | X (torch.Tensor): a batch of images in the tensor format.
178 |
179 | Returns:
180 | Dict[str, Any]: a dict containing the outputs of the parent and the projected features.
181 | """
182 | out = self.backbone(X)
183 |
184 | return out
185 |
186 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
187 | """
188 | Training step for MoCo reusing BaseMomentumMethod training step.
189 |
190 | Args:
191 | batch (Sequence[Any]): a batch of data in the
192 | format of [img_indexes, [X], Y], where [X] is a list of size self.num_large_crops
193 | containing batches of images.
194 | batch_idx (int): index of the batch.
195 |
196 | Returns:
197 | torch.Tensor: total loss composed of MOCO loss and classification loss.
198 |
199 | """
200 |
201 |
202 | self.momentum_backbone.eval()
203 |
204 |
205 | # import ipdb; ipdb.set_trace()
206 | if self.aux_data:
207 | image_tau1, image_weak = batch[0]
208 | targets = batch[1]
209 |
210 | else:
211 | image_tau1, image_weak = batch[1]
212 | targets = batch[2]
213 |
214 |
215 | ############################################################################
216 | # Adversarial Training (CAT)
217 | ############################################################################
218 | # if self.trainer.current_epoch ==2:
219 | # import ipdb; ipdb.set_trace()
220 | # import ipdb; ipdb.set_trace()
221 |
222 | away_target = self.momentum_projector(self.momentum_backbone(image_weak))
223 |
224 | AE_generation_image = image_weak
225 | image_AE = self.generate_training_AE(AE_generation_image, away_target)
226 |
227 | image_CAT = torch.cat([image_weak, image_AE])
228 | logits_all = self.projector(self.backbone(image_CAT))
229 | bs = image_weak.size(0)
230 | student_logits_clean = logits_all[:bs]
231 | student_logits_AE = logits_all[bs:]
232 |
233 | # Cosine Similarity loss
234 | adv_loss = -F.cosine_similarity(student_logits_clean, away_target).mean()
235 | adv_loss += -self.trades_k*F.cosine_similarity(student_logits_AE, student_logits_clean).mean()
236 |
237 | ############################################################################
238 | # Adversarial Training (CAT)
239 | ############################################################################
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 | ############################################################################
250 | # Online clean classifier training
251 | ############################################################################
252 | # Bug Fix: train classifier using evaluation mode
253 | self.backbone.eval()
254 | # import ipdb; ipdb.set_trace()
255 | outs_image_weak = self._base_shared_step(image_weak, targets)
256 | self.backbone.train()
257 | metrics = {
258 | "train_class_loss": outs_image_weak["loss"],
259 | "train_acc1": outs_image_weak["acc1"],
260 | "train_acc5": outs_image_weak["acc5"],
261 | }
262 | class_loss_clean = outs_image_weak["loss"]
263 | self.log_dict(metrics, on_epoch=True)
264 | ############################################################################
265 | # Online clean classifier training
266 | ############################################################################
267 |
268 |
269 | ae_std = F.normalize(student_logits_AE, dim=-1).std(dim=0).mean()
270 | clean_std = F.normalize(student_logits_clean, dim=-1).std(dim=0).mean()
271 | teacher_std = F.normalize(away_target, dim=-1).std(dim=0).mean()
272 |
273 |
274 | metrics = {
275 | "adv_loss": adv_loss,
276 | "ae_std": ae_std,
277 | "clean_std": clean_std,
278 | "teacher_std": teacher_std
279 | }
280 | self.log_dict(metrics, on_epoch=True, sync_dist=True)
281 |
282 | # return adv_loss + class_loss_adv + class_loss_clean
283 | return adv_loss + class_loss_clean
284 |
285 |
286 |
287 | def generate_training_AE(self, image: torch.Tensor, away_target: torch.Tensor):
288 | """
289 | images_org: weak aug
290 | away_target: from teacher
291 | """
292 |
293 | # self.epsilon = 8/255.
294 | # self.num_steps = 5
295 | # self.step_size = 2/255.
296 |
297 | x_cl = image.clone().detach()
298 |
299 | # if self.rand:
300 | x_cl = x_cl + torch.zeros_like(image).uniform_(-self.epsilon, self.epsilon)
301 |
302 | # f_ori_proj = self.model(images_org).detach()
303 | # Change the attack process of model to eval
304 | self.backbone.eval()
305 |
306 | for i in range(self.num_steps):
307 | x_cl.requires_grad_()
308 | with torch.enable_grad():
309 | f_proj = self.projector(self.backbone(x_cl))
310 |
311 | # for 16 bit training
312 | loss_contrast = -F.cosine_similarity(f_proj, away_target, dim=1).sum() *256
313 | loss = loss_contrast
314 |
315 | # import ipdb ;ipdb.set_trace()
316 | grad_x_cl = torch.autograd.grad(loss, x_cl)[0]
317 | # grad_x_cl = torch.autograd.grad(loss, x_cl, grad_outputs=torch.ones_like(loss))[0]
318 | x_cl = x_cl.detach() + self.step_size * torch.sign(grad_x_cl.detach())
319 |
320 | # remove the clamp in for the image comparision
321 | x_cl = torch.min(torch.max(x_cl, image - self.epsilon), image + self.epsilon)
322 | x_cl = torch.clamp(x_cl, 0, 1)
323 |
324 | self.backbone.train()
325 |
326 | return x_cl
327 |
328 |
329 |
330 |
331 |
332 |
333 |
--------------------------------------------------------------------------------