├── .gitignore
├── LICENSE.md
├── README.md
├── asset
└── image.png
├── bash_files
├── moco.sh
├── simco.sh
└── simmoco.sh
├── main_pretrain.py
└── solo
├── __init__.py
├── args
├── __init__.py
├── dataset.py
├── setup.py
└── utils.py
├── losses
├── __init__.py
├── dual_temperature_loss.py
└── moco.py
├── methods
├── __init__.py
├── base.py
├── dali.py
├── mocov2.py
├── mocov2plus.py
├── simco_dual_temperature.py
└── simmoco_dual_temperature.py
└── utils
├── __init__.py
├── auto_umap.py
├── backbones.py
├── checkpointer.py
├── classification_dataloader.py
├── dali_dataloader.py
├── kmeans.py
├── knn.py
├── lars.py
├── metrics.py
├── misc.py
├── momentum.py
├── pretrain_dataloader.py
├── sinkhorn_knopp.py
└── whitening.py
/.gitignore:
--------------------------------------------------------------------------------
1 | wandb
2 | trained_models
3 | code
4 | data
5 | .env
6 |
7 | **/__pycache__/**
8 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Kang Zhang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Dual Temperature Helps Contrastive Learning Without Many Negative Samples: Towards Understanding and Simplifying MoCo (Accepted by CVPR2022)
2 |
3 |
4 | Chaoning Zhang, Kang Zhang, Trung X. Pham, Axi Niu, Zhinan Qiao, Chang D. Yoo, In So Kweon
5 |
6 | Contrastive learning (CL) is widely known to require many negative samples, 65536 in MoCo for instance, for which the performance of a dictionary-free framework is often inferior because the negative sample size (NSS) is limited by its mini-batch size (MBS). To decouple the NSS from the MBS, a dynamic dictionary has been adopted in a large volume of CL frameworks, among which arguably the most popular one is MoCo family. In essence, MoCo adopts a momentum-based queue dictionary, for which we perform a fine-grained analysis of its size and consistency. We point out that InfoNCE loss used in MoCo implicitly attract anchors to their corresponding positive sample with various strength of penalties and identify such inter-anchor hardness-awareness property as a major reason for the necessity of a large dictionary. Our findings motivate us to simplify MoCo v2 via the removal of its dictionary as well as momentum. Based on an InfoNCE with the proposed dual temperature, our simplified frameworks, SimMoCo and SimCo, outperform MoCo v2 by a visible margin. Moreover, our work bridges the gap between CL and non-CL frameworks, contributing to a more unified understanding of these two mainstream frameworks in SSL.
7 |
8 |
9 | This repository is the official implementation of ["Dual Temperature Helps Contrastive Learning Without Many Negative Samples: Towards Understanding and Simplifying MoCo"](https://arxiv.org/abs/2203.17248).
10 |
11 |
12 |
13 |
14 | ---
15 | See also our other works:
16 |
17 | Decoupled Adversarial Contrastive Learning for Self-supervised Adversarial Robustness (Accepted by ECCV2022 oral presentation) [code](https://github.com/pantheon5100/DeACL.git) [paper](https://arxiv.org/abs/2207.10899)
18 |
19 | ---
20 |
21 | # Dual Temperature InfoNCE Loss
22 | You can simply replace your original loss with dual-temperature loss from the following code:
23 | ```python
24 | # q1 is the anchor and k2 is the positive sample
25 | # The intra-anchor hardness-awareness is controlled by `temperature` parameter.
26 | # The inter-anchor hardness awareness is controlled by `dt_m` parameter,
27 | # and temperature is calculated by dt_m * temperature.
28 | nce_loss = dual_temperature_loss_func(q1, k2,
29 | temperature=temperature,
30 | dt_m=dt_m)
31 |
32 | def dual_temperature_loss_func(
33 | query: torch.Tensor,
34 | key: torch.Tensor,
35 | temperature=0.1,
36 | dt_m=10,
37 | ) -> torch.Tensor:
38 |
39 | """
40 | query: anchor sample.
41 | key: positive sample.
42 | temperature: intra-anchor hardness-awareness control temperature.
43 | dt_m: the scalar number to get inter-anchor hardness awareness temperature.
44 | inter-anchor hardness awareness temperature is calculated by dt_m * temperature
45 | """
46 |
47 | # intra-anchor hardness-awareness
48 | b = query.size(0)
49 | pos = torch.einsum("nc,nc->n", [query, key]).unsqueeze(-1)
50 |
51 | # Selecte the intra negative samples according the updata time,
52 | neg = torch.einsum("nc,ck->nk", [query, key.T])
53 | mask_neg = torch.ones_like(neg, dtype=bool)
54 | mask_neg.fill_diagonal_(False)
55 | neg = neg[mask_neg].reshape(neg.size(0), neg.size(1)-1)
56 | logits = torch.cat([pos, neg], dim=1)
57 |
58 | logits_intra = logits / temperature
59 | prob_intra = F.softmax(logits_intra, dim=1)
60 |
61 | # inter-anchor hardness-awareness
62 | logits_inter = logits / (temperature*dt_m)
63 | prob_inter = F.softmax(logits_inter, dim=1)
64 |
65 | # hardness-awareness factor
66 | inter_intra = (1 - prob_inter[:, 0]) / (1 - prob_intra[:, 0])
67 |
68 | loss = -torch.nn.functional.log_softmax(logits_intra, dim=-1)[:, 0]
69 |
70 | # final loss
71 | loss = inter_intra.detach() * loss
72 | loss = loss.mean()
73 |
74 | return loss
75 |
76 | ```
77 |
78 | # 🔧 Enviroment
79 |
80 | Please refer [solo-learn](https://github.com/vturrisi/solo-learn) to install the enviroment.
81 |
82 | > First clone the repo.
83 | >
84 | > Then, to install solo-learn with Dali and/or UMAP support, use:
85 | >
86 | > `pip3 install .[dali,umap,h5] --extra-index-url https://developer.download.nvidia.com/compute/redist`
87 |
88 |
89 | # Dataset
90 | CIFAR10 and CIFAR100 will be automately downloaded.
91 |
92 | # ⚡ Training
93 | To train SimCo, SimMoCo, and MoCoV2, use the script in folder `./bash_files`.
94 |
95 | You should change the entity and project name to enable the wandb logging. `--project --entity `. Or you can simply remove `--wandb` to disable wandb logging.
96 |
97 | # Results
98 |
99 | | Batch size | 64 | 128 | 256 | 512 | 1024 |
100 | |------------|-------|-------|----------------|-------|-------|
101 | | MoCo v2 | 52.58 | 54.40 | 53.28 | 51.47 | 48.90 |
102 | | SimMoCo | 54.02 | 54.93 | 54.11 | 52.45 | 49.70 |
103 | | SimCo | 58.04 | 58.29 | **58.35** | 57.08 | 55.34 |
104 |
105 | More result can be found in the paper.
106 |
107 | This code is developed based on [solo-learn](https://github.com/vturrisi/solo-learn).
108 |
109 | # Citation
110 | ```
111 | @article{zhang2022dual,
112 | title={Dual temperature helps contrastive learning without many negative samples: Towards understanding and simplifying moco},
113 | author={Zhang, Chaoning and Zhang, Kang and Pham, Trung X and Niu, Axi and Qiao, Zhinan and Yoo, Chang D and Kweon, In So},
114 | journal={CVPR},
115 | year={2022}
116 | }
117 | ```
118 |
119 |
120 | # Acknowledgement
121 |
122 | This work was partly supported by Institute for Information & communications Technology Planning & Evaluation (IITP) grant funded by the Korea government (MSIT) under grant No.2019-0-01396 (Development of framework for analyzing, detecting, mitigating of bias in AI model and training data), No.2021-0-01381 (Development of Causal AI through Video Understanding and Reinforcement Learning, and Its Applications to Real Environments) and No.2021-0-02068 (Artificial Intelligence Innovation Hub).
123 |
--------------------------------------------------------------------------------
/asset/image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChaoningZhang/Dual-temperature/e8d5e768255721b23643d948631f80c0b8a0851b/asset/image.png
--------------------------------------------------------------------------------
/bash_files/moco.sh:
--------------------------------------------------------------------------------
1 | # this is the script to train original moco
2 | python3 main_pretrain.py \
3 | --dataset cifar100 \
4 | --encoder resnet18 \
5 | --data_dir ./data \
6 | --max_epochs 200 \
7 | --gpus 0 \
8 | --precision 16 \
9 | --optimizer sgd \
10 | --scheduler warmup_cosine \
11 | --lr 0.03 \
12 | --classifier_lr 0.3 \
13 | --weight_decay 5e-4 \
14 | --batch_size 256 \
15 | --num_workers 3 \
16 | --brightness 0.4 \
17 | --contrast 0.4 \
18 | --saturation 0.4 \
19 | --hue 0.1 \
20 | --gaussian_prob 0.0 0.0 \
21 | --name mocov2 \
22 | --project \
23 | --entity \
24 | --wandb \
25 | --method mocov2 \
26 | --proj_hidden_dim 128 \
27 | --temperature 0.1 \
28 | --base_tau_momentum 0.99 \
29 | --final_tau_momentum 0.99 \
30 | --momentum_classifier
31 |
--------------------------------------------------------------------------------
/bash_files/simco.sh:
--------------------------------------------------------------------------------
1 | # This is the script to train simco
2 | python3 main_pretrain.py \
3 | --dataset cifar100 \
4 | --encoder resnet18 \
5 | --data_dir ./data \
6 | --max_epochs 200 \
7 | --gpus 1 \
8 | --precision 16 \
9 | --optimizer sgd \
10 | --scheduler warmup_cosine \
11 | --lr 0.03 \
12 | --classifier_lr 0.3 \
13 | --weight_decay 5e-4 \
14 | --batch_size 256 \
15 | --num_workers 3 \
16 | --brightness 0.4 \
17 | --contrast 0.4 \
18 | --saturation 0.4 \
19 | --hue 0.1 \
20 | --gaussian_prob 0.0 0.0 \
21 | --name simco \
22 | --project \
23 | --entity \
24 | --wandb \
25 | --method simco_dual_temperature \
26 | --proj_hidden_dim 128 \
27 | --temperature 0.1 \
28 | --dt_m 10 \
29 | --base_tau_momentum 0 \
30 | --final_tau_momentum 0 \
31 | --momentum_classifier
32 |
33 |
--------------------------------------------------------------------------------
/bash_files/simmoco.sh:
--------------------------------------------------------------------------------
1 | # This is the script to train simmoco
2 | python3 main_pretrain.py \
3 | --dataset cifar100 \
4 | --encoder resnet18 \
5 | --data_dir ./data \
6 | --max_epochs 200 \
7 | --gpus 3 \
8 | --precision 16 \
9 | --optimizer sgd \
10 | --scheduler warmup_cosine \
11 | --lr 0.03 \
12 | --classifier_lr 0.3 \
13 | --weight_decay 5e-4 \
14 | --batch_size 256 \
15 | --num_workers 3 \
16 | --brightness 0.4 \
17 | --contrast 0.4 \
18 | --saturation 0.4 \
19 | --hue 0.1 \
20 | --gaussian_prob 0.0 0.0 \
21 | --name simmoco \
22 | --project \
23 | --entity \
24 | --wandb \
25 | --method simmoco_dual_temperature \
26 | --proj_hidden_dim 128 \
27 | --temperature 0.1 \
28 | --dt_m 10 \
29 | --base_tau_momentum 0.99 \
30 | --final_tau_momentum 0.99 \
31 | --momentum_classifier
32 |
33 |
--------------------------------------------------------------------------------
/main_pretrain.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import os
21 | from pprint import pprint
22 |
23 | from pytorch_lightning import Trainer, seed_everything
24 | from pytorch_lightning.callbacks import LearningRateMonitor
25 | from pytorch_lightning.loggers import WandbLogger
26 | from pytorch_lightning.plugins import DDPPlugin
27 |
28 | from solo.args.setup import parse_args_pretrain
29 | from solo.methods import METHODS
30 |
31 | try:
32 | from solo.methods.dali import PretrainABC
33 | except ImportError:
34 | _dali_avaliable = False
35 | else:
36 | _dali_avaliable = True
37 |
38 | try:
39 | from solo.utils.auto_umap import AutoUMAP
40 | except ImportError:
41 | _umap_available = False
42 | else:
43 | _umap_available = True
44 |
45 | from solo.utils.checkpointer import Checkpointer
46 | from solo.utils.classification_dataloader import prepare_data as prepare_data_classification
47 | from solo.utils.pretrain_dataloader import (
48 | prepare_dataloader,
49 | prepare_datasets,
50 | prepare_multicrop_transform,
51 | prepare_n_crop_transform,
52 | prepare_transform,
53 | )
54 | import shutil
55 | import sys
56 | import glob
57 |
58 |
59 | def main():
60 | # set the seed
61 | seed_everything(15)
62 |
63 | args = parse_args_pretrain()
64 |
65 | assert args.method in METHODS, f"Choose from {METHODS.keys()}"
66 |
67 | MethodClass = METHODS[args.method]
68 | if args.dali:
69 | assert (
70 | _dali_avaliable
71 | ), "Dali is not currently avaiable, please install it first with [dali]."
72 | MethodClass = type(f"Dali{MethodClass.__name__}", (MethodClass, PretrainABC), {})
73 |
74 | model = MethodClass(**args.__dict__)
75 |
76 | # contrastive dataloader
77 | if not args.dali:
78 | # asymmetric augmentations
79 | if args.unique_augs > 1:
80 | transform = [
81 | prepare_transform(args.dataset, multicrop=args.multicrop, **kwargs)
82 | for kwargs in args.transform_kwargs
83 | ]
84 | else:
85 | transform = prepare_transform(
86 | args.dataset, multicrop=args.multicrop, **args.transform_kwargs
87 | )
88 |
89 | if args.debug_augmentations:
90 | print("Transforms:")
91 | pprint(transform)
92 |
93 | if args.multicrop:
94 | assert not args.unique_augs == 1
95 |
96 | if args.dataset in ["cifar10", "cifar100"]:
97 | size_crops = [32, 24]
98 | elif args.dataset == "stl10":
99 | size_crops = [96, 58]
100 | # imagenet or custom dataset
101 | else:
102 | size_crops = [224, 96]
103 |
104 | transform = prepare_multicrop_transform(
105 | transform, size_crops=size_crops, num_crops=[args.num_crops, args.num_small_crops]
106 | )
107 | else:
108 | if args.num_crops != 2:
109 | # import pdb; pdb.set_trace()
110 | assert args.method == "wmse" or args.method == "simsiam_eoa" or args.method == "simclr_neg_size"
111 |
112 | transform = prepare_n_crop_transform(transform, num_crops=args.num_crops)
113 |
114 | train_dataset = prepare_datasets(
115 | args.dataset,
116 | transform,
117 | data_dir=args.data_dir,
118 | train_dir=args.train_dir,
119 | no_labels=args.no_labels,
120 | )
121 | train_loader = prepare_dataloader(
122 | train_dataset, batch_size=args.batch_size, num_workers=args.num_workers
123 | )
124 |
125 | # normal dataloader for when it is available
126 | if args.dataset == "custom" and (args.no_labels or args.val_dir is None):
127 | val_loader = None
128 | elif args.dataset in ["imagenet100", "imagenet"] and args.val_dir is None:
129 | val_loader = None
130 | else:
131 | _, val_loader = prepare_data_classification(
132 | args.dataset,
133 | data_dir=args.data_dir,
134 | train_dir=args.train_dir,
135 | val_dir=args.val_dir,
136 | batch_size=args.batch_size,
137 | num_workers=args.num_workers,
138 | )
139 |
140 | callbacks = []
141 |
142 | # wandb logging
143 | if args.wandb:
144 | wandb_logger = WandbLogger(
145 | name=args.name,
146 | project=args.project,
147 | entity=args.entity,
148 | offline=args.offline,
149 | )
150 | wandb_logger.watch(model, log="gradients", log_freq=100)
151 | wandb_logger.log_hyperparams(args)
152 |
153 | # lr logging
154 | lr_monitor = LearningRateMonitor(logging_interval="epoch")
155 | callbacks.append(lr_monitor)
156 |
157 | # save checkpoint on last epoch only
158 | ckpt = Checkpointer(
159 | args,
160 | logdir=os.path.join(args.checkpoint_dir, args.method),
161 | frequency=args.checkpoint_frequency,
162 | )
163 | callbacks.append(ckpt)
164 |
165 | if args.auto_umap:
166 | assert (
167 | _umap_available
168 | ), "UMAP is not currently avaiable, please install it first with [umap]."
169 | auto_umap = AutoUMAP(
170 | args,
171 | logdir=os.path.join(args.auto_umap_dir, args.method),
172 | frequency=args.auto_umap_frequency,
173 | )
174 | callbacks.append(auto_umap)
175 |
176 | trainer = Trainer.from_argparse_args(
177 | args,
178 | logger=wandb_logger if args.wandb else None,
179 | callbacks=callbacks,
180 | plugins=DDPPlugin(find_unused_parameters=True),
181 | checkpoint_callback=False,
182 | terminate_on_nan=True,
183 | accelerator="ddp",
184 | log_every_n_steps=args.log_frenquence,
185 | )
186 |
187 | # save code for each run to make each run reproducible
188 | #################################################################
189 | if args.wandb:
190 | experimentdir = f"code/{args.method}_{args.project}_{args.name}_{trainer.logger.version}"
191 | args.codepath = experimentdir
192 | else:
193 | experimentdir = f"code/{args.method}_{args.project}_{args.name}_test"
194 |
195 | if not os.path.exists("code"):
196 | os.mkdir("code")
197 |
198 | if os.path.exists(experimentdir):
199 | print(experimentdir + ' : exists. overwrite it.')
200 | shutil.rmtree(experimentdir)
201 | os.mkdir(experimentdir)
202 | else:
203 | os.mkdir(experimentdir)
204 |
205 | shutil.copytree(f"solo", os.path.join(experimentdir, 'solo'))
206 | shutil.copytree(f"bash_files", os.path.join(experimentdir, 'bash_files'))
207 | shutil.copyfile(f"main_pretrain.py", os.path.join(experimentdir, 'main_pretrain.py'))
208 | #################################################################
209 |
210 | if args.dali:
211 | trainer.fit(model, val_dataloaders=val_loader)
212 | else:
213 | trainer.fit(model, train_loader, val_loader)
214 |
215 |
216 | if __name__ == "__main__":
217 | main()
218 |
--------------------------------------------------------------------------------
/solo/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 |
21 | from solo import args, losses, methods, utils
22 |
23 | __all__ = ["args", "losses", "methods", "utils"]
24 |
--------------------------------------------------------------------------------
/solo/args/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | from solo.args import dataset, setup, utils
21 |
22 | __all__ = ["dataset", "setup", "utils"]
23 |
24 |
25 |
--------------------------------------------------------------------------------
/solo/args/dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | from argparse import ArgumentParser
21 | from pathlib import Path
22 |
23 |
24 | def dataset_args(parser: ArgumentParser):
25 | """Adds dataset-related arguments to a parser.
26 |
27 | Args:
28 | parser (ArgumentParser): parser to add dataset args to.
29 | """
30 |
31 | SUPPORTED_DATASETS = [
32 | "cifar10",
33 | "cifar100",
34 | "stl10",
35 | "imagenet",
36 | "imagenet100",
37 | "custom",
38 | ]
39 |
40 | parser.add_argument("--dataset", choices=SUPPORTED_DATASETS, type=str, required=True)
41 |
42 | # dataset path
43 | parser.add_argument("--data_dir", type=Path, required=True)
44 | parser.add_argument("--train_dir", type=Path, default=None)
45 | parser.add_argument("--val_dir", type=Path, default=None)
46 |
47 | # dali (imagenet-100/imagenet/custom only)
48 | parser.add_argument("--dali", action="store_true")
49 | parser.add_argument("--dali_device", type=str, default="gpu")
50 |
51 | # custom dataset only
52 | parser.add_argument("--no_labels", action="store_true")
53 |
54 |
55 | def augmentations_args(parser: ArgumentParser):
56 | """Adds augmentation-related arguments to a parser.
57 |
58 | Args:
59 | parser (ArgumentParser): parser to add augmentation args to.
60 | """
61 |
62 | # cropping
63 | parser.add_argument("--multicrop", action="store_true")
64 | parser.add_argument("--num_crops", type=int, default=2)
65 | parser.add_argument("--num_small_crops", type=int, default=0)
66 |
67 | # augmentations
68 | parser.add_argument("--brightness", type=float, required=True, nargs="+")
69 | parser.add_argument("--contrast", type=float, required=True, nargs="+")
70 | parser.add_argument("--saturation", type=float, required=True, nargs="+")
71 | parser.add_argument("--hue", type=float, required=True, nargs="+")
72 | parser.add_argument("--gaussian_prob", type=float, default=[0.5], nargs="+")
73 | parser.add_argument("--solarization_prob", type=float, default=[0.0], nargs="+")
74 | parser.add_argument("--min_scale", type=float, default=[0.08], nargs="+")
75 |
76 | # for imagenet or custom dataset
77 | parser.add_argument("--size", type=int, default=[224], nargs="+")
78 |
79 | # for custom dataset
80 | parser.add_argument("--mean", type=float, default=[0.485, 0.456, 0.406], nargs="+")
81 | parser.add_argument("--std", type=float, default=[0.228, 0.224, 0.225], nargs="+")
82 |
83 | # debug
84 | parser.add_argument("--debug_augmentations", action="store_true")
85 |
--------------------------------------------------------------------------------
/solo/args/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import argparse
21 |
22 | import pytorch_lightning as pl
23 | from solo.args.dataset import augmentations_args, dataset_args
24 | from solo.args.utils import additional_setup_linear, additional_setup_pretrain
25 | from solo.methods import METHODS
26 | from solo.utils.checkpointer import Checkpointer
27 |
28 | try:
29 | from solo.utils.auto_umap import AutoUMAP
30 | except ImportError:
31 | _umap_available = False
32 | else:
33 | _umap_available = True
34 | import os
35 |
36 |
37 | def parse_args_pretrain() -> argparse.Namespace:
38 | """Parses dataset, augmentation, pytorch lightning, model specific and additional args.
39 |
40 | First adds shared args such as dataset, augmentation and pytorch lightning args, then pulls the
41 | model name from the command and proceeds to add model specific args from the desired class. If
42 | wandb is enabled, it adds checkpointer args. Finally, adds additional non-user given parameters.
43 |
44 | Returns:
45 | argparse.Namespace: a namespace containing all args needed for pretraining.
46 | """
47 |
48 | parser = argparse.ArgumentParser()
49 |
50 | # add current working path
51 | parser.add_argument("--runpath", default=os.getcwd(), type=str)
52 |
53 | # add code saving path
54 | parser.add_argument("--codepath", default=os.getcwd(), type=str)
55 |
56 | # add current working path
57 | parser.add_argument("--log_frenquence", default=50, type=int)
58 |
59 | # add shared arguments
60 | dataset_args(parser)
61 | augmentations_args(parser)
62 |
63 | # add pytorch lightning trainer args
64 | parser = pl.Trainer.add_argparse_args(parser)
65 |
66 | # add method-specific arguments
67 | parser.add_argument("--method", type=str)
68 |
69 | # THIS LINE IS KEY TO PULL THE MODEL NAME
70 | temp_args, _ = parser.parse_known_args()
71 |
72 | # add model specific args
73 | parser = METHODS[temp_args.method].add_model_specific_args(parser)
74 |
75 | # add auto umap args
76 | parser.add_argument("--auto_umap", action="store_true")
77 |
78 | # optionally add checkpointer and AutoUMAP args
79 | temp_args, _ = parser.parse_known_args()
80 | if temp_args.wandb:
81 | parser = Checkpointer.add_checkpointer_args(parser)
82 |
83 | if _umap_available and temp_args.auto_umap:
84 | parser = AutoUMAP.add_auto_umap_args(parser)
85 |
86 | # parse args
87 | args = parser.parse_args()
88 |
89 |
90 | # prepare arguments with additional setup
91 | additional_setup_pretrain(args)
92 |
93 | return args
94 |
95 |
96 | def parse_args_linear() -> argparse.Namespace:
97 | """Parses feature extractor, dataset, pytorch lightning, linear eval specific and additional args.
98 |
99 | First adds and arg for the pretrained feature extractor, then adds dataset, pytorch lightning
100 | and linear eval specific args. If wandb is enabled, it adds checkpointer args. Finally, adds
101 | additional non-user given parameters.
102 |
103 | Returns:
104 | argparse.Namespace: a namespace containing all args needed for pretraining.
105 | """
106 |
107 | parser = argparse.ArgumentParser()
108 |
109 | parser.add_argument("--pretrained_feature_extractor", type=str)
110 |
111 | # add shared arguments
112 | dataset_args(parser)
113 |
114 | # add pytorch lightning trainer args
115 | parser = pl.Trainer.add_argparse_args(parser)
116 |
117 | # linear model
118 | parser = METHODS["linear"].add_model_specific_args(parser)
119 |
120 | # THIS LINE IS KEY TO PULL WANDB
121 | temp_args, _ = parser.parse_known_args()
122 |
123 | # add checkpointer args (only if logging is enabled)
124 | if temp_args.wandb:
125 | parser = Checkpointer.add_checkpointer_args(parser)
126 |
127 | # parse args
128 | args = parser.parse_args()
129 | additional_setup_linear(args)
130 |
131 | return args
132 |
--------------------------------------------------------------------------------
/solo/args/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import os
21 | from argparse import Namespace
22 |
23 | N_CLASSES_PER_DATASET = {
24 | "cifar10": 10,
25 | "cifar100": 100,
26 | "stl10": 10,
27 | "imagenet": 1000,
28 | "imagenet100": 100,
29 | }
30 |
31 |
32 | def additional_setup_pretrain(args: Namespace):
33 | """Provides final setup for pretraining to non-user given parameters by changing args.
34 |
35 | Parsers arguments to extract the number of classes of a dataset, create
36 | transformations kwargs, correctly parse gpus, identify if a cifar dataset
37 | is being used and adjust the lr.
38 |
39 | Args:
40 | args (Namespace): object that needs to contain, at least:
41 | - dataset: dataset name.
42 | - brightness, contrast, saturation, hue, min_scale: required augmentations
43 | settings.
44 | - multicrop: flag to use multicrop.
45 | - dali: flag to use dali.
46 | - optimizer: optimizer name being used.
47 | - gpus: list of gpus to use.
48 | - lr: learning rate.
49 |
50 | [optional]
51 | - gaussian_prob, solarization_prob: optional augmentations settings.
52 | """
53 |
54 | if args.dataset in N_CLASSES_PER_DATASET:
55 | args.num_classes = N_CLASSES_PER_DATASET[args.dataset]
56 | else:
57 | # hack to maintain the current pipeline
58 | # even if the custom dataset doesn't have any labels
59 | dir_path = args.data_dir / args.train_dir
60 | args.num_classes = max(
61 | 1,
62 | len([entry.name for entry in os.scandir(dir_path) if entry.is_dir]),
63 | )
64 |
65 | unique_augs = max(
66 | len(p)
67 | for p in [
68 | args.brightness,
69 | args.contrast,
70 | args.saturation,
71 | args.hue,
72 | args.gaussian_prob,
73 | args.solarization_prob,
74 | args.min_scale,
75 | args.size,
76 | ]
77 | )
78 | # if args.method != "simclr_interintra_neg":
79 | # assert unique_augs == args.num_crops or unique_augs == 1
80 |
81 | # assert that either all unique augmentation pipelines have a unique
82 | # parameter or that a single parameter is replicated to all pipelines
83 | for p in [
84 | "brightness",
85 | "contrast",
86 | "saturation",
87 | "hue",
88 | "gaussian_prob",
89 | "solarization_prob",
90 | "min_scale",
91 | "size",
92 | ]:
93 | values = getattr(args, p)
94 | n = len(values)
95 | assert n == unique_augs or n == 1
96 |
97 | if n == 1:
98 | setattr(args, p, getattr(args, p) * unique_augs)
99 |
100 | args.unique_augs = unique_augs
101 |
102 | if unique_augs > 1:
103 | args.transform_kwargs = [
104 | dict(
105 | brightness=brightness,
106 | contrast=contrast,
107 | saturation=saturation,
108 | hue=hue,
109 | gaussian_prob=gaussian_prob,
110 | solarization_prob=solarization_prob,
111 | min_scale=min_scale,
112 | size=size,
113 | )
114 | for (
115 | brightness,
116 | contrast,
117 | saturation,
118 | hue,
119 | gaussian_prob,
120 | solarization_prob,
121 | min_scale,
122 | size,
123 | ) in zip(
124 | args.brightness,
125 | args.contrast,
126 | args.saturation,
127 | args.hue,
128 | args.gaussian_prob,
129 | args.solarization_prob,
130 | args.min_scale,
131 | args.size,
132 | )
133 | ]
134 |
135 | elif not args.multicrop:
136 | args.transform_kwargs = dict(
137 | brightness=args.brightness[0],
138 | contrast=args.contrast[0],
139 | saturation=args.saturation[0],
140 | hue=args.hue[0],
141 | gaussian_prob=args.gaussian_prob[0],
142 | solarization_prob=args.solarization_prob[0],
143 | min_scale=args.min_scale[0],
144 | size=args.size[0],
145 | )
146 | else:
147 | args.transform_kwargs = dict(
148 | brightness=args.brightness[0],
149 | contrast=args.contrast[0],
150 | saturation=args.saturation[0],
151 | hue=args.hue[0],
152 | gaussian_prob=args.gaussian_prob[0],
153 | solarization_prob=args.solarization_prob[0],
154 | )
155 |
156 | # add support for custom mean and std
157 | if args.dataset == "custom":
158 | if isinstance(args.transform_kwargs, dict):
159 | args.transform_kwargs["mean"] = args.mean
160 | args.transform_kwargs["std"] = args.std
161 | else:
162 | for kwargs in args.transform_kwargs:
163 | kwargs["mean"] = args.mean
164 | kwargs["std"] = args.std
165 |
166 | if args.dataset in ["cifar10", "cifar100", "stl10"]:
167 | if isinstance(args.transform_kwargs, dict):
168 | del args.transform_kwargs["size"]
169 | else:
170 | for kwargs in args.transform_kwargs:
171 | del kwargs["size"]
172 |
173 | # create backbone-specific arguments
174 | args.backbone_args = {"cifar": True if args.dataset in ["cifar10", "cifar100"] else False}
175 | if "resnet" in args.encoder:
176 | args.backbone_args["zero_init_residual"] = args.zero_init_residual
177 | else:
178 | # dataset related for all transformers
179 | dataset = args.dataset
180 | if "cifar" in dataset:
181 | args.backbone_args["img_size"] = 32
182 | elif "stl" in dataset:
183 | args.backbone_args["img_size"] = 96
184 | elif "imagenet" in dataset:
185 | args.backbone_args["img_size"] = 224
186 | elif "custom" in dataset:
187 | transform_kwargs = args.transform_kwargs
188 | if isinstance(transform_kwargs, list):
189 | args.backbone_args["img_size"] = transform_kwargs[0]["size"]
190 | else:
191 | args.backbone_args["img_size"] = transform_kwargs["size"]
192 |
193 | if "vit" in args.encoder:
194 | args.backbone_args["patch_size"] = args.patch_size
195 |
196 | del args.zero_init_residual
197 | del args.patch_size
198 |
199 | if args.dali:
200 | assert args.dataset in ["imagenet100", "imagenet", "custom"]
201 |
202 | args.extra_optimizer_args = {}
203 | if args.optimizer == "sgd":
204 | args.extra_optimizer_args["momentum"] = 0.9
205 |
206 | if isinstance(args.gpus, int):
207 | args.gpus = [args.gpus]
208 | elif isinstance(args.gpus, str):
209 | args.gpus = [int(gpu) for gpu in args.gpus.split(",") if gpu]
210 |
211 | # adjust lr according to batch size
212 | args.lr = args.lr * args.batch_size * len(args.gpus) / 256
213 |
214 |
215 | def additional_setup_linear(args: Namespace):
216 | """Provides final setup for linear evaluation to non-user given parameters by changing args.
217 |
218 | Parsers arguments to extract the number of classes of a dataset, correctly parse gpus, identify
219 | if a cifar dataset is being used and adjust the lr.
220 |
221 | Args:
222 | args: Namespace object that needs to contain, at least:
223 | - dataset: dataset name.
224 | - optimizer: optimizer name being used.
225 | - gpus: list of gpus to use.
226 | - lr: learning rate.
227 | """
228 |
229 | assert args.dataset in N_CLASSES_PER_DATASET
230 | args.num_classes = N_CLASSES_PER_DATASET[args.dataset]
231 |
232 | # create backbone-specific arguments
233 | args.backbone_args = {"cifar": True if args.dataset in ["cifar10", "cifar100"] else False}
234 |
235 | if "resnet" not in args.encoder:
236 | # dataset related for all transformers
237 | dataset = args.dataset
238 | if "cifar" in dataset:
239 | args.backbone_args["img_size"] = 32
240 | elif "stl" in dataset:
241 | args.backbone_args["img_size"] = 96
242 | elif "imagenet" in dataset:
243 | args.backbone_args["img_size"] = 224
244 | elif "custom" in dataset:
245 | transform_kwargs = args.transform_kwargs
246 | if isinstance(transform_kwargs, list):
247 | args.backbone_args["img_size"] = transform_kwargs[0]["size"]
248 | else:
249 | args.backbone_args["img_size"] = transform_kwargs["size"]
250 |
251 | if "vit" in args.encoder:
252 | args.backbone_args["patch_size"] = args.patch_size
253 |
254 | del args.patch_size
255 |
256 | if args.dali:
257 | assert args.dataset in ["imagenet100", "imagenet"]
258 |
259 | args.extra_optimizer_args = {}
260 | if args.optimizer == "sgd":
261 | args.extra_optimizer_args["momentum"] = 0.9
262 |
263 | if isinstance(args.gpus, int):
264 | args.gpus = [args.gpus]
265 | elif isinstance(args.gpus, str):
266 | args.gpus = [int(gpu) for gpu in args.gpus.split(",") if gpu]
267 |
--------------------------------------------------------------------------------
/solo/losses/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | from solo.losses.moco import moco_loss_func
21 | from solo.losses.dual_temperature_loss import dual_temperature_loss_func
22 |
23 | __all__ = [
24 | "moco_loss_func",
25 | "dual_temperature_loss_func",
26 | ]
27 |
--------------------------------------------------------------------------------
/solo/losses/dual_temperature_loss.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import torch
21 | import torch.nn.functional as F
22 |
23 | def dual_temperature_loss_func(
24 | query: torch.Tensor,
25 | key: torch.Tensor,
26 | temperature=0.1,
27 | dt_m=10,
28 | ) -> torch.Tensor:
29 | """
30 | query: anchor sample.
31 | key: positive sample.
32 | temperature: intra-anchor hardness-awareness control temperature.
33 | dt_m: the scalar number to get inter-anchor hardness awareness temperature.
34 | inter-anchor hardness awareness temperature is calculated by dt_m * temperature
35 | """
36 |
37 | # intra-anchor hardness-awareness
38 | b = query.size(0)
39 | pos = torch.einsum("nc,nc->n", [query, key]).unsqueeze(-1)
40 |
41 | # Selecte the intra negative samples according the updata time,
42 | neg = torch.einsum("nc,ck->nk", [query, key.T])
43 | mask_neg = torch.ones_like(neg, dtype=bool)
44 | mask_neg.fill_diagonal_(False)
45 | neg = neg[mask_neg].reshape(neg.size(0), neg.size(1)-1)
46 | logits = torch.cat([pos, neg], dim=1)
47 |
48 | logits_intra = logits / temperature
49 | prob_intra = F.softmax(logits_intra, dim=1)
50 |
51 | # inter-anchor hardness-awareness
52 | logits_inter = logits / (temperature*dt_m)
53 | prob_inter = F.softmax(logits_inter, dim=1)
54 |
55 | # hardness-awareness factor
56 | inter_intra = (1 - prob_inter[:, 0]) / (1 - prob_intra[:, 0])
57 |
58 | loss = -torch.nn.functional.log_softmax(logits_intra, dim=-1)[:, 0]
59 |
60 | # final loss
61 | loss = inter_intra.detach() * loss
62 | loss = loss.mean()
63 |
64 | return loss
65 |
--------------------------------------------------------------------------------
/solo/losses/moco.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import torch
21 | import torch.nn.functional as F
22 |
23 |
24 | def moco_loss_func(
25 | query: torch.Tensor, key: torch.Tensor, queue: torch.Tensor, temperature=0.1
26 | ) -> torch.Tensor:
27 | """Computes MoCo's loss given a batch of queries from view 1, a batch of keys from view 2 and a
28 | queue of past elements.
29 |
30 | Args:
31 | query (torch.Tensor): NxD Tensor containing the queries from view 1.
32 | key (torch.Tensor): NxD Tensor containing the queries from view 2.
33 | queue (torch.Tensor): a queue of negative samples for the contrastive loss.
34 | temperature (float, optional): [description]. temperature of the softmax in the contrastive
35 | loss. Defaults to 0.1.
36 |
37 | Returns:
38 | torch.Tensor: MoCo loss.
39 | """
40 |
41 | pos = torch.einsum("nc,nc->n", [query, key]).unsqueeze(-1)
42 | neg = torch.einsum("nc,ck->nk", [query, queue])
43 | logits = torch.cat([pos, neg], dim=1)
44 | logits /= temperature
45 | targets = torch.zeros(query.size(0), device=query.device, dtype=torch.long)
46 | return F.cross_entropy(logits, targets)
47 |
--------------------------------------------------------------------------------
/solo/methods/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 |
21 | from solo.methods.base import BaseMethod
22 | from solo.methods.mocov2plus import MoCoV2Plus
23 |
24 | # dual temperature method
25 | from solo.methods.simco_dual_temperature import SimCo_DualTemperature
26 | from solo.methods.simmoco_dual_temperature import SimMoCo_DualTemperature
27 | from solo.methods.mocov2 import MoCoV2
28 |
29 | METHODS = {
30 | # base classes
31 | "base": BaseMethod,
32 | # methods
33 | "mocov2plus": MoCoV2Plus,
34 |
35 | "simco_dual_temperature": SimCo_DualTemperature,
36 | "simmoco_dual_temperature": SimMoCo_DualTemperature,
37 | "mocov2": MoCoV2,
38 |
39 | }
40 |
41 | __all__ = [
42 | "BaseMethod",
43 | "MoCoV2Plus",
44 | "SimCo_DualTemperature",
45 | "SimMoCo_DualTemperature",
46 | "MoCoV2",
47 | ]
48 |
49 | try:
50 | from solo.methods import dali # noqa: F401
51 | except ImportError:
52 | pass
53 | else:
54 | __all__.append("dali")
55 |
--------------------------------------------------------------------------------
/solo/methods/dali.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import math
21 | from abc import ABC
22 | from pathlib import Path
23 | from typing import List
24 |
25 | import torch
26 | import torch.nn as nn
27 | from nvidia.dali.plugin.pytorch import DALIGenericIterator, LastBatchPolicy
28 | from solo.utils.dali_dataloader import (
29 | CustomNormalPipeline,
30 | CustomTransform,
31 | ImagenetTransform,
32 | MulticropPretrainPipeline,
33 | NormalPipeline,
34 | PretrainPipeline,
35 | )
36 |
37 |
38 | class BaseWrapper(DALIGenericIterator):
39 | """Temporary fix to handle LastBatchPolicy.DROP."""
40 |
41 | def __len__(self):
42 | size = (
43 | self._size_no_pad // self._shards_num
44 | if self._last_batch_policy == LastBatchPolicy.DROP
45 | else self.size
46 | )
47 | if self._reader_name:
48 | if self._last_batch_policy != LastBatchPolicy.DROP:
49 | return math.ceil(size / self.batch_size)
50 | else:
51 | return size // self.batch_size
52 | else:
53 | if self._last_batch_policy != LastBatchPolicy.DROP:
54 | return math.ceil(size / (self._num_gpus * self.batch_size))
55 | else:
56 | return size // (self._num_gpus * self.batch_size)
57 |
58 |
59 | class PretrainWrapper(BaseWrapper):
60 | def __init__(
61 | self,
62 | model_batch_size: int,
63 | model_rank: int,
64 | model_device: str,
65 | conversion_map: List[int] = None,
66 | *args,
67 | **kwargs,
68 | ):
69 | """Adds indices to a batch fetched from the parent.
70 |
71 | Args:
72 | model_batch_size (int): batch size.
73 | model_rank (int): rank of the current process.
74 | model_device (str): id of the current device.
75 | conversion_map (List[int], optional): list of integeres that map each index
76 | to a class label. If nothing is passed, no label mapping needs to be done.
77 | Defaults to None.
78 | """
79 |
80 | super().__init__(*args, **kwargs)
81 | self.model_batch_size = model_batch_size
82 | self.model_rank = model_rank
83 | self.model_device = model_device
84 | self.conversion_map = conversion_map
85 | if self.conversion_map is not None:
86 | self.conversion_map = torch.tensor(
87 | self.conversion_map, dtype=torch.float32, device=self.model_device
88 | ).reshape(-1, 1)
89 | self.conversion_map = nn.Embedding.from_pretrained(self.conversion_map)
90 |
91 | def __next__(self):
92 | batch = super().__next__()[0]
93 | # PyTorch Lightning does double buffering
94 | # https://github.com/PyTorchLightning/pytorch-lightning/issues/1316,
95 | # and as DALI owns the tensors it returns the content of it is trashed so the copy needs,
96 | # to be made before returning.
97 |
98 | if self.conversion_map is not None:
99 | *all_X, indexes = [batch[v] for v in self.output_map]
100 | targets = self.conversion_map(indexes).flatten().long().detach().clone()
101 | indexes = indexes.flatten().long().detach().clone()
102 | else:
103 | *all_X, targets = [batch[v] for v in self.output_map]
104 | targets = targets.squeeze(-1).long().detach().clone()
105 | # creates dummy indexes
106 | indexes = (
107 | (
108 | torch.arange(self.model_batch_size, device=self.model_device)
109 | + (self.model_rank * self.model_batch_size)
110 | )
111 | .detach()
112 | .clone()
113 | )
114 |
115 | all_X = [x.detach().clone() for x in all_X]
116 | return [indexes, all_X, targets]
117 |
118 |
119 | class Wrapper(BaseWrapper):
120 | def __next__(self):
121 | batch = super().__next__()
122 | x, target = batch[0]["x"], batch[0]["label"]
123 | target = target.squeeze(-1).long()
124 | # PyTorch Lightning does double buffering
125 | # https://github.com/PyTorchLightning/pytorch-lightning/issues/1316,
126 | # and as DALI owns the tensors it returns the content of it is trashed so the copy needs,
127 | # to be made before returning.
128 | x = x.detach().clone()
129 | target = target.detach().clone()
130 | return x, target
131 |
132 |
133 | class PretrainABC(ABC):
134 | """Abstract pretrain class that returns a train_dataloader using dali."""
135 |
136 | def train_dataloader(self) -> DALIGenericIterator:
137 | """Returns a train dataloader using dali. Supports multi-crop and asymmetric augmentations.
138 |
139 | Returns:
140 | DALIGenericIterator: a train dataloader in the form of a dali pipeline object wrapped
141 | with PretrainWrapper.
142 | """
143 |
144 | device_id = self.local_rank
145 | shard_id = self.global_rank
146 | num_shards = self.trainer.world_size
147 |
148 | # get data arguments from model
149 | dali_device = self.extra_args["dali_device"]
150 |
151 | # data augmentations
152 | unique_augs = self.extra_args["unique_augs"]
153 | transform_kwargs = self.extra_args["transform_kwargs"]
154 |
155 | num_workers = self.extra_args["num_workers"]
156 | data_dir = Path(self.extra_args["data_dir"])
157 | train_dir = Path(self.extra_args["train_dir"])
158 |
159 | # hack to encode image indexes into the labels
160 | self.encode_indexes_into_labels = self.extra_args["encode_indexes_into_labels"]
161 |
162 | # handle custom data by creating the needed pipeline
163 | dataset = self.extra_args["dataset"]
164 | if dataset in ["imagenet100", "imagenet"]:
165 | transform_pipeline = ImagenetTransform
166 | elif dataset == "custom":
167 | transform_pipeline = CustomTransform
168 | else:
169 | raise ValueError(dataset, "is not supported, used [imagenet, imagenet100 or custom]")
170 |
171 | if self.multicrop:
172 | num_crops = [self.num_crops, self.num_small_crops]
173 | size_crops = [224, 96]
174 | min_scales = [0.14, 0.05]
175 | max_scale_crops = [1.0, 0.14]
176 |
177 | transforms = []
178 | for size, min_scale, max_scale in zip(size_crops, min_scales, max_scale_crops):
179 | transform = transform_pipeline(
180 | device=dali_device,
181 | **transform_kwargs,
182 | size=size,
183 | min_scale=min_scale,
184 | max_scale=max_scale,
185 | )
186 | transforms.append(transform)
187 | train_pipeline = MulticropPretrainPipeline(
188 | data_dir / train_dir,
189 | batch_size=self.batch_size,
190 | transforms=transforms,
191 | num_crops=num_crops,
192 | device=dali_device,
193 | device_id=device_id,
194 | shard_id=shard_id,
195 | num_shards=num_shards,
196 | num_threads=num_workers,
197 | no_labels=self.extra_args["no_labels"],
198 | encode_indexes_into_labels=self.encode_indexes_into_labels,
199 | )
200 | output_map = [
201 | *[f"large{i}" for i in range(num_crops[0])],
202 | *[f"small{i}" for i in range(num_crops[1])],
203 | "label",
204 | ]
205 |
206 | else:
207 | if unique_augs > 1:
208 | transform = [
209 | transform_pipeline(
210 | device=dali_device,
211 | **kwargs,
212 | max_scale=1.0,
213 | )
214 | for kwargs in transform_kwargs
215 | ]
216 | else:
217 | transform = transform_pipeline(
218 | device=dali_device,
219 | **transform_kwargs,
220 | max_scale=1.0,
221 | )
222 |
223 | train_pipeline = PretrainPipeline(
224 | data_dir / train_dir,
225 | batch_size=self.batch_size,
226 | transform=transform,
227 | device=dali_device,
228 | device_id=device_id,
229 | shard_id=shard_id,
230 | num_shards=num_shards,
231 | num_threads=num_workers,
232 | no_labels=self.extra_args["no_labels"],
233 | encode_indexes_into_labels=self.encode_indexes_into_labels,
234 | )
235 | output_map = [f"large{i}" for i in range(self.num_crops)] + ["label"]
236 |
237 | policy = LastBatchPolicy.DROP
238 | conversion_map = train_pipeline.conversion_map if self.encode_indexes_into_labels else None
239 | train_loader = PretrainWrapper(
240 | model_batch_size=self.batch_size,
241 | model_rank=device_id,
242 | model_device=self.device,
243 | conversion_map=conversion_map,
244 | pipelines=train_pipeline,
245 | output_map=output_map,
246 | reader_name="Reader",
247 | last_batch_policy=policy,
248 | auto_reset=True,
249 | )
250 |
251 | self.dali_epoch_size = train_pipeline.epoch_size("Reader")
252 |
253 | return train_loader
254 |
255 |
256 | class ClassificationABC(ABC):
257 | """Abstract classification class that returns a train_dataloader and val_dataloader using
258 | dali."""
259 |
260 | def train_dataloader(self) -> DALIGenericIterator:
261 | device_id = self.local_rank
262 | shard_id = self.global_rank
263 | num_shards = self.trainer.world_size
264 |
265 | num_workers = self.extra_args["num_workers"]
266 | dali_device = self.extra_args["dali_device"]
267 | data_dir = Path(self.extra_args["data_dir"])
268 | train_dir = Path(self.extra_args["train_dir"])
269 |
270 | # handle custom data by creating the needed pipeline
271 | dataset = self.extra_args["dataset"]
272 | if dataset in ["imagenet100", "imagenet"]:
273 | pipeline_class = NormalPipeline
274 | elif dataset == "custom":
275 | pipeline_class = CustomNormalPipeline
276 | else:
277 | raise ValueError(dataset, "is not supported, used [imagenet, imagenet100 or custom]")
278 |
279 | train_pipeline = pipeline_class(
280 | data_dir / train_dir,
281 | validation=False,
282 | batch_size=self.batch_size,
283 | device=dali_device,
284 | device_id=device_id,
285 | shard_id=shard_id,
286 | num_shards=num_shards,
287 | num_threads=num_workers,
288 | )
289 | train_loader = Wrapper(
290 | train_pipeline,
291 | output_map=["x", "label"],
292 | reader_name="Reader",
293 | last_batch_policy=LastBatchPolicy.DROP,
294 | auto_reset=True,
295 | )
296 | return train_loader
297 |
298 | def val_dataloader(self) -> DALIGenericIterator:
299 | device_id = self.local_rank
300 | shard_id = self.global_rank
301 | num_shards = self.trainer.world_size
302 |
303 | num_workers = self.extra_args["num_workers"]
304 | dali_device = self.extra_args["dali_device"]
305 | data_dir = Path(self.extra_args["data_dir"])
306 | val_dir = Path(self.extra_args["val_dir"])
307 |
308 | # handle custom data by creating the needed pipeline
309 | dataset = self.extra_args["dataset"]
310 | if dataset in ["imagenet100", "imagenet"]:
311 | pipeline_class = NormalPipeline
312 | elif dataset == "custom":
313 | pipeline_class = CustomNormalPipeline
314 | else:
315 | raise ValueError(dataset, "is not supported, used [imagenet, imagenet100 or custom]")
316 |
317 | val_pipeline = pipeline_class(
318 | data_dir / val_dir,
319 | validation=True,
320 | batch_size=self.batch_size,
321 | device=dali_device,
322 | device_id=device_id,
323 | shard_id=shard_id,
324 | num_shards=num_shards,
325 | num_threads=num_workers,
326 | )
327 |
328 | val_loader = Wrapper(
329 | val_pipeline,
330 | output_map=["x", "label"],
331 | reader_name="Reader",
332 | last_batch_policy=LastBatchPolicy.PARTIAL,
333 | auto_reset=True,
334 | )
335 | return val_loader
336 |
--------------------------------------------------------------------------------
/solo/methods/mocov2.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import argparse
21 | from typing import Any, Dict, List, Sequence, Tuple
22 |
23 | import torch
24 | import torch.nn as nn
25 | import torch.nn.functional as F
26 | from solo.losses.moco import moco_loss_func
27 | from solo.methods.base import BaseMomentumMethod
28 | from solo.utils.momentum import initialize_momentum_params
29 | from solo.utils.misc import gather
30 |
31 |
32 | class MoCoV2(BaseMomentumMethod):
33 | queue: torch.Tensor
34 |
35 | def __init__(
36 | self,
37 | proj_output_dim: int,
38 | proj_hidden_dim: int,
39 | temperature: float,
40 | queue_size: int,
41 | **kwargs
42 | ):
43 | """Implements MoCo.
44 |
45 | Args:
46 | proj_output_dim (int): number of dimensions of projected features.
47 | proj_hidden_dim (int): number of neurons of the hidden layers of the projector.
48 | temperature (float): temperature for the softmax in the contrastive loss.
49 | queue_size (int): number of samples to keep in the queue.
50 | """
51 |
52 | super().__init__(**kwargs)
53 |
54 | self.temperature = temperature
55 | self.queue_size = queue_size
56 |
57 | # projector
58 | self.projector = nn.Sequential(
59 | nn.Linear(self.features_dim, proj_hidden_dim),
60 | nn.ReLU(),
61 | nn.Linear(proj_hidden_dim, proj_output_dim),
62 | )
63 |
64 | # momentum projector
65 | self.momentum_projector = nn.Sequential(
66 | nn.Linear(self.features_dim, proj_hidden_dim),
67 | nn.ReLU(),
68 | nn.Linear(proj_hidden_dim, proj_output_dim),
69 | )
70 | initialize_momentum_params(self.projector, self.momentum_projector)
71 |
72 | # create the queue
73 | self.register_buffer("queue", torch.randn(2, proj_output_dim, queue_size))
74 | self.queue = nn.functional.normalize(self.queue, dim=1)
75 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
76 |
77 | @staticmethod
78 | def add_model_specific_args(parent_parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
79 | parent_parser = super(MoCoV2, MoCoV2).add_model_specific_args(parent_parser)
80 | parser = parent_parser.add_argument_group("mocov2")
81 |
82 | # projector
83 | parser.add_argument("--proj_output_dim", type=int, default=128)
84 | parser.add_argument("--proj_hidden_dim", type=int, default=2048)
85 |
86 | # parameters
87 | parser.add_argument("--temperature", type=float, default=0.1)
88 |
89 | # queue settings
90 | parser.add_argument("--queue_size", default=65536, type=int)
91 |
92 | return parent_parser
93 |
94 | @property
95 | def learnable_params(self) -> List[dict]:
96 | """Adds projector parameters together with parent's learnable parameters.
97 |
98 | Returns:
99 | List[dict]: list of learnable parameters.
100 | """
101 |
102 | extra_learnable_params = [{"params": self.projector.parameters()}]
103 | return super().learnable_params + extra_learnable_params
104 |
105 | @property
106 | def momentum_pairs(self) -> List[Tuple[Any, Any]]:
107 | """Adds (projector, momentum_projector) to the parent's momentum pairs.
108 |
109 | Returns:
110 | List[Tuple[Any, Any]]: list of momentum pairs.
111 | """
112 |
113 | extra_momentum_pairs = [(self.projector, self.momentum_projector)]
114 | return super().momentum_pairs + extra_momentum_pairs
115 |
116 | @torch.no_grad()
117 | def _dequeue_and_enqueue(self, keys: torch.Tensor):
118 | """Adds new samples and removes old samples from the queue in a fifo manner.
119 |
120 | Args:
121 | keys (torch.Tensor): output features of the momentum encoder.
122 | """
123 |
124 | batch_size = keys.shape[1]
125 | ptr = int(self.queue_ptr) # type: ignore
126 | assert self.queue_size % batch_size == 0 # for simplicity
127 |
128 | # replace the keys at ptr (dequeue and enqueue)
129 | keys = keys.permute(0, 2, 1)
130 | self.queue[:, :, ptr : ptr + batch_size] = keys
131 | ptr = (ptr + batch_size) % self.queue_size # move pointer
132 | self.queue_ptr[0] = ptr # type: ignore
133 |
134 | def forward(self, X: torch.Tensor, *args, **kwargs) -> Dict[str, Any]:
135 | """Performs the forward pass of the online encoder and the online projection.
136 |
137 | Args:
138 | X (torch.Tensor): a batch of images in the tensor format.
139 |
140 | Returns:
141 | Dict[str, Any]: a dict containing the outputs of the parent and the projected features.
142 | """
143 |
144 | out = super().forward(X, *args, **kwargs)
145 | q = F.normalize(self.projector(out["feats"]), dim=-1)
146 | return {**out, "q": q}
147 |
148 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
149 | """
150 | Training step for MoCo reusing BaseMomentumMethod training step.
151 |
152 | Args:
153 | batch (Sequence[Any]): a batch of data in the
154 | format of [img_indexes, [X], Y], where [X] is a list of size self.num_crops
155 | containing batches of images.
156 | batch_idx (int): index of the batch.
157 |
158 | Returns:
159 | torch.Tensor: total loss composed of MOCO loss and classification loss.
160 |
161 | """
162 |
163 | out = super().training_step(batch, batch_idx)
164 | class_loss = out["loss"]
165 | feats1, _ = out["feats"]
166 | _, momentum_feats2 = out["momentum_feats"]
167 |
168 | q1 = self.projector(feats1)
169 | q1 = F.normalize(q1, dim=-1)
170 |
171 | with torch.no_grad():
172 | k2 = self.momentum_projector(momentum_feats2)
173 | k2 = F.normalize(k2, dim=-1)
174 |
175 | # ------- contrastive loss -------
176 | # symmetric
177 | queue = self.queue.clone().detach()
178 | nce_loss = moco_loss_func(q1, k2, queue[1], self.temperature)
179 |
180 | # ------- update queue -------
181 | keys = torch.stack((torch.zeros_like(gather(k2)), gather(k2)))
182 | self._dequeue_and_enqueue(keys)
183 |
184 | self.log("train_nce_loss", nce_loss, on_epoch=True, sync_dist=True)
185 |
186 | return nce_loss + class_loss
187 |
--------------------------------------------------------------------------------
/solo/methods/mocov2plus.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import argparse
21 | from typing import Any, Dict, List, Sequence, Tuple
22 |
23 | import torch
24 | import torch.nn as nn
25 | import torch.nn.functional as F
26 | from solo.losses.moco import moco_loss_func
27 | from solo.methods.base import BaseMomentumMethod
28 | from solo.utils.momentum import initialize_momentum_params
29 | from solo.utils.misc import gather
30 |
31 |
32 | class MoCoV2Plus(BaseMomentumMethod):
33 | queue: torch.Tensor
34 |
35 | def __init__(
36 | self,
37 | proj_output_dim: int,
38 | proj_hidden_dim: int,
39 | temperature: float,
40 | queue_size: int,
41 | **kwargs
42 | ):
43 | """Implements MoCo V2+ (https://arxiv.org/abs/2011.10566).
44 |
45 | Args:
46 | proj_output_dim (int): number of dimensions of projected features.
47 | proj_hidden_dim (int): number of neurons of the hidden layers of the projector.
48 | temperature (float): temperature for the softmax in the contrastive loss.
49 | queue_size (int): number of samples to keep in the queue.
50 | """
51 |
52 | super().__init__(**kwargs)
53 |
54 | self.temperature = temperature
55 | self.queue_size = queue_size
56 |
57 | # projector
58 | self.projector = nn.Sequential(
59 | nn.Linear(self.features_dim, proj_hidden_dim),
60 | nn.ReLU(),
61 | nn.Linear(proj_hidden_dim, proj_output_dim),
62 | )
63 |
64 | # momentum projector
65 | self.momentum_projector = nn.Sequential(
66 | nn.Linear(self.features_dim, proj_hidden_dim),
67 | nn.ReLU(),
68 | nn.Linear(proj_hidden_dim, proj_output_dim),
69 | )
70 | initialize_momentum_params(self.projector, self.momentum_projector)
71 |
72 | # create the queue
73 | self.register_buffer("queue", torch.randn(2, proj_output_dim, queue_size))
74 | self.queue = nn.functional.normalize(self.queue, dim=1)
75 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
76 |
77 | @staticmethod
78 | def add_model_specific_args(parent_parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
79 | parent_parser = super(MoCoV2Plus, MoCoV2Plus).add_model_specific_args(parent_parser)
80 | parser = parent_parser.add_argument_group("mocov2plus")
81 |
82 | # projector
83 | parser.add_argument("--proj_output_dim", type=int, default=128)
84 | parser.add_argument("--proj_hidden_dim", type=int, default=2048)
85 |
86 | # parameters
87 | parser.add_argument("--temperature", type=float, default=0.1)
88 |
89 | # queue settings
90 | parser.add_argument("--queue_size", default=65536, type=int)
91 |
92 | return parent_parser
93 |
94 | @property
95 | def learnable_params(self) -> List[dict]:
96 | """Adds projector parameters together with parent's learnable parameters.
97 |
98 | Returns:
99 | List[dict]: list of learnable parameters.
100 | """
101 |
102 | extra_learnable_params = [{"params": self.projector.parameters()}]
103 | return super().learnable_params + extra_learnable_params
104 |
105 | @property
106 | def momentum_pairs(self) -> List[Tuple[Any, Any]]:
107 | """Adds (projector, momentum_projector) to the parent's momentum pairs.
108 |
109 | Returns:
110 | List[Tuple[Any, Any]]: list of momentum pairs.
111 | """
112 |
113 | extra_momentum_pairs = [(self.projector, self.momentum_projector)]
114 | return super().momentum_pairs + extra_momentum_pairs
115 |
116 | @torch.no_grad()
117 | def _dequeue_and_enqueue(self, keys: torch.Tensor):
118 | """Adds new samples and removes old samples from the queue in a fifo manner.
119 |
120 | Args:
121 | keys (torch.Tensor): output features of the momentum encoder.
122 | """
123 |
124 | batch_size = keys.shape[1]
125 | ptr = int(self.queue_ptr) # type: ignore
126 | assert self.queue_size % batch_size == 0 # for simplicity
127 |
128 | # replace the keys at ptr (dequeue and enqueue)
129 | keys = keys.permute(0, 2, 1)
130 | self.queue[:, :, ptr : ptr + batch_size] = keys
131 | ptr = (ptr + batch_size) % self.queue_size # move pointer
132 | self.queue_ptr[0] = ptr # type: ignore
133 |
134 | def forward(self, X: torch.Tensor, *args, **kwargs) -> Dict[str, Any]:
135 | """Performs the forward pass of the online encoder and the online projection.
136 |
137 | Args:
138 | X (torch.Tensor): a batch of images in the tensor format.
139 |
140 | Returns:
141 | Dict[str, Any]: a dict containing the outputs of the parent and the projected features.
142 | """
143 |
144 | out = super().forward(X, *args, **kwargs)
145 | q = F.normalize(self.projector(out["feats"]), dim=-1)
146 | return {**out, "q": q}
147 |
148 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
149 | """
150 | Training step for MoCo reusing BaseMomentumMethod training step.
151 |
152 | Args:
153 | batch (Sequence[Any]): a batch of data in the
154 | format of [img_indexes, [X], Y], where [X] is a list of size self.num_crops
155 | containing batches of images.
156 | batch_idx (int): index of the batch.
157 |
158 | Returns:
159 | torch.Tensor: total loss composed of MOCO loss and classification loss.
160 |
161 | """
162 |
163 | out = super().training_step(batch, batch_idx)
164 | class_loss = out["loss"]
165 | feats1, feats2 = out["feats"]
166 | momentum_feats1, momentum_feats2 = out["momentum_feats"]
167 |
168 | q1 = self.projector(feats1)
169 | q2 = self.projector(feats2)
170 | q1 = F.normalize(q1, dim=-1)
171 | q2 = F.normalize(q2, dim=-1)
172 |
173 | with torch.no_grad():
174 | k1 = self.momentum_projector(momentum_feats1)
175 | k2 = self.momentum_projector(momentum_feats2)
176 | k1 = F.normalize(k1, dim=-1)
177 | k2 = F.normalize(k2, dim=-1)
178 |
179 | # ------- contrastive loss -------
180 | # symmetric
181 | queue = self.queue.clone().detach()
182 | nce_loss = (
183 | moco_loss_func(q1, k2, queue[1], self.temperature)
184 | + moco_loss_func(q2, k1, queue[0], self.temperature)
185 | ) / 2
186 |
187 | # ------- update queue -------
188 | keys = torch.stack((gather(k1), gather(k2)))
189 | self._dequeue_and_enqueue(keys)
190 |
191 | self.log("train_nce_loss", nce_loss, on_epoch=True, sync_dist=True)
192 |
193 | return nce_loss + class_loss
194 |
--------------------------------------------------------------------------------
/solo/methods/simco_dual_temperature.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import argparse
21 | from typing import Any, Dict, List, Sequence, Tuple
22 |
23 | import torch
24 | import torch.nn as nn
25 | import torch.nn.functional as F
26 | from solo.methods.base import BaseMomentumMethod
27 | from solo.utils.momentum import initialize_momentum_params
28 | from solo.losses.dual_temperature_loss import dual_temperature_loss_func
29 |
30 |
31 | class SimCo_DualTemperature(BaseMomentumMethod):
32 | queue: torch.Tensor
33 |
34 | def __init__(
35 | self,
36 | proj_output_dim: int,
37 | proj_hidden_dim: int,
38 | temperature: float,
39 | dt_m: float,
40 | **kwargs
41 | ):
42 | """Implements simco with dual temperature.
43 |
44 | Args:
45 | proj_output_dim (int): number of dimensions of projected features.
46 | proj_hidden_dim (int): number of neurons of the hidden layers of the projector.
47 | temperature (float): temperature for the softmax in the contrastive loss.
48 | queue_size (int): number of samples to keep in the queue.
49 | """
50 |
51 | super().__init__(**kwargs)
52 |
53 | self.temperature = temperature
54 | self.dt_m = dt_m
55 |
56 | # projector
57 | self.projector = nn.Sequential(
58 | nn.Linear(self.features_dim, proj_hidden_dim),
59 | nn.ReLU(),
60 | nn.Linear(proj_hidden_dim, proj_output_dim),
61 | )
62 |
63 | # momentum projector
64 | self.momentum_projector = nn.Sequential(
65 | nn.Linear(self.features_dim, proj_hidden_dim),
66 | nn.ReLU(),
67 | nn.Linear(proj_hidden_dim, proj_output_dim),
68 | )
69 |
70 | initialize_momentum_params(self.projector, self.momentum_projector)
71 |
72 |
73 | @staticmethod
74 | def add_model_specific_args(parent_parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
75 | parent_parser = super(SimCo_DualTemperature, SimCo_DualTemperature).add_model_specific_args(parent_parser)
76 | parser = parent_parser.add_argument_group("simco_dual_temperature")
77 |
78 | # projector
79 | parser.add_argument("--proj_output_dim", type=int, default=128)
80 | parser.add_argument("--proj_hidden_dim", type=int, default=2048)
81 |
82 | # parameters
83 | parser.add_argument("--temperature", type=float, default=0.1)
84 | parser.add_argument("--dt_m", type=float, default=10)
85 |
86 | return parent_parser
87 |
88 | @property
89 | def learnable_params(self) -> List[dict]:
90 | """Adds projector parameters together with parent's learnable parameters.
91 |
92 | Returns:
93 | List[dict]: list of learnable parameters.
94 | """
95 |
96 | extra_learnable_params = [{"params": self.projector.parameters()}]
97 | return super().learnable_params + extra_learnable_params
98 |
99 | @property
100 | def momentum_pairs(self) -> List[Tuple[Any, Any]]:
101 | """Adds (projector, momentum_projector) to the parent's momentum pairs.
102 |
103 | Returns:
104 | List[Tuple[Any, Any]]: list of momentum pairs.
105 | """
106 |
107 | extra_momentum_pairs = [(self.projector, self.momentum_projector)]
108 | return super().momentum_pairs + extra_momentum_pairs
109 |
110 |
111 | def forward(self, X: torch.Tensor, *args, **kwargs) -> Dict[str, Any]:
112 | """Performs the forward pass of the online encoder and the online projection.
113 |
114 | Args:
115 | X (torch.Tensor): a batch of images in the tensor format.
116 |
117 | Returns:
118 | Dict[str, Any]: a dict containing the outputs of the parent and the projected features.
119 | """
120 |
121 | out = super().forward(X, *args, **kwargs)
122 | q = F.normalize(self.projector(out["feats"]), dim=-1)
123 | return {**out, "q": q}
124 |
125 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
126 | """
127 | Training step for MoCo reusing BaseMomentumMethod training step.
128 |
129 | Args:
130 | batch (Sequence[Any]): a batch of data in the
131 | format of [img_indexes, [X], Y], where [X] is a list of size self.num_crops
132 | containing batches of images.
133 | batch_idx (int): index of the batch.
134 |
135 | Returns:
136 | torch.Tensor: total loss composed of MOCO loss and classification loss.
137 |
138 | """
139 |
140 | out = super().training_step(batch, batch_idx)
141 | class_loss = out["loss"]
142 | feats1, feats2 = out["feats"]
143 |
144 | q1 = self.projector(feats1)
145 | q2 = self.projector(feats2)
146 |
147 | q1 = F.normalize(q1, dim=-1)
148 | q2 = F.normalize(q2, dim=-1)
149 |
150 |
151 | nce_loss = (
152 | dual_temperature_loss_func(q1, q2,
153 | temperature=self.temperature,
154 | dt_m=self.dt_m)
155 | + dual_temperature_loss_func(q2, q1,
156 | temperature=self.temperature,
157 | dt_m=self.dt_m)
158 | ) / 2
159 |
160 | # calculate std of features
161 | z1_std = F.normalize(q1, dim=-1).std(dim=0).mean()
162 | z2_std = F.normalize(q2, dim=-1).std(dim=0).mean()
163 | z_std = (z1_std + z2_std) / 2
164 |
165 | metrics = {
166 | "train_nce_loss": nce_loss,
167 | "train_z_std": z_std,
168 | }
169 | self.log_dict(metrics, on_epoch=True, sync_dist=True)
170 |
171 | return nce_loss + class_loss
172 |
173 |
--------------------------------------------------------------------------------
/solo/methods/simmoco_dual_temperature.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import argparse
21 | from typing import Any, Dict, List, Sequence, Tuple
22 |
23 | import torch
24 | import torch.nn as nn
25 | import torch.nn.functional as F
26 | from solo.methods.base import BaseMomentumMethod
27 | from solo.utils.momentum import initialize_momentum_params
28 | from solo.losses.dual_temperature_loss import dual_temperature_loss_func
29 |
30 |
31 | class SimMoCo_DualTemperature(BaseMomentumMethod):
32 | queue: torch.Tensor
33 |
34 | def __init__(
35 | self,
36 | proj_output_dim: int,
37 | proj_hidden_dim: int,
38 | temperature: float,
39 | dt_m: float,
40 | plus_version: bool,
41 | **kwargs
42 | ):
43 | """Implements simmoco with dual temperature.
44 |
45 | Args:
46 | proj_output_dim (int): number of dimensions of projected features.
47 | proj_hidden_dim (int): number of neurons of the hidden layers of the projector.
48 | temperature (float): temperature for the softmax in the contrastive loss.
49 | queue_size (int): number of samples to keep in the queue.
50 | """
51 |
52 | super().__init__(**kwargs)
53 |
54 | self.temperature = temperature
55 | self.dt_m = dt_m
56 | self.plus_version = plus_version
57 |
58 |
59 | # projector
60 | self.projector = nn.Sequential(
61 | nn.Linear(self.features_dim, proj_hidden_dim),
62 | nn.ReLU(),
63 | nn.Linear(proj_hidden_dim, proj_output_dim),
64 | )
65 |
66 | # momentum projector
67 | self.momentum_projector = nn.Sequential(
68 | nn.Linear(self.features_dim, proj_hidden_dim),
69 | nn.ReLU(),
70 | nn.Linear(proj_hidden_dim, proj_output_dim),
71 | )
72 |
73 | initialize_momentum_params(self.projector, self.momentum_projector)
74 |
75 |
76 | @staticmethod
77 | def add_model_specific_args(parent_parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
78 | parent_parser = super(SimMoCo_DualTemperature, SimMoCo_DualTemperature).add_model_specific_args(parent_parser)
79 | parser = parent_parser.add_argument_group("simmoco_dual_temperature")
80 |
81 | # projector
82 | parser.add_argument("--proj_output_dim", type=int, default=128)
83 | parser.add_argument("--proj_hidden_dim", type=int, default=2048)
84 |
85 | # parameters
86 | parser.add_argument("--temperature", type=float, default=0.1)
87 | parser.add_argument("--dt_m", type=float, default=10)
88 |
89 | # train the plus version which uses symmetric loss
90 | parser.add_argument("--plus_version", action="store_true")
91 |
92 | return parent_parser
93 |
94 | @property
95 | def learnable_params(self) -> List[dict]:
96 | """Adds projector parameters together with parent's learnable parameters.
97 |
98 | Returns:
99 | List[dict]: list of learnable parameters.
100 | """
101 |
102 | extra_learnable_params = [{"params": self.projector.parameters()}]
103 | return super().learnable_params + extra_learnable_params
104 |
105 | @property
106 | def momentum_pairs(self) -> List[Tuple[Any, Any]]:
107 | """Adds (projector, momentum_projector) to the parent's momentum pairs.
108 |
109 | Returns:
110 | List[Tuple[Any, Any]]: list of momentum pairs.
111 | """
112 |
113 | extra_momentum_pairs = [(self.projector, self.momentum_projector)]
114 | return super().momentum_pairs + extra_momentum_pairs
115 |
116 |
117 | def forward(self, X: torch.Tensor, *args, **kwargs) -> Dict[str, Any]:
118 | """Performs the forward pass of the online encoder and the online projection.
119 |
120 | Args:
121 | X (torch.Tensor): a batch of images in the tensor format.
122 |
123 | Returns:
124 | Dict[str, Any]: a dict containing the outputs of the parent and the projected features.
125 | """
126 |
127 | out = super().forward(X, *args, **kwargs)
128 | q = F.normalize(self.projector(out["feats"]), dim=-1)
129 | return {**out, "q": q}
130 |
131 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor:
132 | """
133 | Training step for MoCo reusing BaseMomentumMethod training step.
134 |
135 | Args:
136 | batch (Sequence[Any]): a batch of data in the
137 | format of [img_indexes, [X], Y], where [X] is a list of size self.num_crops
138 | containing batches of images.
139 | batch_idx (int): index of the batch.
140 |
141 | Returns:
142 | torch.Tensor: total loss composed of MOCO loss and classification loss.
143 |
144 | """
145 |
146 | if self.plus_version:
147 | out = super().training_step(batch, batch_idx)
148 | class_loss = out["loss"]
149 | feats1, feats2 = out["feats"]
150 | momentum_feats1, momentum_feats2 = out["momentum_feats"]
151 |
152 | q1 = self.projector(feats1)
153 | q2 = self.projector(feats2)
154 |
155 | q1 = F.normalize(q1, dim=-1)
156 | q2 = F.normalize(q2, dim=-1)
157 | with torch.no_grad():
158 | k1 = self.momentum_projector(momentum_feats1)
159 | k2 = self.momentum_projector(momentum_feats2)
160 | k1 = F.normalize(k1, dim=-1).detach()
161 | k2 = F.normalize(k2, dim=-1).detach()
162 |
163 |
164 | nce_loss = (
165 | dual_temperature_loss_func(q1, k2,
166 | temperature=self.temperature,
167 | dt_m=self.dt_m)
168 | + dual_temperature_loss_func(q2, k1,
169 | temperature=self.temperature,
170 | dt_m=self.dt_m)
171 | ) / 2
172 |
173 | # calculate std of features
174 | z1_std = F.normalize(q1, dim=-1).std(dim=0).mean()
175 | z2_std = F.normalize(q2, dim=-1).std(dim=0).mean()
176 | z_std = (z1_std + z2_std) / 2
177 |
178 | metrics = {
179 | "train_nce_loss": nce_loss,
180 | "train_z_std": z_std,
181 | }
182 | self.log_dict(metrics, on_epoch=True, sync_dist=True)
183 |
184 | return nce_loss + class_loss
185 |
186 | else:
187 | out = super().training_step(batch, batch_idx)
188 | class_loss = out["loss"]
189 | feats1, _ = out["feats"]
190 | _, momentum_feats2 = out["momentum_feats"]
191 |
192 | q1 = self.projector(feats1)
193 |
194 | q1 = F.normalize(q1, dim=-1)
195 |
196 | with torch.no_grad():
197 | k2 = self.momentum_projector(momentum_feats2)
198 | k2 = F.normalize(k2, dim=-1).detach()
199 |
200 | nce_loss = dual_temperature_loss_func(q1, k2,
201 | temperature=self.temperature,
202 | dt_m=self.dt_m)
203 |
204 | # calculate std of features
205 | z1_std = F.normalize(q1, dim=-1).std(dim=0).mean()
206 | z_std = z1_std
207 |
208 | metrics = {
209 | "train_nce_loss": nce_loss,
210 | "train_z_std": z_std,
211 | }
212 | self.log_dict(metrics, on_epoch=True, sync_dist=True)
213 |
214 | return nce_loss + class_loss
215 |
--------------------------------------------------------------------------------
/solo/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | from solo.utils import (
21 | backbones,
22 | checkpointer,
23 | classification_dataloader,
24 | knn,
25 | lars,
26 | metrics,
27 | misc,
28 | momentum,
29 | pretrain_dataloader,
30 | sinkhorn_knopp,
31 | )
32 |
33 | __all__ = [
34 | "backbones",
35 | "classification_dataloader",
36 | "pretrain_dataloader",
37 | "checkpointer",
38 | "knn",
39 | "misc",
40 | "lars",
41 | "metrics",
42 | "momentum",
43 | "sinkhorn_knopp",
44 | ]
45 |
46 | try:
47 | from solo.utils import dali_dataloader # noqa: F401
48 | except ImportError:
49 | pass
50 | else:
51 | __all__.append("dali_dataloader")
52 |
53 | try:
54 | from solo.utils import auto_umap # noqa: F401
55 | except ImportError:
56 | pass
57 | else:
58 | __all__.append("auto_umap")
59 |
--------------------------------------------------------------------------------
/solo/utils/auto_umap.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import math
21 | import os
22 | from argparse import ArgumentParser, Namespace
23 | from pathlib import Path
24 | from typing import Optional, Union
25 |
26 | import pandas as pd
27 | import pytorch_lightning as pl
28 | import seaborn as sns
29 | import torch
30 | import umap
31 | import wandb
32 | from matplotlib import pyplot as plt
33 | from pytorch_lightning.callbacks import Callback
34 |
35 | from .misc import gather
36 |
37 |
38 | class AutoUMAP(Callback):
39 | def __init__(
40 | self,
41 | args: Namespace,
42 | logdir: Union[str, Path] = Path("auto_umap"),
43 | frequency: int = 1,
44 | keep_previous: bool = False,
45 | color_palette: str = "hls",
46 | ):
47 | """UMAP callback that automatically runs UMAP on the validation dataset and uploads the
48 | figure to wandb.
49 |
50 | Args:
51 | args (Namespace): namespace object containing at least an attribute name.
52 | logdir (Union[str, Path], optional): base directory to store checkpoints.
53 | Defaults to Path("auto_umap").
54 | frequency (int, optional): number of epochs between each UMAP. Defaults to 1.
55 | color_palette (str, optional): color scheme for the classes. Defaults to "hls".
56 | keep_previous (bool, optional): whether to keep previous plots or not.
57 | Defaults to False.
58 | """
59 |
60 | super().__init__()
61 |
62 | self.args = args
63 | self.logdir = Path(logdir)
64 | self.frequency = frequency
65 | self.color_palette = color_palette
66 | self.keep_previous = keep_previous
67 |
68 | @staticmethod
69 | def add_auto_umap_args(parent_parser: ArgumentParser):
70 | """Adds user-required arguments to a parser.
71 |
72 | Args:
73 | parent_parser (ArgumentParser): parser to add new args to.
74 | """
75 |
76 | parser = parent_parser.add_argument_group("auto_umap")
77 | parser.add_argument("--auto_umap_dir", default=Path("auto_umap"), type=Path)
78 | parser.add_argument("--auto_umap_frequency", default=1, type=int)
79 | return parent_parser
80 |
81 | def initial_setup(self, trainer: pl.Trainer):
82 | """Creates the directories and does the initial setup needed.
83 |
84 | Args:
85 | trainer (pl.Trainer): pytorch lightning trainer object.
86 | """
87 |
88 | if trainer.logger is None:
89 | version = None
90 | else:
91 | version = str(trainer.logger.version)
92 | if version is not None:
93 | self.path = self.logdir / version
94 | self.umap_placeholder = f"{self.args.name}-{version}" + "-ep={}.pdf"
95 | else:
96 | self.path = self.logdir
97 | self.umap_placeholder = f"{self.args.name}" + "-ep={}.pdf"
98 | self.last_ckpt: Optional[str] = None
99 |
100 | # create logging dirs
101 | if trainer.is_global_zero:
102 | os.makedirs(self.path, exist_ok=True)
103 |
104 | def on_train_start(self, trainer: pl.Trainer, _):
105 | """Performs initial setup on training start.
106 |
107 | Args:
108 | trainer (pl.Trainer): pytorch lightning trainer object.
109 | """
110 |
111 | self.initial_setup(trainer)
112 |
113 | def plot(self, trainer: pl.Trainer, module: pl.LightningModule):
114 | """Produces a UMAP visualization by forwarding all data of the
115 | first validation dataloader through the module.
116 |
117 | Args:
118 | trainer (pl.Trainer): pytorch lightning trainer object.
119 | module (pl.LightningModule): current module object.
120 | """
121 |
122 | device = module.device
123 | data = []
124 | Y = []
125 |
126 | # set module to eval model and collect all feature representations
127 | module.eval()
128 | with torch.no_grad():
129 | for x, y in trainer.val_dataloaders[0]:
130 | x = x.to(device, non_blocking=True)
131 | y = y.to(device, non_blocking=True)
132 |
133 | feats = module(x)["feats"]
134 |
135 | feats = gather(feats)
136 | y = gather(y)
137 |
138 | data.append(feats.cpu())
139 | Y.append(y.cpu())
140 | module.train()
141 |
142 | if trainer.is_global_zero and len(data):
143 | data = torch.cat(data, dim=0).numpy()
144 | Y = torch.cat(Y, dim=0)
145 | num_classes = len(torch.unique(Y))
146 | Y = Y.numpy()
147 |
148 | data = umap.UMAP(n_components=2).fit_transform(data)
149 |
150 | # passing to dataframe
151 | df = pd.DataFrame()
152 | df["feat_1"] = data[:, 0]
153 | df["feat_2"] = data[:, 1]
154 | df["Y"] = Y
155 | plt.figure(figsize=(9, 9))
156 | ax = sns.scatterplot(
157 | x="feat_1",
158 | y="feat_2",
159 | hue="Y",
160 | palette=sns.color_palette(self.color_palette, num_classes),
161 | data=df,
162 | legend="full",
163 | alpha=0.3,
164 | )
165 | ax.set(xlabel="", ylabel="", xticklabels=[], yticklabels=[])
166 | ax.tick_params(left=False, right=False, bottom=False, top=False)
167 |
168 | # manually improve quality of imagenet umaps
169 | if num_classes > 100:
170 | anchor = (0.5, 1.8)
171 | else:
172 | anchor = (0.5, 1.35)
173 |
174 | plt.legend(loc="upper center", bbox_to_anchor=anchor, ncol=math.ceil(num_classes / 10))
175 | plt.tight_layout()
176 |
177 | if isinstance(trainer.logger, pl.loggers.WandbLogger):
178 | wandb.log(
179 | {"validation_umap": wandb.Image(ax)},
180 | commit=False,
181 | )
182 |
183 | # save plot locally as well
184 | epoch = trainer.current_epoch # type: ignore
185 | plt.savefig(self.path / self.umap_placeholder.format(epoch))
186 | plt.close()
187 |
188 | def on_validation_end(self, trainer: pl.Trainer, module: pl.LightningModule):
189 | """Tries to generate an up-to-date UMAP visualization of the features
190 | at the end of each validation epoch.
191 |
192 | Args:
193 | trainer (pl.Trainer): pytorch lightning trainer object.
194 | """
195 |
196 | epoch = trainer.current_epoch # type: ignore
197 | if epoch % self.frequency == 0 and not trainer.sanity_checking:
198 | self.plot(trainer, module)
199 |
--------------------------------------------------------------------------------
/solo/utils/backbones.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | # Copy-pasted from timm (https://github.com/rwightman/pytorch-image-models/blob/master/timm/),
21 | # but allowing different window sizes.
22 |
23 |
24 | from timm.models.swin_transformer import _create_swin_transformer, register_model
25 | from timm.models.vision_transformer import _create_vision_transformer
26 |
27 |
28 | @register_model
29 | def swin_tiny(window_size=7, **kwargs):
30 | model_kwargs = dict(
31 | patch_size=4,
32 | window_size=window_size,
33 | embed_dim=96,
34 | depths=(2, 2, 6, 2),
35 | num_heads=(3, 6, 12, 24),
36 | num_classes=0,
37 | **kwargs,
38 | )
39 | return _create_swin_transformer("swin_tiny_patch4_window7_224", **model_kwargs)
40 |
41 |
42 | @register_model
43 | def swin_small(window_size=7, **kwargs):
44 | model_kwargs = dict(
45 | patch_size=4,
46 | window_size=window_size,
47 | embed_dim=96,
48 | depths=(2, 2, 18, 2),
49 | num_heads=(3, 6, 12, 24),
50 | num_classes=0,
51 | **kwargs,
52 | )
53 | return _create_swin_transformer(
54 | "swin_small_patch4_window7_224", pretrained=False, **model_kwargs
55 | )
56 |
57 |
58 | @register_model
59 | def swin_base(window_size=7, **kwargs):
60 | model_kwargs = dict(
61 | patch_size=4,
62 | window_size=window_size,
63 | embed_dim=128,
64 | depths=(2, 2, 18, 2),
65 | num_heads=(4, 8, 16, 32),
66 | num_classes=0,
67 | **kwargs,
68 | )
69 | return _create_swin_transformer(
70 | "swin_base_patch4_window7_224", pretrained=False, **model_kwargs
71 | )
72 |
73 |
74 | @register_model
75 | def swin_large(window_size=7, **kwargs):
76 | model_kwargs = dict(
77 | patch_size=4,
78 | window_size=window_size,
79 | embed_dim=192,
80 | depths=(2, 2, 18, 2),
81 | num_heads=(6, 12, 24, 48),
82 | num_classes=0,
83 | **kwargs,
84 | )
85 | return _create_swin_transformer(
86 | "swin_large_patch4_window7_224", pretrained=False, **model_kwargs
87 | )
88 |
89 |
90 | @register_model
91 | def vit_tiny(patch_size=16, **kwargs):
92 | """ViT-Tiny (Vit-Ti/16)"""
93 | model_kwargs = dict(
94 | patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, num_classes=0, **kwargs
95 | )
96 | model = _create_vision_transformer("vit_tiny_patch16_224", pretrained=False, **model_kwargs)
97 | return model
98 |
99 |
100 | @register_model
101 | def vit_small(patch_size=16, **kwargs):
102 | model_kwargs = dict(
103 | patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, num_classes=0, **kwargs
104 | )
105 | model = _create_vision_transformer("vit_small_patch16_224", pretrained=False, **model_kwargs)
106 | return model
107 |
108 |
109 | @register_model
110 | def vit_base(patch_size=16, **kwargs):
111 | model_kwargs = dict(
112 | patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, num_classes=0, **kwargs
113 | )
114 | model = _create_vision_transformer("vit_base_patch16_224", pretrained=False, **model_kwargs)
115 | return model
116 |
117 |
118 | @register_model
119 | def vit_large(patch_size=16, **kwargs):
120 | model_kwargs = dict(
121 | patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, num_classes=0, **kwargs
122 | )
123 | model = _create_vision_transformer("vit_large_patch16_224", pretrained=False, **model_kwargs)
124 | return model
125 |
--------------------------------------------------------------------------------
/solo/utils/checkpointer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import json
21 | import os
22 | from argparse import ArgumentParser, Namespace
23 | from pathlib import Path
24 | from typing import Optional, Union
25 |
26 | import pytorch_lightning as pl
27 | from pytorch_lightning.callbacks import Callback
28 |
29 |
30 | class Checkpointer(Callback):
31 | def __init__(
32 | self,
33 | args: Namespace,
34 | logdir: Union[str, Path] = Path("trained_models"),
35 | frequency: int = 1,
36 | keep_previous_checkpoints: bool = False,
37 | ):
38 | """Custom checkpointer callback that stores checkpoints in an easier to access way.
39 |
40 | Args:
41 | args (Namespace): namespace object containing at least an attribute name.
42 | logdir (Union[str, Path], optional): base directory to store checkpoints.
43 | Defaults to "trained_models".
44 | frequency (int, optional): number of epochs between each checkpoint. Defaults to 1.
45 | keep_previous_checkpoints (bool, optional): whether to keep previous checkpoints or not.
46 | Defaults to False.
47 | """
48 |
49 | super().__init__()
50 |
51 | self.args = args
52 | self.logdir = Path(logdir)
53 | self.frequency = frequency
54 | self.keep_previous_checkpoints = keep_previous_checkpoints
55 |
56 | @staticmethod
57 | def add_checkpointer_args(parent_parser: ArgumentParser):
58 | """Adds user-required arguments to a parser.
59 |
60 | Args:
61 | parent_parser (ArgumentParser): parser to add new args to.
62 | """
63 |
64 | parser = parent_parser.add_argument_group("checkpointer")
65 | parser.add_argument("--checkpoint_dir", default=Path("trained_models"), type=Path)
66 | parser.add_argument("--checkpoint_frequency", default=1, type=int)
67 | return parent_parser
68 |
69 | def initial_setup(self, trainer: pl.Trainer):
70 | """Creates the directories and does the initial setup needed.
71 |
72 | Args:
73 | trainer (pl.Trainer): pytorch lightning trainer object.
74 | """
75 |
76 | if trainer.logger is None:
77 | version = None
78 | else:
79 | version = str(trainer.logger.version)
80 | if version is not None:
81 | self.path = self.logdir / version
82 | self.ckpt_placeholder = f"{self.args.name}-{version}" + "-ep={}.ckpt"
83 | else:
84 | self.path = self.logdir
85 | self.ckpt_placeholder = f"{self.args.name}" + "-ep={}.ckpt"
86 | self.last_ckpt: Optional[str] = None
87 |
88 | # create logging dirs
89 | if trainer.is_global_zero:
90 | os.makedirs(self.path, exist_ok=True)
91 |
92 | def save_args(self, trainer: pl.Trainer):
93 | """Stores arguments into a json file.
94 |
95 | Args:
96 | trainer (pl.Trainer): pytorch lightning trainer object.
97 | """
98 |
99 | if trainer.is_global_zero:
100 | args = vars(self.args)
101 | json_path = self.path / "args.json"
102 | json.dump(args, open(json_path, "w"), default=lambda o: "")
103 |
104 | def save(self, trainer: pl.Trainer):
105 | """Saves current checkpoint.
106 |
107 | Args:
108 | trainer (pl.Trainer): pytorch lightning trainer object.
109 | """
110 |
111 | if trainer.is_global_zero and not trainer.sanity_checking:
112 | epoch = trainer.current_epoch # type: ignore
113 | ckpt = self.path / self.ckpt_placeholder.format(epoch)
114 | trainer.save_checkpoint(ckpt)
115 |
116 | if self.last_ckpt and self.last_ckpt != ckpt and not self.keep_previous_checkpoints:
117 | os.remove(self.last_ckpt)
118 | self.last_ckpt = ckpt
119 |
120 | def on_train_start(self, trainer: pl.Trainer, _):
121 | """Executes initial setup and saves arguments.
122 |
123 | Args:
124 | trainer (pl.Trainer): pytorch lightning trainer object.
125 | """
126 |
127 | self.initial_setup(trainer)
128 | self.save_args(trainer)
129 |
130 | def on_validation_end(self, trainer: pl.Trainer, _):
131 | """Tries to save current checkpoint at the end of each validation epoch.
132 |
133 | Args:
134 | trainer (pl.Trainer): pytorch lightning trainer object.
135 | """
136 |
137 | epoch = trainer.current_epoch # type: ignore
138 | if epoch % self.frequency == 0:
139 | self.save(trainer)
140 |
--------------------------------------------------------------------------------
/solo/utils/classification_dataloader.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import os
21 | from pathlib import Path
22 | from typing import Callable, Optional, Tuple, Union
23 |
24 | import torchvision
25 | from torch import nn
26 | from torch.utils.data import DataLoader, Dataset
27 | from torchvision import transforms
28 | from torchvision.datasets import STL10, ImageFolder
29 |
30 |
31 | def build_custom_pipeline():
32 | """Builds augmentation pipelines for custom data.
33 | If you want to do exoteric augmentations, you can just re-write this function.
34 | Needs to return a dict with the same structure.
35 | """
36 |
37 | pipeline = {
38 | "T_train": transforms.Compose(
39 | [
40 | transforms.RandomResizedCrop(size=224, scale=(0.08, 1.0)),
41 | transforms.RandomHorizontalFlip(),
42 | transforms.ToTensor(),
43 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.228, 0.224, 0.225)),
44 | ]
45 | ),
46 | "T_val": transforms.Compose(
47 | [
48 | transforms.Resize(256), # resize shorter
49 | transforms.CenterCrop(224), # take center crop
50 | transforms.ToTensor(),
51 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.228, 0.224, 0.225)),
52 | ]
53 | ),
54 | }
55 | return pipeline
56 |
57 |
58 | def prepare_transforms(dataset: str) -> Tuple[nn.Module, nn.Module]:
59 | """Prepares pre-defined train and test transformation pipelines for some datasets.
60 |
61 | Args:
62 | dataset (str): dataset name.
63 |
64 | Returns:
65 | Tuple[nn.Module, nn.Module]: training and validation transformation pipelines.
66 | """
67 |
68 | cifar_pipeline = {
69 | "T_train": transforms.Compose(
70 | [
71 | transforms.RandomResizedCrop(size=32, scale=(0.08, 1.0)),
72 | transforms.RandomHorizontalFlip(),
73 | transforms.ToTensor(),
74 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
75 | ]
76 | ),
77 | "T_val": transforms.Compose(
78 | [
79 | transforms.ToTensor(),
80 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
81 | ]
82 | ),
83 | }
84 |
85 | stl_pipeline = {
86 | "T_train": transforms.Compose(
87 | [
88 | transforms.RandomResizedCrop(size=96, scale=(0.08, 1.0)),
89 | transforms.RandomHorizontalFlip(),
90 | transforms.ToTensor(),
91 | transforms.Normalize((0.4914, 0.4823, 0.4466), (0.247, 0.243, 0.261)),
92 | ]
93 | ),
94 | "T_val": transforms.Compose(
95 | [
96 | transforms.Resize((96, 96)),
97 | transforms.ToTensor(),
98 | transforms.Normalize((0.4914, 0.4823, 0.4466), (0.247, 0.243, 0.261)),
99 | ]
100 | ),
101 | }
102 |
103 | imagenet_pipeline = {
104 | "T_train": transforms.Compose(
105 | [
106 | transforms.RandomResizedCrop(size=224, scale=(0.08, 1.0)),
107 | transforms.RandomHorizontalFlip(),
108 | transforms.ToTensor(),
109 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.228, 0.224, 0.225)),
110 | ]
111 | ),
112 | "T_val": transforms.Compose(
113 | [
114 | transforms.Resize(256), # resize shorter
115 | transforms.CenterCrop(224), # take center crop
116 | transforms.ToTensor(),
117 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.228, 0.224, 0.225)),
118 | ]
119 | ),
120 | }
121 |
122 | custom_pipeline = build_custom_pipeline()
123 |
124 | pipelines = {
125 | "cifar10": cifar_pipeline,
126 | "cifar100": cifar_pipeline,
127 | "stl10": stl_pipeline,
128 | "imagenet100": imagenet_pipeline,
129 | "imagenet": imagenet_pipeline,
130 | "custom": custom_pipeline,
131 | }
132 |
133 | assert dataset in pipelines
134 |
135 | pipeline = pipelines[dataset]
136 | T_train = pipeline["T_train"]
137 | T_val = pipeline["T_val"]
138 |
139 | return T_train, T_val
140 |
141 |
142 | def prepare_datasets(
143 | dataset: str,
144 | T_train: Callable,
145 | T_val: Callable,
146 | data_dir: Optional[Union[str, Path]] = None,
147 | train_dir: Optional[Union[str, Path]] = None,
148 | val_dir: Optional[Union[str, Path]] = None,
149 | ) -> Tuple[Dataset, Dataset]:
150 | """Prepares train and val datasets.
151 |
152 | Args:
153 | dataset (str): dataset name.
154 | T_train (Callable): pipeline of transformations for training dataset.
155 | T_val (Callable): pipeline of transformations for validation dataset.
156 | data_dir Optional[Union[str, Path]]: path where to download/locate the dataset.
157 | train_dir Optional[Union[str, Path]]: subpath where the training data is located.
158 | val_dir Optional[Union[str, Path]]: subpath where the validation data is located.
159 |
160 | Returns:
161 | Tuple[Dataset, Dataset]: training dataset and validation dataset.
162 | """
163 |
164 | if data_dir is None:
165 | sandbox_dir = Path(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
166 | data_dir = sandbox_dir / "datasets"
167 | else:
168 | data_dir = Path(data_dir)
169 |
170 | if train_dir is None:
171 | train_dir = Path(f"{dataset}/train")
172 | else:
173 | train_dir = Path(train_dir)
174 |
175 | if val_dir is None:
176 | val_dir = Path(f"{dataset}/val")
177 | else:
178 | val_dir = Path(val_dir)
179 |
180 | assert dataset in ["cifar10", "cifar100", "stl10", "imagenet", "imagenet100", "custom"]
181 |
182 | if dataset in ["cifar10", "cifar100"]:
183 | DatasetClass = vars(torchvision.datasets)[dataset.upper()]
184 | train_dataset = DatasetClass(
185 | data_dir / train_dir,
186 | train=True,
187 | download=True,
188 | transform=T_train,
189 | )
190 |
191 | val_dataset = DatasetClass(
192 | data_dir / val_dir,
193 | train=False,
194 | download=True,
195 | transform=T_val,
196 | )
197 |
198 | elif dataset == "stl10":
199 | train_dataset = STL10(
200 | data_dir / train_dir,
201 | split="train",
202 | download=True,
203 | transform=T_train,
204 | )
205 | val_dataset = STL10(
206 | data_dir / val_dir,
207 | split="test",
208 | download=True,
209 | transform=T_val,
210 | )
211 |
212 | elif dataset in ["imagenet", "imagenet100", "custom"]:
213 | train_dir = data_dir / train_dir
214 | val_dir = data_dir / val_dir
215 |
216 | train_dataset = ImageFolder(train_dir, T_train)
217 | val_dataset = ImageFolder(val_dir, T_val)
218 |
219 | return train_dataset, val_dataset
220 |
221 |
222 | def prepare_dataloaders(
223 | train_dataset: Dataset, val_dataset: Dataset, batch_size: int = 64, num_workers: int = 4
224 | ) -> Tuple[DataLoader, DataLoader]:
225 | """Wraps a train and a validation dataset with a DataLoader.
226 |
227 | Args:
228 | train_dataset (Dataset): object containing training data.
229 | val_dataset (Dataset): object containing validation data.
230 | batch_size (int): batch size.
231 | num_workers (int): number of parallel workers.
232 | Returns:
233 | Tuple[DataLoader, DataLoader]: training dataloader and validation dataloader.
234 | """
235 |
236 | train_loader = DataLoader(
237 | train_dataset,
238 | batch_size=batch_size,
239 | shuffle=True,
240 | num_workers=num_workers,
241 | pin_memory=True,
242 | drop_last=True,
243 | )
244 | val_loader = DataLoader(
245 | val_dataset,
246 | batch_size=batch_size,
247 | num_workers=num_workers,
248 | pin_memory=True,
249 | drop_last=False,
250 | )
251 | return train_loader, val_loader
252 |
253 |
254 | def prepare_data(
255 | dataset: str,
256 | data_dir: Optional[Union[str, Path]] = None,
257 | train_dir: Optional[Union[str, Path]] = None,
258 | val_dir: Optional[Union[str, Path]] = None,
259 | batch_size: int = 64,
260 | num_workers: int = 4,
261 | ) -> Tuple[DataLoader, DataLoader]:
262 | """Prepares transformations, creates dataset objects and wraps them in dataloaders.
263 |
264 | Args:
265 | dataset (str): dataset name.
266 | data_dir (Optional[Union[str, Path]], optional): path where to download/locate the dataset.
267 | Defaults to None.
268 | train_dir (Optional[Union[str, Path]], optional): subpath where the
269 | training data is located. Defaults to None.
270 | val_dir (Optional[Union[str, Path]], optional): subpath where the
271 | validation data is located. Defaults to None.
272 | batch_size (int, optional): batch size. Defaults to 64.
273 | num_workers (int, optional): number of parallel workers. Defaults to 4.
274 |
275 | Returns:
276 | Tuple[DataLoader, DataLoader]: prepared training and validation dataloader;.
277 | """
278 |
279 | T_train, T_val = prepare_transforms(dataset)
280 | train_dataset, val_dataset = prepare_datasets(
281 | dataset,
282 | T_train,
283 | T_val,
284 | data_dir=data_dir,
285 | train_dir=train_dir,
286 | val_dir=val_dir,
287 | )
288 | train_loader, val_loader = prepare_dataloaders(
289 | train_dataset,
290 | val_dataset,
291 | batch_size=batch_size,
292 | num_workers=num_workers,
293 | )
294 | return train_loader, val_loader
295 |
--------------------------------------------------------------------------------
/solo/utils/dali_dataloader.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import os
21 | from pathlib import Path
22 | from typing import Callable, Iterable, List, Sequence, Union
23 |
24 | import nvidia.dali.fn as fn
25 | import nvidia.dali.ops as ops
26 | import nvidia.dali.types as types
27 | from nvidia.dali.pipeline import Pipeline
28 |
29 |
30 | class Mux:
31 | def __init__(self, prob: float):
32 | """Implements mutex operation for dali in order to support probabilitic augmentations.
33 |
34 | Args:
35 | prob (float): probability value
36 | """
37 |
38 | self.to_bool = ops.Cast(dtype=types.DALIDataType.BOOL)
39 | self.rng = ops.random.CoinFlip(probability=prob)
40 |
41 | def __call__(self, true_case, false_case):
42 | condition = self.to_bool(self.rng())
43 | neg_condition = condition ^ True
44 | return condition * true_case + neg_condition * false_case
45 |
46 |
47 | class RandomGrayScaleConversion:
48 | def __init__(self, prob: float = 0.2, device: str = "gpu"):
49 | """Converts image to greyscale with probability.
50 |
51 | Args:
52 | prob (float, optional): probability of conversion. Defaults to 0.2.
53 | device (str, optional): device on which the operation will be performed.
54 | Defaults to "gpu".
55 | """
56 |
57 | self.mux = Mux(prob=prob)
58 | self.grayscale = ops.ColorSpaceConversion(
59 | device=device, image_type=types.RGB, output_type=types.GRAY
60 | )
61 |
62 | def __call__(self, images):
63 | out = self.grayscale(images)
64 | out = fn.cat(out, out, out, axis=2)
65 | return self.mux(true_case=out, false_case=images)
66 |
67 |
68 | class RandomColorJitter:
69 | def __init__(
70 | self,
71 | brightness: float,
72 | contrast: float,
73 | saturation: float,
74 | hue: float,
75 | prob: float = 0.8,
76 | device: str = "gpu",
77 | ):
78 | """Applies random color jittering with probability.
79 |
80 | Args:
81 | brightness (float): brightness value for samplying uniformly
82 | in [max(0, 1 - brightness), 1 + brightness].
83 | contrast (float): contrast value for samplying uniformly
84 | in [max(0, 1 - contrast), 1 + contrast].
85 | saturation (float): saturation value for samplying uniformly
86 | in [max(0, 1 - saturation), 1 + saturation].
87 | hue (float): hue value for samplying uniformly in [-hue, hue].
88 | prob (float, optional): probability of applying jitter. Defaults to 0.8.
89 | device (str, optional): device on which the operation will be performed.
90 | Defaults to "gpu".
91 | """
92 |
93 | assert 0 <= hue <= 0.5
94 |
95 | self.mux = Mux(prob=prob)
96 |
97 | self.color = ops.ColorTwist(device=device)
98 |
99 | # look at torchvision docs to see how colorjitter samples stuff
100 | # for bright, cont and sat, it samples from [1-v, 1+v]
101 | # for hue, it samples from [-hue, hue]
102 |
103 | self.brightness = 1
104 | self.contrast = 1
105 | self.saturation = 1
106 | self.hue = 0
107 |
108 | if brightness:
109 | self.brightness = ops.random.Uniform(range=[max(0, 1 - brightness), 1 + brightness])
110 |
111 | if contrast:
112 | self.contrast = ops.random.Uniform(range=[max(0, 1 - contrast), 1 + contrast])
113 |
114 | if saturation:
115 | self.saturation = ops.random.Uniform(range=[max(0, 1 - saturation), 1 + saturation])
116 |
117 | if hue:
118 | # dali uses hue in degrees for some reason...
119 | hue = 360 * hue
120 | self.hue = ops.random.Uniform(range=[-hue, hue])
121 |
122 | def __call__(self, images):
123 | out = self.color(
124 | images,
125 | brightness=self.brightness() if callable(self.brightness) else self.brightness,
126 | contrast=self.contrast() if callable(self.contrast) else self.contrast,
127 | saturation=self.saturation() if callable(self.saturation) else self.saturation,
128 | hue=self.hue() if callable(self.hue) else self.hue,
129 | )
130 | return self.mux(true_case=out, false_case=images)
131 |
132 |
133 | class RandomGaussianBlur:
134 | def __init__(self, prob: float = 0.5, window_size: int = 23, device: str = "gpu"):
135 | """Applies random gaussian blur with probability.
136 |
137 | Args:
138 | prob (float, optional): probability of applying random gaussian blur. Defaults to 0.5.
139 | window_size (int, optional): window size for gaussian blur. Defaults to 23.
140 | device (str, optional): device on which the operation will be performe.
141 | Defaults to "gpu".
142 | """
143 |
144 | self.mux = Mux(prob=prob)
145 | # gaussian blur
146 | self.gaussian_blur = ops.GaussianBlur(device=device, window_size=(window_size, window_size))
147 | self.sigma = ops.random.Uniform(range=[0, 1])
148 |
149 | def __call__(self, images):
150 | sigma = self.sigma() * 1.9 + 0.1
151 | out = self.gaussian_blur(images, sigma=sigma)
152 | return self.mux(true_case=out, false_case=images)
153 |
154 |
155 | class RandomSolarize:
156 | def __init__(self, threshold: int = 128, prob: float = 0.0):
157 | """Applies random solarization with probability.
158 |
159 | Args:
160 | threshold (int, optional): threshold for inversion. Defaults to 128.
161 | prob (float, optional): probability of solarization. Defaults to 0.0.
162 | """
163 |
164 | self.mux = Mux(prob=prob)
165 |
166 | self.threshold = threshold
167 |
168 | def __call__(self, images):
169 | inverted_img = 255 - images
170 | mask = images >= self.threshold
171 | out = mask * inverted_img + (True ^ mask) * images
172 | return self.mux(true_case=out, false_case=images)
173 |
174 |
175 | class NormalPipeline(Pipeline):
176 | def __init__(
177 | self,
178 | data_path: str,
179 | batch_size: int,
180 | device: str,
181 | validation: bool = False,
182 | device_id: int = 0,
183 | shard_id: int = 0,
184 | num_shards: int = 1,
185 | num_threads: int = 4,
186 | seed: int = 12,
187 | ):
188 | """Initializes the pipeline for validation or linear eval training.
189 |
190 | If validation is set to True then images will only be resized to 256px and center cropped
191 | to 224px, otherwise random resized crop, horizontal flip are applied. In both cases images
192 | are normalized.
193 |
194 | Args:
195 | data_path (str): directory that contains the data.
196 | batch_size (int): batch size.
197 | device (str): device on which the operation will be performed.
198 | validation (bool): whether it is validation or training. Defaults to False. Defaults to
199 | False.
200 | device_id (int): id of the device used to initialize the seed and for parent class.
201 | Defaults to 0.
202 | shard_id (int): id of the shard (chuck of samples). Defaults to 0.
203 | num_shards (int): total number of shards. Defaults to 1.
204 | num_threads (int): number of threads to run in parallel. Defaults to 4.
205 | seed (int): seed for random number generation. Defaults to 12.
206 | """
207 |
208 | seed += device_id
209 | super().__init__(batch_size, num_threads, device_id, seed)
210 |
211 | self.device = device
212 | self.validation = validation
213 |
214 | self.reader = ops.readers.File(
215 | file_root=data_path,
216 | shard_id=shard_id,
217 | num_shards=num_shards,
218 | shuffle_after_epoch=True if not self.validation else False,
219 | )
220 | decoder_device = "mixed" if self.device == "gpu" else "cpu"
221 | device_memory_padding = 211025920 if decoder_device == "mixed" else 0
222 | host_memory_padding = 140544512 if decoder_device == "mixed" else 0
223 | self.decode = ops.decoders.Image(
224 | device=decoder_device,
225 | output_type=types.RGB,
226 | device_memory_padding=device_memory_padding,
227 | host_memory_padding=host_memory_padding,
228 | )
229 |
230 | # crop operations
231 | if self.validation:
232 | self.resize = ops.Resize(
233 | device=self.device,
234 | resize_shorter=256,
235 | interp_type=types.INTERP_CUBIC,
236 | )
237 | # center crop and normalize
238 | self.cmn = ops.CropMirrorNormalize(
239 | device=self.device,
240 | dtype=types.FLOAT,
241 | output_layout=types.NCHW,
242 | crop=(224, 224),
243 | mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
244 | std=[0.228 * 255, 0.224 * 255, 0.225 * 255],
245 | )
246 | else:
247 | self.resize = ops.RandomResizedCrop(
248 | device=self.device,
249 | size=224,
250 | random_area=(0.08, 1.0),
251 | interp_type=types.INTERP_CUBIC,
252 | )
253 | # normalize and horizontal flip
254 | self.cmn = ops.CropMirrorNormalize(
255 | device=self.device,
256 | dtype=types.FLOAT,
257 | output_layout=types.NCHW,
258 | mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
259 | std=[0.228 * 255, 0.224 * 255, 0.225 * 255],
260 | )
261 |
262 | self.coin05 = ops.random.CoinFlip(probability=0.5)
263 | self.to_int64 = ops.Cast(dtype=types.INT64, device=device)
264 |
265 | def define_graph(self):
266 | """Defines the computational graph for dali operations."""
267 |
268 | # read images from memory
269 | inputs, labels = self.reader(name="Reader")
270 | images = self.decode(inputs)
271 |
272 | # crop into large and small images
273 | images = self.resize(images)
274 |
275 | if self.validation:
276 | # crop and normalize
277 | images = self.cmn(images)
278 | else:
279 | # normalize and maybe apply horizontal flip with 0.5 chance
280 | images = self.cmn(images, mirror=self.coin05())
281 |
282 | if self.device == "gpu":
283 | labels = labels.gpu()
284 | # PyTorch expects labels as INT64
285 | labels = self.to_int64(labels)
286 |
287 | return (images, labels)
288 |
289 |
290 | class CustomNormalPipeline(NormalPipeline):
291 | """Initializes the custom pipeline for validation or linear eval training.
292 | This acts as a placeholder and behaves exactly like NormalPipeline.
293 | If you want to do exoteric augmentations, you can just re-write this class.
294 | """
295 |
296 | pass
297 |
298 |
299 | class ImagenetTransform:
300 | def __init__(
301 | self,
302 | device: str,
303 | brightness: float,
304 | contrast: float,
305 | saturation: float,
306 | hue: float,
307 | gaussian_prob: float = 0.5,
308 | solarization_prob: float = 0.0,
309 | size: int = 224,
310 | min_scale: float = 0.08,
311 | max_scale: float = 1.0,
312 | ):
313 | """Applies Imagenet transformations to a batch of images.
314 |
315 | Args:
316 | device (str): device on which the operations will be performed.
317 | brightness (float): sampled uniformly in [max(0, 1 - brightness), 1 + brightness].
318 | contrast (float): sampled uniformly in [max(0, 1 - contrast), 1 + contrast].
319 | saturation (float): sampled uniformly in [max(0, 1 - saturation), 1 + saturation].
320 | hue (float): sampled uniformly in [-hue, hue].
321 | gaussian_prob (float, optional): probability of applying gaussian blur. Defaults to 0.5.
322 | solarization_prob (float, optional): probability of applying solarization. Defaults
323 | to 0.0.
324 | size (int, optional): size of the side of the image after transformation. Defaults
325 | to 224.
326 | min_scale (float, optional): minimum scale of the crops. Defaults to 0.08.
327 | max_scale (float, optional): maximum scale of the crops. Defaults to 1.0.
328 | """
329 |
330 | # random crop
331 | self.random_crop = ops.RandomResizedCrop(
332 | device=device,
333 | size=size,
334 | random_area=(min_scale, max_scale),
335 | interp_type=types.INTERP_CUBIC,
336 | )
337 |
338 | # color jitter
339 | self.random_color_jitter = RandomColorJitter(
340 | brightness=brightness,
341 | contrast=contrast,
342 | saturation=saturation,
343 | hue=hue,
344 | prob=0.8,
345 | device=device,
346 | )
347 |
348 | # grayscale conversion
349 | self.random_grayscale = RandomGrayScaleConversion(prob=0.2, device=device)
350 |
351 | # gaussian blur
352 | self.random_gaussian_blur = RandomGaussianBlur(prob=gaussian_prob, device=device)
353 |
354 | # solarization
355 | self.random_solarization = RandomSolarize(prob=solarization_prob)
356 |
357 | # normalize and horizontal flip
358 | self.cmn = ops.CropMirrorNormalize(
359 | device=device,
360 | dtype=types.FLOAT,
361 | output_layout=types.NCHW,
362 | mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
363 | std=[0.228 * 255, 0.224 * 255, 0.225 * 255],
364 | )
365 | self.coin05 = ops.random.CoinFlip(probability=0.5)
366 |
367 | self.str = (
368 | "ImagenetTransform("
369 | f"random_crop({min_scale}, {max_scale}), "
370 | f"random_color_jitter(brightness={brightness}, "
371 | f"contrast={contrast}, saturation={saturation}, hue={hue}), "
372 | f"random_gray_scale, random_gaussian_blur({gaussian_prob}), "
373 | f"random_solarization({solarization_prob}), "
374 | "crop_mirror_resize())"
375 | )
376 |
377 | def __str__(self) -> str:
378 | return self.str
379 |
380 | def __call__(self, images):
381 | out = self.random_crop(images)
382 | out = self.random_color_jitter(out)
383 | out = self.random_grayscale(out)
384 | out = self.random_gaussian_blur(out)
385 | out = self.random_solarization(out)
386 | out = self.cmn(out, mirror=self.coin05())
387 | return out
388 |
389 |
390 | class CustomTransform:
391 | def __init__(
392 | self,
393 | device: str,
394 | brightness: float,
395 | contrast: float,
396 | saturation: float,
397 | hue: float,
398 | gaussian_prob: float = 0.5,
399 | solarization_prob: float = 0.0,
400 | size: int = 224,
401 | min_scale: float = 0.08,
402 | max_scale: float = 1.0,
403 | mean: Sequence[float] = (0.485, 0.456, 0.406),
404 | std: Sequence[float] = (0.228, 0.224, 0.225),
405 | ):
406 | """Applies Custom transformations.
407 | If you want to do exoteric augmentations, you can just re-write this class.
408 |
409 | Args:
410 | device (str): device on which the operations will be performed.
411 | brightness (float): sampled uniformly in [max(0, 1 - brightness), 1 + brightness].
412 | contrast (float): sampled uniformly in [max(0, 1 - contrast), 1 + contrast].
413 | saturation (float): sampled uniformly in [max(0, 1 - saturation), 1 + saturation].
414 | hue (float): sampled uniformly in [-hue, hue].
415 | gaussian_prob (float, optional): probability of applying gaussian blur. Defaults to 0.5.
416 | solarization_prob (float, optional): probability of applying solarization. Defaults
417 | to 0.0.
418 | size (int, optional): size of the side of the image after transformation. Defaults
419 | to 224.
420 | min_scale (float, optional): minimum scale of the crops. Defaults to 0.08.
421 | max_scale (float, optional): maximum scale of the crops. Defaults to 1.0.
422 | mean (Sequence[float], optional): mean values for normalization.
423 | Defaults to (0.485, 0.456, 0.406).
424 | std (Sequence[float], optional): std values for normalization.
425 | Defaults to (0.228, 0.224, 0.225).
426 | """
427 |
428 | # random crop
429 | self.random_crop = ops.RandomResizedCrop(
430 | device=device,
431 | size=size,
432 | random_area=(min_scale, max_scale),
433 | interp_type=types.INTERP_CUBIC,
434 | )
435 |
436 | # color jitter
437 | self.random_color_jitter = RandomColorJitter(
438 | brightness=brightness,
439 | contrast=contrast,
440 | saturation=saturation,
441 | hue=hue,
442 | prob=0.8,
443 | device=device,
444 | )
445 |
446 | # grayscale conversion
447 | self.random_grayscale = RandomGrayScaleConversion(prob=0.2, device=device)
448 |
449 | # gaussian blur
450 | self.random_gaussian_blur = RandomGaussianBlur(prob=gaussian_prob, device=device)
451 |
452 | # solarization
453 | self.random_solarization = RandomSolarize(prob=solarization_prob)
454 |
455 | # normalize and horizontal flip
456 | self.cmn = ops.CropMirrorNormalize(
457 | device=device,
458 | dtype=types.FLOAT,
459 | output_layout=types.NCHW,
460 | mean=[v * 255 for v in mean],
461 | std=[v * 255 for v in std],
462 | )
463 | self.coin05 = ops.random.CoinFlip(probability=0.5)
464 |
465 | self.str = (
466 | "CustomTransform("
467 | f"random_crop({min_scale}, {max_scale}), "
468 | f"random_color_jitter(brightness={brightness}, "
469 | f"contrast={contrast}, saturation={saturation}, hue={hue}), "
470 | f"random_gray_scale, random_gaussian_blur({gaussian_prob}), "
471 | f"random_solarization({solarization_prob}), "
472 | "crop_mirror_resize())"
473 | )
474 |
475 | def __call__(self, images):
476 | out = self.random_crop(images)
477 | out = self.random_color_jitter(out)
478 | out = self.random_grayscale(out)
479 | out = self.random_gaussian_blur(out)
480 | out = self.random_solarization(out)
481 | out = self.cmn(out, mirror=self.coin05())
482 | return out
483 |
484 | def __str__(self):
485 | return self.str
486 |
487 |
488 | class PretrainPipeline(Pipeline):
489 | def __init__(
490 | self,
491 | data_path: Union[str, Path],
492 | batch_size: int,
493 | device: str,
494 | transform: Union[Callable, Iterable],
495 | num_crops: int = 2,
496 | random_shuffle: bool = True,
497 | device_id: int = 0,
498 | shard_id: int = 0,
499 | num_shards: int = 1,
500 | num_threads: int = 4,
501 | seed: int = 12,
502 | no_labels: bool = False,
503 | encode_indexes_into_labels: bool = False,
504 | ):
505 | """Initializes the pipeline for pretraining.
506 |
507 | Args:
508 | data_path (str): directory that contains the data.
509 | batch_size (int): batch size.
510 | device (str): device on which the operation will be performed.
511 | transform (Union[Callable, Iterable]): a transformation or a sequence
512 | of transformations to be applied.
513 | num_crops (int, optional): number of crops. Defaults to 2.
514 | random_shuffle (bool, optional): whether to randomly shuffle the samples.
515 | Defaults to True.
516 | device_id (int, optional): id of the device used to initialize the seed and
517 | for parent class. Defaults to 0.
518 | shard_id (int, optional): id of the shard (chuck of samples). Defaults to 0.
519 | num_shards (int, optional): total number of shards. Defaults to 1.
520 | num_threads (int, optional): number of threads to run in parallel. Defaults to 4.
521 | seed (int, optional): seed for random number generation. Defaults to 12.
522 | no_labels (bool, optional): if the data has no labels. Defaults to False.
523 | encode_indexes_into_labels (bool, optional): uses sample indexes as labels
524 | and then gets the labels from a lookup table. This may use more CPU memory,
525 | so just use when needed. Defaults to False.
526 | """
527 |
528 | seed += device_id
529 | super().__init__(
530 | batch_size=batch_size,
531 | num_threads=num_threads,
532 | device_id=device_id,
533 | seed=seed,
534 | )
535 |
536 | self.device = device
537 |
538 | data_path = Path(data_path)
539 | if no_labels:
540 | files = [data_path / f for f in sorted(os.listdir(data_path))]
541 | labels = [-1] * len(files)
542 | self.reader = ops.readers.File(
543 | files=files,
544 | shard_id=shard_id,
545 | num_shards=num_shards,
546 | shuffle_after_epoch=random_shuffle,
547 | labels=labels,
548 | )
549 | elif encode_indexes_into_labels:
550 | labels = sorted(Path(entry.name) for entry in os.scandir(data_path) if entry.is_dir())
551 |
552 | data = [
553 | (data_path / label / file, label_idx)
554 | for label_idx, label in enumerate(labels)
555 | for file in sorted(os.listdir(data_path / label))
556 | ]
557 |
558 | files = []
559 | labels = []
560 | # for debugging
561 | true_labels = []
562 |
563 | self.conversion_map = []
564 | for file_idx, (file, label_idx) in enumerate(data):
565 | files.append(file)
566 | labels.append(file_idx)
567 | true_labels.append(label_idx)
568 | self.conversion_map.append(label_idx)
569 |
570 | # debugging
571 | for file, file_idx, label_idx in zip(files, labels, true_labels):
572 | assert self.conversion_map[file_idx] == label_idx
573 |
574 | self.reader = ops.readers.File(
575 | files=files,
576 | shard_id=shard_id,
577 | num_shards=num_shards,
578 | shuffle_after_epoch=random_shuffle,
579 | )
580 | else:
581 | self.reader = ops.readers.File(
582 | file_root=data_path,
583 | shard_id=shard_id,
584 | num_shards=num_shards,
585 | shuffle_after_epoch=random_shuffle,
586 | )
587 |
588 | decoder_device = "mixed" if self.device == "gpu" else "cpu"
589 | device_memory_padding = 211025920 if decoder_device == "mixed" else 0
590 | host_memory_padding = 140544512 if decoder_device == "mixed" else 0
591 | self.decode = ops.decoders.Image(
592 | device=decoder_device,
593 | output_type=types.RGB,
594 | device_memory_padding=device_memory_padding,
595 | host_memory_padding=host_memory_padding,
596 | )
597 | self.to_int64 = ops.Cast(dtype=types.INT64, device=device)
598 |
599 | self.num_crops = num_crops
600 |
601 | # transformations
602 | self.transform = transform
603 |
604 | if isinstance(transform, Iterable):
605 | self.one_transform_per_crop = True
606 | else:
607 | self.one_transform_per_crop = False
608 | self.num_crops = num_crops
609 |
610 | def define_graph(self):
611 | """Defines the computational graph for dali operations."""
612 |
613 | # read images from memory
614 | inputs, labels = self.reader(name="Reader")
615 |
616 | images = self.decode(inputs)
617 |
618 | if self.one_transform_per_crop:
619 | crops = [transform(images) for transform in self.transform]
620 | else:
621 | crops = [self.transform(images) for i in range(self.num_crops)]
622 |
623 | if self.device == "gpu":
624 | labels = labels.gpu()
625 | # PyTorch expects labels as INT64
626 | labels = self.to_int64(labels)
627 |
628 | return (*crops, labels)
629 |
630 |
631 | class MulticropPretrainPipeline(Pipeline):
632 | def __init__(
633 | self,
634 | data_path: Union[str, Path],
635 | batch_size: int,
636 | device: str,
637 | transforms: List,
638 | num_crops: List[int],
639 | random_shuffle: bool = True,
640 | device_id: int = 0,
641 | shard_id: int = 0,
642 | num_shards: int = 1,
643 | num_threads: int = 4,
644 | seed: int = 12,
645 | no_labels: bool = False,
646 | encode_indexes_into_labels: bool = False,
647 | ):
648 | """Initializes the pipeline for pretraining with multicrop.
649 |
650 | Args:
651 | data_path (str): directory that contains the data.
652 | batch_size (int): batch size.
653 | device (str): device on which the operation will be performed.
654 | transforms (List): list of transformations to be applied.
655 | num_crops (List[int]): number of crops.
656 | random_shuffle (bool, optional): whether to randomly shuffle the samples.
657 | Defaults to True.
658 | device_id (int, optional): id of the device used to initialize the seed and
659 | for parent class. Defaults to 0.
660 | shard_id (int, optional): id of the shard (chuck of samples). Defaults to 0.
661 | num_shards (int, optional): total number of shards. Defaults to 1.
662 | num_threads (int, optional): number of threads to run in parallel. Defaults to 4.
663 | seed (int, optional): seed for random number generation. Defaults to 12.
664 | no_labels (bool, optional): if the data has no labels. Defaults to False.
665 | encode_indexes_into_labels (bool, optional): uses sample indexes as labels
666 | and then gets the labels from a lookup table. This may use more CPU memory,
667 | so just use when needed. Defaults to False.
668 | """
669 |
670 | seed += device_id
671 | super().__init__(
672 | batch_size=batch_size,
673 | num_threads=num_threads,
674 | device_id=device_id,
675 | seed=seed,
676 | )
677 |
678 | self.device = device
679 |
680 | data_path = Path(data_path)
681 | if no_labels:
682 | files = [data_path / f for f in sorted(os.listdir(data_path))]
683 | labels = [-1] * len(files)
684 | self.reader = ops.readers.File(
685 | files=files,
686 | shard_id=shard_id,
687 | num_shards=num_shards,
688 | shuffle_after_epoch=random_shuffle,
689 | labels=labels,
690 | )
691 | elif encode_indexes_into_labels:
692 | labels = sorted(Path(entry.name) for entry in os.scandir(data_path) if entry.is_dir())
693 |
694 | data = [
695 | (data_path / label / file, label_idx)
696 | for label_idx, label in enumerate(labels)
697 | for file in sorted(os.listdir(data_path / label))
698 | ]
699 |
700 | files = []
701 | labels = []
702 | # for debugging
703 | true_labels = []
704 |
705 | self.conversion_map = []
706 | for file_idx, (file, label_idx) in enumerate(data):
707 | files.append(file)
708 | labels.append(file_idx)
709 | true_labels.append(label_idx)
710 | self.conversion_map.append(label_idx)
711 |
712 | # debugging
713 | for file, file_idx, label_idx in zip(files, labels, true_labels):
714 | assert self.conversion_map[file_idx] == label_idx
715 |
716 | self.reader = ops.readers.File(
717 | files=files,
718 | shard_id=shard_id,
719 | num_shards=num_shards,
720 | shuffle_after_epoch=random_shuffle,
721 | )
722 | else:
723 | self.reader = ops.readers.File(
724 | file_root=data_path,
725 | shard_id=shard_id,
726 | num_shards=num_shards,
727 | shuffle_after_epoch=random_shuffle,
728 | )
729 |
730 | decoder_device = "mixed" if self.device == "gpu" else "cpu"
731 | device_memory_padding = 211025920 if decoder_device == "mixed" else 0
732 | host_memory_padding = 140544512 if decoder_device == "mixed" else 0
733 | self.decode = ops.decoders.Image(
734 | device=decoder_device,
735 | output_type=types.RGB,
736 | device_memory_padding=device_memory_padding,
737 | host_memory_padding=host_memory_padding,
738 | )
739 | self.to_int64 = ops.Cast(dtype=types.INT64, device=device)
740 |
741 | self.num_crops = num_crops
742 | self.transforms = transforms
743 |
744 | assert len(transforms) == len(num_crops)
745 |
746 | def define_graph(self):
747 | """Defines the computational graph for dali operations."""
748 |
749 | # read images from memory
750 | inputs, labels = self.reader(name="Reader")
751 | images = self.decode(inputs)
752 |
753 | # crop into large and small images
754 | crops = []
755 | for i, transform in enumerate(self.transforms):
756 | for _ in range(self.num_crops[i]):
757 | crop = transform(images)
758 | crops.append(crop)
759 |
760 | if self.device == "gpu":
761 | labels = labels.gpu()
762 | # PyTorch expects labels as INT64
763 | labels = self.to_int64(labels)
764 |
765 | return (*crops, labels)
766 |
--------------------------------------------------------------------------------
/solo/utils/kmeans.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | from typing import Any, Sequence
21 |
22 | import numpy as np
23 | import torch
24 | import torch.distributed as dist
25 | import torch.nn.functional as F
26 | from scipy.sparse import csr_matrix
27 |
28 |
29 | class KMeans:
30 | def __init__(
31 | self,
32 | world_size: int,
33 | rank: int,
34 | num_crops: int,
35 | dataset_size: int,
36 | proj_features_dim: int,
37 | num_prototypes: int,
38 | kmeans_iters: int = 10,
39 | ):
40 | """Class that performs K-Means on the hypersphere.
41 |
42 | Args:
43 | world_size (int): world size.
44 | rank (int): rank of the current process.
45 | num_crops (int): number of crops.
46 | dataset_size (int): total size of the dataset (number of samples).
47 | proj_features_dim (int): number of dimensions of the projected features.
48 | num_prototypes (int): number of prototypes.
49 | kmeans_iters (int, optional): number of iterations for the k-means clustering.
50 | Defaults to 10.
51 | """
52 | self.world_size = world_size
53 | self.rank = rank
54 | self.num_crops = num_crops
55 | self.dataset_size = dataset_size
56 | self.proj_features_dim = proj_features_dim
57 | self.num_prototypes = num_prototypes
58 | self.kmeans_iters = kmeans_iters
59 |
60 | @staticmethod
61 | def get_indices_sparse(data: np.ndarray):
62 | cols = np.arange(data.size)
63 | M = csr_matrix((cols, (data.ravel(), cols)), shape=(int(data.max()) + 1, data.size))
64 | return [np.unravel_index(row.data, data.shape) for row in M]
65 |
66 | def cluster_memory(
67 | self,
68 | local_memory_index: torch.Tensor,
69 | local_memory_embeddings: torch.Tensor,
70 | ) -> Sequence[Any]:
71 | """Performs K-Means clustering on the hypersphere and returns centroids and
72 | assignments for each sample.
73 |
74 | Args:
75 | local_memory_index (torch.Tensor): memory bank cointaining indices of the
76 | samples.
77 | local_memory_embeddings (torch.Tensor): memory bank cointaining embeddings
78 | of the samples.
79 |
80 | Returns:
81 | Sequence[Any]: assignments and centroids.
82 | """
83 | j = 0
84 | device = local_memory_embeddings.device
85 | assignments = -torch.ones(len(self.num_prototypes), self.dataset_size).long()
86 | centroids_list = []
87 | with torch.no_grad():
88 | for i_K, K in enumerate(self.num_prototypes):
89 | # run distributed k-means
90 |
91 | # init centroids with elements from memory bank of rank 0
92 | centroids = torch.empty(K, self.proj_features_dim).to(device, non_blocking=True)
93 | if self.rank == 0:
94 | random_idx = torch.randperm(len(local_memory_embeddings[j]))[:K]
95 | assert len(random_idx) >= K, "please reduce the number of centroids"
96 | centroids = local_memory_embeddings[j][random_idx]
97 | if dist.is_available() and dist.is_initialized():
98 | dist.broadcast(centroids, 0)
99 |
100 | for n_iter in range(self.kmeans_iters + 1):
101 |
102 | # E step
103 | dot_products = torch.mm(local_memory_embeddings[j], centroids.t())
104 | _, local_assignments = dot_products.max(dim=1)
105 |
106 | # finish
107 | if n_iter == self.kmeans_iters:
108 | break
109 |
110 | # M step
111 | where_helper = self.get_indices_sparse(local_assignments.cpu().numpy())
112 | counts = torch.zeros(K).to(device, non_blocking=True).int()
113 | emb_sums = torch.zeros(K, self.proj_features_dim).to(device, non_blocking=True)
114 | for k in range(len(where_helper)):
115 | if len(where_helper[k][0]) > 0:
116 | emb_sums[k] = torch.sum(
117 | local_memory_embeddings[j][where_helper[k][0]],
118 | dim=0,
119 | )
120 | counts[k] = len(where_helper[k][0])
121 | if dist.is_available() and dist.is_initialized():
122 | dist.all_reduce(counts)
123 | dist.all_reduce(emb_sums)
124 | mask = counts > 0
125 | centroids[mask] = emb_sums[mask] / counts[mask].unsqueeze(1)
126 |
127 | # normalize centroids
128 | centroids = F.normalize(centroids, dim=1, p=2)
129 |
130 | centroids_list.append(centroids)
131 |
132 | if dist.is_available() and dist.is_initialized():
133 | # gather the assignments
134 | assignments_all = torch.empty(
135 | self.world_size,
136 | local_assignments.size(0),
137 | dtype=local_assignments.dtype,
138 | device=local_assignments.device,
139 | )
140 | assignments_all = list(assignments_all.unbind(0))
141 |
142 | dist_process = dist.all_gather(
143 | assignments_all, local_assignments, async_op=True
144 | )
145 | dist_process.wait()
146 | assignments_all = torch.cat(assignments_all).cpu()
147 |
148 | # gather the indexes
149 | indexes_all = torch.empty(
150 | self.world_size,
151 | local_memory_index.size(0),
152 | dtype=local_memory_index.dtype,
153 | device=local_memory_index.device,
154 | )
155 | indexes_all = list(indexes_all.unbind(0))
156 | dist_process = dist.all_gather(indexes_all, local_memory_index, async_op=True)
157 | dist_process.wait()
158 | indexes_all = torch.cat(indexes_all).cpu()
159 |
160 | else:
161 | assignments_all = local_assignments
162 | indexes_all = local_memory_index
163 |
164 | # log assignments
165 | assignments[i_K][indexes_all] = assignments_all
166 |
167 | # next memory bank to use
168 | j = (j + 1) % self.num_crops
169 |
170 | return assignments, centroids_list
171 |
--------------------------------------------------------------------------------
/solo/utils/knn.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | from typing import Sequence
21 |
22 | import torch
23 | from torchmetrics.metric import Metric
24 |
25 |
26 | class WeightedKNNClassifier(Metric):
27 | def __init__(
28 | self,
29 | k: int = 20,
30 | T: float = 0.07,
31 | num_chunks: int = 100,
32 | distance_fx: str = "cosine",
33 | epsilon: float = 0.00001,
34 | dist_sync_on_step: bool = False,
35 | ):
36 | """Implements the weighted k-NN classifier used for evaluation.
37 |
38 | Args:
39 | k (int, optional): number of neighbors. Defaults to 20.
40 | T (float, optional): temperature for the exponential. Only used with cosine
41 | distance. Defaults to 0.07.
42 | num_chunks (int, optional): number of chunks of test features. Defaults to 100.
43 | distance_fx (str, optional): Distance function. Accepted arguments: "cosine" or
44 | "euclidean". Defaults to "cosine".
45 | epsilon (float, optional): Small value for numerical stability. Only used with
46 | euclidean distance. Defaults to 0.00001.
47 | dist_sync_on_step (bool, optional): whether to sync distributed values at every
48 | step. Defaults to False.
49 | """
50 |
51 | super().__init__(dist_sync_on_step=dist_sync_on_step, compute_on_step=False)
52 |
53 | self.k = k
54 | self.T = T
55 | self.num_chunks = num_chunks
56 | self.distance_fx = distance_fx
57 | self.epsilon = epsilon
58 |
59 | self.add_state("train_features", default=[], persistent=False)
60 | self.add_state("train_targets", default=[], persistent=False)
61 | self.add_state("test_features", default=[], persistent=False)
62 | self.add_state("test_targets", default=[], persistent=False)
63 |
64 | def update(
65 | self,
66 | train_features: torch.Tensor = None,
67 | train_targets: torch.Tensor = None,
68 | test_features: torch.Tensor = None,
69 | test_targets: torch.Tensor = None,
70 | ):
71 | """Updates the memory banks. If train (test) features are passed as input, the
72 | corresponding train (test) targets must be passed as well.
73 |
74 | Args:
75 | train_features (torch.Tensor, optional): a batch of train features. Defaults to None.
76 | train_targets (torch.Tensor, optional): a batch of train targets. Defaults to None.
77 | test_features (torch.Tensor, optional): a batch of test features. Defaults to None.
78 | test_targets (torch.Tensor, optional): a batch of test targets. Defaults to None.
79 | """
80 | assert (train_features is None) == (train_targets is None)
81 | assert (test_features is None) == (test_targets is None)
82 |
83 | if train_features is not None:
84 | assert train_features.size(0) == train_targets.size(0)
85 | self.train_features.append(train_features)
86 | self.train_targets.append(train_targets)
87 |
88 | if test_features is not None:
89 | assert test_features.size(0) == test_targets.size(0)
90 | self.test_features.append(test_features)
91 | self.test_targets.append(test_targets)
92 |
93 | @torch.no_grad()
94 | def compute(self) -> Sequence[float]:
95 | """Computes weighted k-NN accuracy @1 and @5. If cosine distance is selected,
96 | the weight is computed using the exponential of the temperature scaled cosine
97 | distance of the samples. If euclidean distance is selected, the weight corresponds
98 | to the inverse of the euclidean distance.
99 |
100 | Returns:
101 | Sequence[float]: k-NN accuracy @1 and @5.
102 | """
103 | train_features = torch.cat(self.train_features)
104 | train_targets = torch.cat(self.train_targets)
105 | test_features = torch.cat(self.test_features)
106 | test_targets = torch.cat(self.test_targets)
107 |
108 | top1, top5, total = 0.0, 0.0, 0
109 | num_classes = torch.unique(test_targets).numel()
110 | num_test_images = test_targets.size(0)
111 | chunk_size = max(1, num_test_images // self.num_chunks)
112 | k = min(self.k, train_targets.size(0))
113 | retrieval_one_hot = torch.zeros(k, num_classes).to(train_features.device)
114 | for idx in range(0, num_test_images, chunk_size):
115 | # get the features for test images
116 | features = test_features[idx : min((idx + chunk_size), num_test_images), :]
117 | targets = test_targets[idx : min((idx + chunk_size), num_test_images)]
118 | batch_size = targets.size(0)
119 |
120 | # calculate the dot product and compute top-k neighbors
121 | if self.distance_fx == "cosine":
122 | similarity = torch.mm(features, train_features.t())
123 | elif self.distance_fx == "euclidean":
124 | similarity = 1 / (torch.cdist(features, train_features) + self.epsilon)
125 | else:
126 | raise NotImplementedError
127 |
128 | distances, indices = similarity.topk(k, largest=True, sorted=True)
129 | candidates = train_targets.view(1, -1).expand(batch_size, -1)
130 | retrieved_neighbors = torch.gather(candidates, 1, indices)
131 |
132 | retrieval_one_hot.resize_(batch_size * k, num_classes).zero_()
133 | # import pdb; pdb.set_trace()
134 |
135 | retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1)
136 |
137 | if self.distance_fx == "cosine":
138 | distances = distances.clone().div_(self.T).exp_()
139 |
140 | probs = torch.sum(
141 | torch.mul(
142 | retrieval_one_hot.view(batch_size, -1, num_classes),
143 | distances.view(batch_size, -1, 1),
144 | ),
145 | 1,
146 | )
147 | _, predictions = probs.sort(1, True)
148 |
149 | # find the predictions that match the target
150 | correct = predictions.eq(targets.data.view(-1, 1))
151 | # import pdb; pdb.set_trace()
152 | top1 = top1 + correct.narrow(1, 0, 1).sum().item()
153 | top5 = (
154 | top5 + correct.narrow(1, 0, min(5, k)).sum().item()
155 | ) # top5 does not make sense if k < 5
156 | total += targets.size(0)
157 |
158 | top1 = top1 * 100.0 / total
159 | top5 = top5 * 100.0 / total
160 |
161 | self.reset()
162 |
163 | return top1, top5
164 |
--------------------------------------------------------------------------------
/solo/utils/lars.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | # Copied from Pytorch Lightning (https://github.com/PyTorchLightning/pytorch-lightning/)
21 | # with extra documentations.
22 |
23 |
24 | import torch
25 | from torch.optim import Optimizer
26 |
27 |
28 | class LARSWrapper:
29 | def __init__(
30 | self,
31 | optimizer: Optimizer,
32 | eta: float = 1e-3,
33 | clip: bool = False,
34 | eps: float = 1e-8,
35 | exclude_bias_n_norm: bool = False,
36 | ):
37 | """Wrapper that adds LARS scheduling to any optimizer.
38 | This helps stability with huge batch sizes.
39 |
40 | Args:
41 | optimizer (Optimizer): torch optimizer.
42 | eta (float, optional): trust coefficient. Defaults to 1e-3.
43 | clip (bool, optional): clip gradient values. Defaults to False.
44 | eps (float, optional): adaptive_lr stability coefficient. Defaults to 1e-8.
45 | exclude_bias_n_norm (bool, optional): exclude bias and normalization layers from lars.
46 | Defaults to False.
47 | """
48 |
49 | self.optim = optimizer
50 | self.eta = eta
51 | self.eps = eps
52 | self.clip = clip
53 | self.exclude_bias_n_norm = exclude_bias_n_norm
54 |
55 | # transfer optim methods
56 | self.state_dict = self.optim.state_dict
57 | self.load_state_dict = self.optim.load_state_dict
58 | self.zero_grad = self.optim.zero_grad
59 | self.add_param_group = self.optim.add_param_group
60 |
61 | self.__setstate__ = self.optim.__setstate__ # type: ignore
62 | self.__getstate__ = self.optim.__getstate__ # type: ignore
63 | self.__repr__ = self.optim.__repr__ # type: ignore
64 |
65 | @property
66 | def defaults(self):
67 | return self.optim.defaults
68 |
69 | @defaults.setter
70 | def defaults(self, defaults):
71 | self.optim.defaults = defaults
72 |
73 | @property # type: ignore
74 | def __class__(self):
75 | return Optimizer
76 |
77 | @property
78 | def state(self):
79 | return self.optim.state
80 |
81 | @state.setter
82 | def state(self, state):
83 | self.optim.state = state
84 |
85 | @property
86 | def param_groups(self):
87 | return self.optim.param_groups
88 |
89 | @param_groups.setter
90 | def param_groups(self, value):
91 | self.optim.param_groups = value
92 |
93 | @torch.no_grad()
94 | def step(self, closure=None):
95 | weight_decays = []
96 |
97 | for group in self.optim.param_groups:
98 | weight_decay = group.get("weight_decay", 0)
99 | weight_decays.append(weight_decay)
100 |
101 | # reset weight decay
102 | group["weight_decay"] = 0
103 |
104 | # update the parameters
105 | for p in group["params"]:
106 | if p.grad is not None and (p.ndim != 1 or not self.exclude_bias_n_norm):
107 | self.update_p(p, group, weight_decay)
108 |
109 | # update the optimizer
110 | self.optim.step(closure=closure)
111 |
112 | # return weight decay control to optimizer
113 | for group_idx, group in enumerate(self.optim.param_groups):
114 | group["weight_decay"] = weight_decays[group_idx]
115 |
116 | def update_p(self, p, group, weight_decay):
117 | # calculate new norms
118 | p_norm = torch.norm(p.data)
119 | g_norm = torch.norm(p.grad.data)
120 |
121 | if p_norm != 0 and g_norm != 0:
122 | # calculate new lr
123 | new_lr = (self.eta * p_norm) / (g_norm + p_norm * weight_decay + self.eps)
124 |
125 | # clip lr
126 | if self.clip:
127 | new_lr = min(new_lr / group["lr"], 1)
128 |
129 | # update params with clipped lr
130 | p.grad.data += weight_decay * p.data
131 | p.grad.data *= new_lr
132 |
--------------------------------------------------------------------------------
/solo/utils/metrics.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | from typing import Dict, List, Sequence
21 |
22 | import torch
23 |
24 |
25 | def accuracy_at_k(
26 | outputs: torch.Tensor, targets: torch.Tensor, top_k: Sequence[int] = (1, 5)
27 | ) -> Sequence[int]:
28 | """Computes the accuracy over the k top predictions for the specified values of k.
29 |
30 | Args:
31 | outputs (torch.Tensor): output of a classifier (logits or probabilities).
32 | targets (torch.Tensor): ground truth labels.
33 | top_k (Sequence[int], optional): sequence of top k values to compute the accuracy over.
34 | Defaults to (1, 5).
35 |
36 | Returns:
37 | Sequence[int]: accuracies at the desired k.
38 | """
39 |
40 | with torch.no_grad():
41 | maxk = max(top_k)
42 | batch_size = targets.size(0)
43 |
44 | _, pred = outputs.topk(maxk, 1, True, True)
45 | pred = pred.t()
46 | correct = pred.eq(targets.view(1, -1).expand_as(pred))
47 |
48 | res = []
49 | for k in top_k:
50 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True)
51 | res.append(correct_k.mul_(100.0 / batch_size))
52 | return res
53 |
54 |
55 | def weighted_mean(outputs: List[Dict], key: str, batch_size_key: str) -> float:
56 | """Computes the mean of the values of a key weighted by the batch size.
57 |
58 | Args:
59 | outputs (List[Dict]): list of dicts containing the outputs of a validation step.
60 | key (str): key of the metric of interest.
61 | batch_size_key (str): key of batch size values.
62 |
63 | Returns:
64 | float: weighted mean of the values of a key
65 | """
66 |
67 | value = 0
68 | n = 0
69 | for out in outputs:
70 | value += out[batch_size_key] * out[key]
71 | n += out[batch_size_key]
72 | value = value / n
73 | return value.squeeze(0)
74 |
--------------------------------------------------------------------------------
/solo/utils/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import math
21 | import warnings
22 | from typing import List, Tuple
23 |
24 | import torch
25 | import torch.distributed as dist
26 | import torch.nn as nn
27 |
28 |
29 | def _1d_filter(tensor: torch.Tensor) -> torch.Tensor:
30 | return tensor.isfinite()
31 |
32 |
33 | def _2d_filter(tensor: torch.Tensor) -> torch.Tensor:
34 | return tensor.isfinite().all(dim=1)
35 |
36 |
37 | def _single_input_filter(tensor: torch.Tensor) -> Tuple[torch.Tensor]:
38 | if len(tensor.size()) == 1:
39 | filter_func = _1d_filter
40 | elif len(tensor.size()) == 2:
41 | filter_func = _2d_filter
42 | else:
43 | raise RuntimeError("Only 1d and 2d tensors are supported.")
44 |
45 | selected = filter_func(tensor)
46 | tensor = tensor[selected]
47 |
48 | return tensor, selected
49 |
50 |
51 | def _multi_input_filter(tensors: List[torch.Tensor]) -> Tuple[torch.Tensor]:
52 | if len(tensors[0].size()) == 1:
53 | filter_func = _1d_filter
54 | elif len(tensors[0].size()) == 2:
55 | filter_func = _2d_filter
56 | else:
57 | raise RuntimeError("Only 1d and 2d tensors are supported.")
58 |
59 | selected = filter_func(tensors[0])
60 | for tensor in tensors[1:]:
61 | selected = torch.logical_and(selected, filter_func(tensor))
62 | tensors = [tensor[selected] for tensor in tensors]
63 |
64 | return tensors, selected
65 |
66 |
67 | def filter_inf_n_nan(tensors: List[torch.Tensor], return_indexes: bool = False):
68 | """Filters out inf and nans from any tensor.
69 | This is usefull when there are instability issues,
70 | which cause a small number of values to go bad.
71 |
72 | Args:
73 | tensor (List): tensor to remove nans and infs from.
74 |
75 | Returns:
76 | torch.Tensor: filtered view of the tensor without nans or infs.
77 | """
78 |
79 | if isinstance(tensors, torch.Tensor):
80 | tensors, selected = _single_input_filter(tensors)
81 | else:
82 | tensors, selected = _multi_input_filter(tensors)
83 |
84 | if return_indexes:
85 | return tensors, selected
86 | return tensors
87 |
88 |
89 | class FilterInfNNan(nn.Module):
90 | def __init__(self, module):
91 | """Layer that filters out inf and nans from any tensor.
92 | This is usefull when there are instability issues,
93 | which cause a small number of values to go bad.
94 |
95 | Args:
96 | tensor (List): tensor to remove nans and infs from.
97 |
98 | Returns:
99 | torch.Tensor: filtered view of the tensor without nans or infs.
100 | """
101 | super().__init__()
102 |
103 | self.module = module
104 |
105 | def forward(self, x: torch.Tensor) -> torch.Tensor:
106 | out = self.module(x)
107 | out = filter_inf_n_nan(out)
108 | return out
109 |
110 | def __getattr__(self, name):
111 | try:
112 | return super().__getattr__(name)
113 | except AttributeError:
114 | if name == "module":
115 | raise AttributeError()
116 | return getattr(self.module, name)
117 |
118 |
119 | def _no_grad_trunc_normal_(tensor, mean, std, a, b):
120 | """Copy & paste from PyTorch official master until it's in a few official releases - RW
121 | Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
122 | """
123 |
124 | def norm_cdf(x):
125 | """Computes standard normal cumulative distribution function"""
126 |
127 | return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
128 |
129 | if (mean < a - 2 * std) or (mean > b + 2 * std):
130 | warnings.warn(
131 | "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
132 | "The distribution of values may be incorrect.",
133 | stacklevel=2,
134 | )
135 |
136 | with torch.no_grad():
137 | # Values are generated by using a truncated uniform distribution and
138 | # then using the inverse CDF for the normal distribution.
139 | # Get upper and lower cdf values
140 | l = norm_cdf((a - mean) / std)
141 | u = norm_cdf((b - mean) / std)
142 |
143 | # Uniformly fill tensor with values from [l, u], then translate to
144 | # [2l-1, 2u-1].
145 | tensor.uniform_(2 * l - 1, 2 * u - 1)
146 |
147 | # Use inverse cdf transform for normal distribution to get truncated
148 | # standard normal
149 | tensor.erfinv_()
150 |
151 | # Transform to proper mean, std
152 | tensor.mul_(std * math.sqrt(2.0))
153 | tensor.add_(mean)
154 |
155 | # Clamp to ensure it's in the proper range
156 | tensor.clamp_(min=a, max=b)
157 | return tensor
158 |
159 |
160 | def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
161 | """Copy & paste from PyTorch official master until it's in a few official releases - RW
162 | Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
163 | """
164 |
165 | return _no_grad_trunc_normal_(tensor, mean, std, a, b)
166 |
167 |
168 | class GatherLayer(torch.autograd.Function):
169 | """Gathers tensors from all processes, supporting backward propagation."""
170 |
171 | @staticmethod
172 | def forward(ctx, input):
173 | ctx.save_for_backward(input)
174 | if dist.is_available() and dist.is_initialized():
175 | output = [torch.zeros_like(input) for _ in range(dist.get_world_size())]
176 | dist.all_gather(output, input)
177 | else:
178 | output = [input]
179 | return tuple(output)
180 |
181 | @staticmethod
182 | def backward(ctx, *grads):
183 | (input,) = ctx.saved_tensors
184 | if dist.is_available() and dist.is_initialized():
185 | grad_out = torch.zeros_like(input)
186 | grad_out[:] = grads[dist.get_rank()]
187 | else:
188 | grad_out = grads[0]
189 | return grad_out
190 |
191 |
192 | def gather(X, dim=0):
193 | """Gathers tensors from all processes, supporting backward propagation."""
194 | return torch.cat(GatherLayer.apply(X), dim=dim)
195 |
--------------------------------------------------------------------------------
/solo/utils/momentum.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import math
21 |
22 | import torch
23 | from torch import nn
24 |
25 |
26 | @torch.no_grad()
27 | def initialize_momentum_params(online_net: nn.Module, momentum_net: nn.Module):
28 | """Copies the parameters of the online network to the momentum network.
29 |
30 | Args:
31 | online_net (nn.Module): online network (e.g. online encoder, online projection, etc...).
32 | momentum_net (nn.Module): momentum network (e.g. momentum encoder,
33 | momentum projection, etc...).
34 | """
35 |
36 | params_online = online_net.parameters()
37 | params_momentum = momentum_net.parameters()
38 | for po, pm in zip(params_online, params_momentum):
39 | pm.data.copy_(po.data)
40 | pm.requires_grad = False
41 |
42 |
43 | class MomentumUpdater:
44 | def __init__(self, base_tau: float = 0.996, final_tau: float = 1.0):
45 | """Updates momentum parameters using exponential moving average.
46 |
47 | Args:
48 | base_tau (float, optional): base value of the weight decrease coefficient
49 | (should be in [0,1]). Defaults to 0.996.
50 | final_tau (float, optional): final value of the weight decrease coefficient
51 | (should be in [0,1]). Defaults to 1.0.
52 | """
53 |
54 | super().__init__()
55 |
56 | assert 0 <= base_tau <= 1
57 | assert 0 <= final_tau <= 1 and base_tau <= final_tau
58 |
59 | self.base_tau = base_tau
60 | self.cur_tau = base_tau
61 | self.final_tau = final_tau
62 |
63 | @torch.no_grad()
64 | def update(self, online_net: nn.Module, momentum_net: nn.Module):
65 | """Performs the momentum update for each param group.
66 |
67 | Args:
68 | online_net (nn.Module): online network (e.g. online encoder, online projection, etc...).
69 | momentum_net (nn.Module): momentum network (e.g. momentum encoder,
70 | momentum projection, etc...).
71 | """
72 |
73 | for op, mp in zip(online_net.parameters(), momentum_net.parameters()):
74 | mp.data = self.cur_tau * mp.data + (1 - self.cur_tau) * op.data
75 |
76 | def update_tau(self, cur_step: int, max_steps: int):
77 | """Computes the next value for the weighting decrease coefficient tau using cosine annealing.
78 |
79 | Args:
80 | cur_step (int): number of gradient steps so far.
81 | max_steps (int): overall number of gradient steps in the whole training.
82 | """
83 |
84 | self.cur_tau = (
85 | self.final_tau
86 | - (self.final_tau - self.base_tau) * (math.cos(math.pi * cur_step / max_steps) + 1) / 2
87 | )
88 |
--------------------------------------------------------------------------------
/solo/utils/pretrain_dataloader.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | import os
21 | import random
22 | from pathlib import Path
23 | from typing import Any, Callable, Iterable, List, Optional, Sequence, Type, Union
24 |
25 | import torch
26 | import torchvision
27 | from PIL import Image, ImageFilter, ImageOps
28 | from torch.utils.data import DataLoader
29 | from torch.utils.data.dataset import Dataset
30 | from torchvision import transforms
31 | from torchvision.datasets import STL10, ImageFolder
32 |
33 |
34 | def dataset_with_index(DatasetClass: Type[Dataset]) -> Type[Dataset]:
35 | """Factory for datasets that also returns the data index.
36 |
37 | Args:
38 | DatasetClass (Type[Dataset]): Dataset class to be wrapped.
39 |
40 | Returns:
41 | Type[Dataset]: dataset with index.
42 | """
43 |
44 | class DatasetWithIndex(DatasetClass):
45 | def __getitem__(self, index):
46 | data = super().__getitem__(index)
47 | return (index, *data)
48 |
49 | return DatasetWithIndex
50 |
51 |
52 | class CustomDatasetWithoutLabels(Dataset):
53 | def __init__(self, root, transform=None):
54 | self.root = Path(root)
55 | self.transform = transform
56 | self.images = os.listdir(root)
57 |
58 | def __getitem__(self, index):
59 | path = self.root / self.images[index]
60 | x = Image.open(path).convert("RGB")
61 | if self.transform is not None:
62 | x = self.transform(x)
63 | return x, -1
64 |
65 | def __len__(self):
66 | return len(self.images)
67 |
68 |
69 | class GaussianBlur:
70 | def __init__(self, sigma: Sequence[float] = [0.1, 2.0]):
71 | """Gaussian blur as a callable object.
72 |
73 | Args:
74 | sigma (Sequence[float]): range to sample the radius of the gaussian blur filter.
75 | Defaults to [0.1, 2.0].
76 | """
77 |
78 | self.sigma = sigma
79 |
80 | def __call__(self, x: torch.Tensor) -> torch.Tensor:
81 | """Applies gaussian blur to an input image.
82 |
83 | Args:
84 | x (torch.Tensor): an image in the tensor format.
85 |
86 | Returns:
87 | torch.Tensor: returns a blurred image.
88 | """
89 |
90 | sigma = random.uniform(self.sigma[0], self.sigma[1])
91 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
92 | return x
93 |
94 |
95 | class Solarization:
96 | """Solarization as a callable object."""
97 |
98 | def __call__(self, img: Image) -> Image:
99 | """Applies solarization to an input image.
100 |
101 | Args:
102 | img (Image): an image in the PIL.Image format.
103 |
104 | Returns:
105 | Image: a solarized image.
106 | """
107 |
108 | return ImageOps.solarize(img)
109 |
110 |
111 | class NCropAugmentation:
112 | def __init__(self, transform: Union[Callable, Sequence], num_crops: Optional[int] = None):
113 | """Creates a pipeline that apply a transformation pipeline multiple times.
114 |
115 | Args:
116 | transform (Union[Callable, Sequence]): transformation pipeline or list of
117 | transformation pipelines.
118 | num_crops: if transformation pipeline is not a list, applies the same
119 | pipeline num_crops times, if it is a list, this is ignored and each
120 | element of the list is applied once.
121 | """
122 |
123 | self.transform = transform
124 |
125 | if isinstance(transform, Iterable):
126 | self.one_transform_per_crop = True
127 | assert num_crops == len(transform)
128 | else:
129 | self.one_transform_per_crop = False
130 | self.num_crops = num_crops
131 |
132 | def __call__(self, x: Image) -> List[torch.Tensor]:
133 | """Applies transforms n times to generate n crops.
134 |
135 | Args:
136 | x (Image): an image in the PIL.Image format.
137 |
138 | Returns:
139 | List[torch.Tensor]: an image in the tensor format.
140 | """
141 |
142 | if self.one_transform_per_crop:
143 | return [transform(x) for transform in self.transform]
144 | else:
145 | return [self.transform(x) for _ in range(self.num_crops)]
146 |
147 |
148 | class BaseTransform:
149 | """Adds callable base class to implement different transformation pipelines."""
150 |
151 | def __call__(self, x: Image) -> torch.Tensor:
152 | return self.transform(x)
153 |
154 | def __repr__(self) -> str:
155 | return str(self.transform)
156 |
157 |
158 | class CifarTransform(BaseTransform):
159 | def __init__(
160 | self,
161 | brightness: float,
162 | contrast: float,
163 | saturation: float,
164 | hue: float,
165 | gaussian_prob: float = 0.0,
166 | solarization_prob: float = 0.0,
167 | min_scale: float = 0.08,
168 | ):
169 | """Applies cifar transformations.
170 |
171 | Args:
172 | brightness (float): sampled uniformly in [max(0, 1 - brightness), 1 + brightness].
173 | contrast (float): sampled uniformly in [max(0, 1 - contrast), 1 + contrast].
174 | saturation (float): sampled uniformly in [max(0, 1 - saturation), 1 + saturation].
175 | hue (float): sampled uniformly in [-hue, hue].
176 | gaussian_prob (float, optional): probability of applying gaussian blur. Defaults to 0.0.
177 | solarization_prob (float, optional): probability of applying solarization. Defaults
178 | to 0.0.
179 | min_scale (float, optional): minimum scale of the crops. Defaults to 0.08.
180 | """
181 |
182 | super().__init__()
183 |
184 | self.transform = transforms.Compose(
185 | [
186 | transforms.RandomResizedCrop(
187 | (32, 32),
188 | scale=(min_scale, 1.0),
189 | interpolation=transforms.InterpolationMode.BICUBIC,
190 | ),
191 | transforms.RandomApply(
192 | [transforms.ColorJitter(brightness, contrast, saturation, hue)], p=0.8
193 | ),
194 | transforms.RandomGrayscale(p=0.2),
195 | transforms.RandomApply([GaussianBlur()], p=gaussian_prob),
196 | transforms.RandomApply([Solarization()], p=solarization_prob),
197 | transforms.RandomHorizontalFlip(p=0.5),
198 | transforms.ToTensor(),
199 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
200 | ]
201 | )
202 |
203 |
204 | class STLTransform(BaseTransform):
205 | def __init__(
206 | self,
207 | brightness: float,
208 | contrast: float,
209 | saturation: float,
210 | hue: float,
211 | gaussian_prob: float = 0.0,
212 | solarization_prob: float = 0.0,
213 | min_scale: float = 0.08,
214 | ):
215 | """Applies STL10 transformations.
216 |
217 | Args:
218 | brightness (float): sampled uniformly in [max(0, 1 - brightness), 1 + brightness].
219 | contrast (float): sampled uniformly in [max(0, 1 - contrast), 1 + contrast].
220 | saturation (float): sampled uniformly in [max(0, 1 - saturation), 1 + saturation].
221 | hue (float): sampled uniformly in [-hue, hue].
222 | gaussian_prob (float, optional): probability of applying gaussian blur. Defaults to 0.0.
223 | solarization_prob (float, optional): probability of applying solarization. Defaults
224 | to 0.0.
225 | min_scale (float, optional): minimum scale of the crops. Defaults to 0.08.
226 | """
227 |
228 | super().__init__()
229 | self.transform = transforms.Compose(
230 | [
231 | transforms.RandomResizedCrop(
232 | (96, 96),
233 | scale=(min_scale, 1.0),
234 | interpolation=transforms.InterpolationMode.BICUBIC,
235 | ),
236 | transforms.RandomApply(
237 | [transforms.ColorJitter(brightness, contrast, saturation, hue)], p=0.8
238 | ),
239 | transforms.RandomGrayscale(p=0.2),
240 | transforms.RandomApply([GaussianBlur()], p=gaussian_prob),
241 | transforms.RandomApply([Solarization()], p=solarization_prob),
242 | transforms.RandomHorizontalFlip(p=0.5),
243 | transforms.ToTensor(),
244 | transforms.Normalize((0.4914, 0.4823, 0.4466), (0.247, 0.243, 0.261)),
245 | ]
246 | )
247 |
248 |
249 | class ImagenetTransform(BaseTransform):
250 | def __init__(
251 | self,
252 | brightness: float,
253 | contrast: float,
254 | saturation: float,
255 | hue: float,
256 | gaussian_prob: float = 0.5,
257 | solarization_prob: float = 0.0,
258 | size: int = 224,
259 | min_scale: float = 0.08,
260 | ):
261 | """Class that applies Imagenet transformations.
262 |
263 | Args:
264 | brightness (float): sampled uniformly in [max(0, 1 - brightness), 1 + brightness].
265 | contrast (float): sampled uniformly in [max(0, 1 - contrast), 1 + contrast].
266 | saturation (float): sampled uniformly in [max(0, 1 - saturation), 1 + saturation].
267 | hue (float): sampled uniformly in [-hue, hue].
268 | gaussian_prob (float, optional): probability of applying gaussian blur. Defaults to 0.0.
269 | solarization_prob (float, optional): probability of applying solarization. Defaults
270 | to 0.0.
271 | min_scale (float, optional): minimum scale of the crops. Defaults to 0.08.
272 | size (int, optional): size of the crop. Defaults to 224.
273 | """
274 |
275 | super().__init__()
276 | self.transform = transforms.Compose(
277 | [
278 | transforms.RandomResizedCrop(
279 | size,
280 | scale=(min_scale, 1.0),
281 | interpolation=transforms.InterpolationMode.BICUBIC,
282 | ),
283 | transforms.RandomApply(
284 | [transforms.ColorJitter(brightness, contrast, saturation, hue)],
285 | p=0.8,
286 | ),
287 | transforms.RandomGrayscale(p=0.2),
288 | transforms.RandomApply([GaussianBlur()], p=gaussian_prob),
289 | transforms.RandomApply([Solarization()], p=solarization_prob),
290 | transforms.RandomHorizontalFlip(p=0.5),
291 | transforms.ToTensor(),
292 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.228, 0.224, 0.225)),
293 | ]
294 | )
295 |
296 |
297 | class CustomTransform(BaseTransform):
298 | def __init__(
299 | self,
300 | brightness: float,
301 | contrast: float,
302 | saturation: float,
303 | hue: float,
304 | gaussian_prob: float = 0.5,
305 | solarization_prob: float = 0.0,
306 | min_scale: float = 0.08,
307 | size: int = 224,
308 | mean: Sequence[float] = (0.485, 0.456, 0.406),
309 | std: Sequence[float] = (0.228, 0.224, 0.225),
310 | ):
311 | """Class that applies Custom transformations.
312 | If you want to do exoteric augmentations, you can just re-write this class.
313 |
314 | Args:
315 | brightness (float): sampled uniformly in [max(0, 1 - brightness), 1 + brightness].
316 | contrast (float): sampled uniformly in [max(0, 1 - contrast), 1 + contrast].
317 | saturation (float): sampled uniformly in [max(0, 1 - saturation), 1 + saturation].
318 | hue (float): sampled uniformly in [-hue, hue].
319 | gaussian_prob (float, optional): probability of applying gaussian blur. Defaults to 0.0.
320 | solarization_prob (float, optional): probability of applying solarization. Defaults
321 | to 0.0.
322 | min_scale (float, optional): minimum scale of the crops. Defaults to 0.08.
323 | size (int, optional): size of the crop. Defaults to 224.
324 | mean (Sequence[float], optional): mean values for normalization.
325 | Defaults to (0.485, 0.456, 0.406).
326 | std (Sequence[float], optional): std values for normalization.
327 | Defaults to (0.228, 0.224, 0.225).
328 | """
329 |
330 | super().__init__()
331 | self.transform = transforms.Compose(
332 | [
333 | transforms.RandomResizedCrop(
334 | size,
335 | scale=(min_scale, 1.0),
336 | interpolation=transforms.InterpolationMode.BICUBIC,
337 | ),
338 | transforms.RandomApply(
339 | [transforms.ColorJitter(brightness, contrast, saturation, hue)],
340 | p=0.8,
341 | ),
342 | transforms.RandomGrayscale(p=0.2),
343 | transforms.RandomApply([GaussianBlur()], p=gaussian_prob),
344 | transforms.RandomApply([Solarization()], p=solarization_prob),
345 | transforms.RandomHorizontalFlip(p=0.5),
346 | transforms.ToTensor(),
347 | transforms.Normalize(mean=mean, std=std),
348 | ]
349 | )
350 |
351 |
352 | class MulticropAugmentation:
353 | def __init__(
354 | self,
355 | transform: Callable,
356 | size_crops: Sequence[int],
357 | num_crops: Sequence[int],
358 | min_scales: Sequence[float],
359 | max_scale_crops: Sequence[float],
360 | ):
361 | """Class that applies multi crop augmentation.
362 |
363 | Args:
364 | transform (Callable): transformation callable without cropping.
365 | size_crops (Sequence[int]): a sequence of sizes of the crops.
366 | num_crops (Sequence[int]): a sequence number of crops per crop size.
367 | min_scales (Sequence[float]): sequence of minimum crop scales per crop
368 | size.
369 | max_scale_crops (Sequence[float]): sequence of maximum crop scales per crop
370 | size.
371 | """
372 |
373 | self.size_crops = size_crops
374 | self.num_crops = num_crops
375 | self.min_scales = min_scales
376 | self.max_scale_crops = max_scale_crops
377 |
378 | self.transforms = []
379 | for i in range(len(size_crops)):
380 | rrc = transforms.RandomResizedCrop(
381 | size_crops[i],
382 | scale=(min_scales[i], max_scale_crops[i]),
383 | interpolation=transforms.InterpolationMode.BICUBIC,
384 | )
385 | full_transform = transforms.Compose([rrc, transform])
386 | self.transforms.append(full_transform)
387 |
388 | def __call__(self, x: Image) -> List[torch.Tensor]:
389 | """Applies multi crop augmentations.
390 |
391 | Args:
392 | x (Image): an image in the PIL.Image format.
393 |
394 | Returns:
395 | List[torch.Tensor]: a list of crops in the tensor format.
396 | """
397 |
398 | imgs = []
399 | for n, transform in zip(self.num_crops, self.transforms):
400 | imgs.extend([transform(x) for i in range(n)])
401 | return imgs
402 |
403 |
404 | class MulticropCifarTransform(BaseTransform):
405 | def __init__(self):
406 | """Class that applies multicrop transform for CIFAR"""
407 |
408 | super().__init__()
409 |
410 | self.transform = transforms.Compose(
411 | [
412 | transforms.RandomHorizontalFlip(p=0.5),
413 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
414 | transforms.RandomGrayscale(p=0.2),
415 | transforms.ToTensor(),
416 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
417 | ]
418 | )
419 |
420 |
421 | class MulticropSTLTransform(BaseTransform):
422 | def __init__(self):
423 | """Class that applies multicrop transform for STL10"""
424 |
425 | super().__init__()
426 | self.transform = transforms.Compose(
427 | [
428 | transforms.RandomHorizontalFlip(p=0.5),
429 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
430 | transforms.RandomGrayscale(p=0.2),
431 | transforms.ToTensor(),
432 | transforms.Normalize((0.4914, 0.4823, 0.4466), (0.247, 0.243, 0.261)),
433 | ]
434 | )
435 |
436 |
437 | class MulticropImagenetTransform(BaseTransform):
438 | def __init__(
439 | self,
440 | brightness: float,
441 | contrast: float,
442 | saturation: float,
443 | hue: float,
444 | gaussian_prob: float = 0.5,
445 | solarization_prob: float = 0.0,
446 | ):
447 | """Class that applies multicrop transform for Imagenet.
448 |
449 | Args:
450 | brightness (float): sampled uniformly in [max(0, 1 - brightness), 1 + brightness].
451 | contrast (float): sampled uniformly in [max(0, 1 - contrast), 1 + contrast].
452 | saturation (float): sampled uniformly in [max(0, 1 - saturation), 1 + saturation].
453 | hue (float): sampled uniformly in [-hue, hue].
454 | gaussian_prob (float, optional): probability of applying gaussian blur. Defaults to 0.5.
455 | solarization_prob (float, optional): minimum scale of the crops. Defaults to 0.0.
456 | """
457 |
458 | super().__init__()
459 | self.transform = transforms.Compose(
460 | [
461 | transforms.RandomApply(
462 | [transforms.ColorJitter(brightness, contrast, saturation, hue)],
463 | p=0.8,
464 | ),
465 | transforms.RandomGrayscale(p=0.2),
466 | transforms.RandomApply([GaussianBlur()], p=gaussian_prob),
467 | transforms.RandomApply([Solarization()], p=solarization_prob),
468 | transforms.RandomHorizontalFlip(p=0.5),
469 | transforms.ToTensor(),
470 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.228, 0.224, 0.225)),
471 | ]
472 | )
473 |
474 |
475 | class MulticropCustomTransform(BaseTransform):
476 | def __init__(
477 | self,
478 | brightness: float,
479 | contrast: float,
480 | saturation: float,
481 | hue: float,
482 | gaussian_prob: float = 0.5,
483 | solarization_prob: float = 0.0,
484 | mean: Sequence[float] = (0.485, 0.456, 0.406),
485 | std: Sequence[float] = (0.228, 0.224, 0.225),
486 | ):
487 | """Class that applies multicrop transform for Custom Datasets.
488 | If you want to do exoteric augmentations, you can just re-write this class.
489 |
490 | Args:
491 | brightness (float): sampled uniformly in [max(0, 1 - brightness), 1 + brightness].
492 | contrast (float): sampled uniformly in [max(0, 1 - contrast), 1 + contrast].
493 | saturation (float): sampled uniformly in [max(0, 1 - saturation), 1 + saturation].
494 | hue (float): sampled uniformly in [-hue, hue].
495 | gaussian_prob (float, optional): probability of applying gaussian blur. Defaults to 0.5.
496 | solarization_prob (float, optional): minimum scale of the crops. Defaults to 0.0.
497 | mean (Sequence[float], optional): mean values for normalization.
498 | Defaults to (0.485, 0.456, 0.406).
499 | std (Sequence[float], optional): std values for normalization.
500 | Defaults to (0.228, 0.224, 0.225).
501 | """
502 |
503 | super().__init__()
504 | self.transform = transforms.Compose(
505 | [
506 | transforms.RandomApply(
507 | [transforms.ColorJitter(brightness, contrast, saturation, hue)],
508 | p=0.8,
509 | ),
510 | transforms.RandomGrayscale(p=0.2),
511 | transforms.RandomApply([GaussianBlur()], p=gaussian_prob),
512 | transforms.RandomApply([Solarization()], p=solarization_prob),
513 | transforms.RandomHorizontalFlip(p=0.5),
514 | transforms.ToTensor(),
515 | transforms.Normalize(mean=mean, std=std),
516 | ]
517 | )
518 |
519 |
520 | def prepare_transform(dataset: str, multicrop: bool = False, **kwargs) -> Any:
521 | """Prepares transforms for a specific dataset. Optionally uses multi crop.
522 |
523 | Args:
524 | dataset (str): name of the dataset.
525 | multicrop (bool, optional): whether or not to use multi crop. Defaults to False.
526 |
527 | Returns:
528 | Any: a transformation for a specific dataset.
529 | """
530 |
531 | if dataset in ["cifar10", "cifar100"]:
532 | return CifarTransform(**kwargs) if not multicrop else MulticropCifarTransform()
533 | elif dataset == "stl10":
534 | return STLTransform(**kwargs) if not multicrop else MulticropSTLTransform()
535 | elif dataset in ["imagenet", "imagenet100"]:
536 | return (
537 | ImagenetTransform(**kwargs) if not multicrop else MulticropImagenetTransform(**kwargs)
538 | )
539 | elif dataset == "custom":
540 | return CustomTransform(**kwargs) if not multicrop else MulticropCustomTransform(**kwargs)
541 |
542 |
543 | def prepare_n_crop_transform(
544 | transform: Callable, num_crops: Optional[int] = None
545 | ) -> NCropAugmentation:
546 | """Turns a single crop transformation to an N crops transformation.
547 |
548 | Args:
549 | transform (Callable): a transformation.
550 | num_crops (Optional[int], optional): number of crops. Defaults to None.
551 |
552 | Returns:
553 | NCropAugmentation: an N crop transformation.
554 | """
555 |
556 | return NCropAugmentation(transform, num_crops)
557 |
558 |
559 | def prepare_multicrop_transform(
560 | transform: Callable,
561 | size_crops: Sequence[int],
562 | num_crops: Optional[Sequence[int]] = None,
563 | min_scales: Optional[Sequence[float]] = None,
564 | max_scale_crops: Optional[Sequence[float]] = None,
565 | ) -> MulticropAugmentation:
566 | """Prepares multicrop transformations by creating custom crops given the parameters.
567 |
568 | Args:
569 | transform (Callable): transformation callable without cropping.
570 | size_crops (Sequence[int]): a sequence of sizes of the crops.
571 | num_crops (Optional[Sequence[int]]): list of number of crops per crop size.
572 | min_scales (Optional[Sequence[float]]): sequence of minimum crop scales per crop
573 | size.
574 | max_scale_crops (Optional[Sequence[float]]): sequence of maximum crop scales per crop
575 | size.
576 |
577 | Returns:
578 | MulticropAugmentation: prepared augmentation pipeline that supports multicrop with
579 | different sizes.
580 | """
581 |
582 | if num_crops is None:
583 | num_crops = [2, 6]
584 | if min_scales is None:
585 | min_scales = [0.14, 0.05]
586 | if max_scale_crops is None:
587 | max_scale_crops = [1.0, 0.14]
588 |
589 | return MulticropAugmentation(
590 | transform,
591 | size_crops=size_crops,
592 | num_crops=num_crops,
593 | min_scales=min_scales,
594 | max_scale_crops=max_scale_crops,
595 | )
596 |
597 |
598 | def prepare_datasets(
599 | dataset: str,
600 | transform: Callable,
601 | data_dir: Optional[Union[str, Path]] = None,
602 | train_dir: Optional[Union[str, Path]] = None,
603 | no_labels: Optional[Union[str, Path]] = False,
604 | ) -> Dataset:
605 | """Prepares the desired dataset.
606 |
607 | Args:
608 | dataset (str): the name of the dataset.
609 | transform (Callable): a transformation.
610 | data_dir (Optional[Union[str, Path]], optional): the directory to load data from.
611 | Defaults to None.
612 | train_dir (Optional[Union[str, Path]], optional): training data directory
613 | to be appended to data_dir. Defaults to None.
614 | no_labels (Optional[bool], optional): if the custom dataset has no labels.
615 |
616 | Returns:
617 | Dataset: the desired dataset with transformations.
618 | """
619 |
620 | if data_dir is None:
621 | sandbox_folder = Path(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
622 | data_dir = sandbox_folder / "datasets"
623 |
624 | if train_dir is None:
625 | train_dir = Path(f"{dataset}/train")
626 | else:
627 | train_dir = Path(train_dir)
628 |
629 | if dataset in ["cifar10", "cifar100"]:
630 | DatasetClass = vars(torchvision.datasets)[dataset.upper()]
631 | train_dataset = dataset_with_index(DatasetClass)(
632 | data_dir / train_dir,
633 | train=True,
634 | download=True,
635 | transform=transform,
636 | )
637 |
638 | elif dataset == "stl10":
639 | train_dataset = dataset_with_index(STL10)(
640 | data_dir / train_dir,
641 | split="train+unlabeled",
642 | download=True,
643 | transform=transform,
644 | )
645 |
646 | elif dataset in ["imagenet", "imagenet100"]:
647 | train_dir = data_dir / train_dir
648 | train_dataset = dataset_with_index(ImageFolder)(train_dir, transform)
649 |
650 | elif dataset == "custom":
651 | train_dir = data_dir / train_dir
652 |
653 | if no_labels:
654 | dataset_class = CustomDatasetWithoutLabels
655 | else:
656 | dataset_class = ImageFolder
657 |
658 | train_dataset = dataset_with_index(dataset_class)(train_dir, transform)
659 |
660 | return train_dataset
661 |
662 |
663 | def prepare_dataloader(
664 | train_dataset: Dataset, batch_size: int = 64, num_workers: int = 4
665 | ) -> DataLoader:
666 | """Prepares the training dataloader for pretraining.
667 |
668 | Args:
669 | train_dataset (Dataset): the name of the dataset.
670 | batch_size (int, optional): batch size. Defaults to 64.
671 | num_workers (int, optional): number of workers. Defaults to 4.
672 |
673 | Returns:
674 | DataLoader: the training dataloader with the desired dataset.
675 | """
676 |
677 | train_loader = DataLoader(
678 | train_dataset,
679 | batch_size=batch_size,
680 | shuffle=True,
681 | num_workers=num_workers,
682 | pin_memory=True,
683 | drop_last=True,
684 | )
685 | return train_loader
686 |
--------------------------------------------------------------------------------
/solo/utils/sinkhorn_knopp.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 | # Adapted from https://github.com/facebookresearch/swav.
21 |
22 | import torch
23 | import torch.distributed as dist
24 |
25 |
26 | class SinkhornKnopp(torch.nn.Module):
27 | def __init__(self, num_iters: int = 3, epsilon: float = 0.05, world_size: int = 1):
28 | """Approximates optimal transport using the Sinkhorn-Knopp algorithm.
29 |
30 | A simple iterative method to approach the double stochastic matrix is to alternately rescale
31 | rows and columns of the matrix to sum to 1.
32 |
33 | Args:
34 | num_iters (int, optional): number of times to perform row and column normalization.
35 | Defaults to 3.
36 | epsilon (float, optional): weight for the entropy regularization term. Defaults to 0.05.
37 | world_size (int, optional): number of nodes for distributed training. Defaults to 1.
38 | """
39 |
40 | super().__init__()
41 | self.num_iters = num_iters
42 | self.epsilon = epsilon
43 | self.world_size = world_size
44 |
45 | @torch.no_grad()
46 | def forward(self, Q: torch.Tensor) -> torch.Tensor:
47 | """Produces assignments using Sinkhorn-Knopp algorithm.
48 |
49 | Applies the entropy regularization, normalizes the Q matrix and then normalizes rows and
50 | columns in an alternating fashion for num_iter times. Before returning it normalizes again
51 | the columns in order for the output to be an assignment of samples to prototypes.
52 |
53 | Args:
54 | Q (torch.Tensor): cosine similarities between the features of the
55 | samples and the prototypes.
56 |
57 | Returns:
58 | torch.Tensor: assignment of samples to prototypes according to optimal transport.
59 | """
60 |
61 | Q = torch.exp(Q / self.epsilon).t()
62 | B = Q.shape[1] * self.world_size
63 | K = Q.shape[0] # num prototypes
64 |
65 | # make the matrix sums to 1
66 | sum_Q = torch.sum(Q)
67 | if dist.is_available() and dist.is_initialized():
68 | dist.all_reduce(sum_Q)
69 | Q /= sum_Q
70 |
71 | for it in range(self.num_iters):
72 | # normalize each row: total weight per prototype must be 1/K
73 | sum_of_rows = torch.sum(Q, dim=1, keepdim=True)
74 | if dist.is_available() and dist.is_initialized():
75 | dist.all_reduce(sum_of_rows)
76 | Q /= sum_of_rows
77 | Q /= K
78 |
79 | # normalize each column: total weight per sample must be 1/B
80 | Q /= torch.sum(Q, dim=0, keepdim=True)
81 | Q /= B
82 |
83 | Q *= B # the colomns must sum to 1 so that Q is an assignment
84 | return Q.t()
85 |
--------------------------------------------------------------------------------
/solo/utils/whitening.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 solo-learn development team.
2 |
3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of
4 | # this software and associated documentation files (the "Software"), to deal in
5 | # the Software without restriction, including without limitation the rights to use,
6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the
7 | # Software, and to permit persons to whom the Software is furnished to do so,
8 | # subject to the following conditions:
9 |
10 | # The above copyright notice and this permission notice shall be included in all copies
11 | # or substantial portions of the Software.
12 |
13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
18 | # DEALINGS IN THE SOFTWARE.
19 |
20 |
21 | import torch
22 | import torch.nn as nn
23 | from torch.cuda.amp import custom_fwd
24 | from torch.nn.functional import conv2d
25 |
26 |
27 | class Whitening2d(nn.Module):
28 | def __init__(self, output_dim: int, eps: float = 0.0):
29 | """Layer that computes hard whitening for W-MSE using the Cholesky decomposition.
30 |
31 | Args:
32 | output_dim (int): number of dimension of projected features.
33 | eps (float, optional): eps for numerical stability in Cholesky decomposition. Defaults
34 | to 0.0.
35 | """
36 |
37 | super(Whitening2d, self).__init__()
38 | self.output_dim = output_dim
39 | self.eps = eps
40 |
41 | @custom_fwd(cast_inputs=torch.float32)
42 | def forward(self, x: torch.Tensor) -> torch.Tensor:
43 | """Performs whitening using the Cholesky decomposition.
44 |
45 | Args:
46 | x (torch.Tensor): a batch or slice of projected features.
47 |
48 | Returns:
49 | torch.Tensor: a batch or slice of whitened features.
50 | """
51 |
52 | x = x.unsqueeze(2).unsqueeze(3)
53 | m = x.mean(0).view(self.output_dim, -1).mean(-1).view(1, -1, 1, 1)
54 | xn = x - m
55 |
56 | T = xn.permute(1, 0, 2, 3).contiguous().view(self.output_dim, -1)
57 | f_cov = torch.mm(T, T.permute(1, 0)) / (T.shape[-1] - 1)
58 |
59 | eye = torch.eye(self.output_dim).type(f_cov.type())
60 |
61 | f_cov_shrinked = (1 - self.eps) * f_cov + self.eps * eye
62 |
63 | inv_sqrt = torch.triangular_solve(eye, torch.cholesky(f_cov_shrinked), upper=False)[0]
64 | inv_sqrt = inv_sqrt.contiguous().view(self.output_dim, self.output_dim, 1, 1)
65 |
66 | decorrelated = conv2d(xn, inv_sqrt)
67 |
68 | return decorrelated.squeeze(2).squeeze(2)
69 |
--------------------------------------------------------------------------------