├── .gitignore ├── LICENSE.md ├── README.md ├── asset └── image.png ├── bash_files ├── moco.sh ├── simco.sh └── simmoco.sh ├── main_pretrain.py └── solo ├── __init__.py ├── args ├── __init__.py ├── dataset.py ├── setup.py └── utils.py ├── losses ├── __init__.py ├── dual_temperature_loss.py └── moco.py ├── methods ├── __init__.py ├── base.py ├── dali.py ├── mocov2.py ├── mocov2plus.py ├── simco_dual_temperature.py └── simmoco_dual_temperature.py └── utils ├── __init__.py ├── auto_umap.py ├── backbones.py ├── checkpointer.py ├── classification_dataloader.py ├── dali_dataloader.py ├── kmeans.py ├── knn.py ├── lars.py ├── metrics.py ├── misc.py ├── momentum.py ├── pretrain_dataloader.py ├── sinkhorn_knopp.py └── whitening.py /.gitignore: -------------------------------------------------------------------------------- 1 | wandb 2 | trained_models 3 | code 4 | data 5 | .env 6 | 7 | **/__pycache__/** 8 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Kang Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dual Temperature Helps Contrastive Learning Without Many Negative Samples: Towards Understanding and Simplifying MoCo (Accepted by CVPR2022) 2 | 3 | 4 | Chaoning Zhang, Kang Zhang, Trung X. Pham, Axi Niu, Zhinan Qiao, Chang D. Yoo, In So Kweon 5 | 6 | Contrastive learning (CL) is widely known to require many negative samples, 65536 in MoCo for instance, for which the performance of a dictionary-free framework is often inferior because the negative sample size (NSS) is limited by its mini-batch size (MBS). To decouple the NSS from the MBS, a dynamic dictionary has been adopted in a large volume of CL frameworks, among which arguably the most popular one is MoCo family. In essence, MoCo adopts a momentum-based queue dictionary, for which we perform a fine-grained analysis of its size and consistency. We point out that InfoNCE loss used in MoCo implicitly attract anchors to their corresponding positive sample with various strength of penalties and identify such inter-anchor hardness-awareness property as a major reason for the necessity of a large dictionary. Our findings motivate us to simplify MoCo v2 via the removal of its dictionary as well as momentum. Based on an InfoNCE with the proposed dual temperature, our simplified frameworks, SimMoCo and SimCo, outperform MoCo v2 by a visible margin. Moreover, our work bridges the gap between CL and non-CL frameworks, contributing to a more unified understanding of these two mainstream frameworks in SSL. 7 | 8 | 9 | This repository is the official implementation of ["Dual Temperature Helps Contrastive Learning Without Many Negative Samples: Towards Understanding and Simplifying MoCo"](https://arxiv.org/abs/2203.17248). 10 | 11 | 12 | 13 | 14 | --- 15 | See also our other works: 16 | 17 | Decoupled Adversarial Contrastive Learning for Self-supervised Adversarial Robustness (Accepted by ECCV2022 oral presentation) [code](https://github.com/pantheon5100/DeACL.git) [paper](https://arxiv.org/abs/2207.10899) 18 | 19 | --- 20 | 21 | # Dual Temperature InfoNCE Loss 22 | You can simply replace your original loss with dual-temperature loss from the following code: 23 | ```python 24 | # q1 is the anchor and k2 is the positive sample 25 | # The intra-anchor hardness-awareness is controlled by `temperature` parameter. 26 | # The inter-anchor hardness awareness is controlled by `dt_m` parameter, 27 | # and temperature is calculated by dt_m * temperature. 28 | nce_loss = dual_temperature_loss_func(q1, k2, 29 | temperature=temperature, 30 | dt_m=dt_m) 31 | 32 | def dual_temperature_loss_func( 33 | query: torch.Tensor, 34 | key: torch.Tensor, 35 | temperature=0.1, 36 | dt_m=10, 37 | ) -> torch.Tensor: 38 | 39 | """ 40 | query: anchor sample. 41 | key: positive sample. 42 | temperature: intra-anchor hardness-awareness control temperature. 43 | dt_m: the scalar number to get inter-anchor hardness awareness temperature. 44 | inter-anchor hardness awareness temperature is calculated by dt_m * temperature 45 | """ 46 | 47 | # intra-anchor hardness-awareness 48 | b = query.size(0) 49 | pos = torch.einsum("nc,nc->n", [query, key]).unsqueeze(-1) 50 | 51 | # Selecte the intra negative samples according the updata time, 52 | neg = torch.einsum("nc,ck->nk", [query, key.T]) 53 | mask_neg = torch.ones_like(neg, dtype=bool) 54 | mask_neg.fill_diagonal_(False) 55 | neg = neg[mask_neg].reshape(neg.size(0), neg.size(1)-1) 56 | logits = torch.cat([pos, neg], dim=1) 57 | 58 | logits_intra = logits / temperature 59 | prob_intra = F.softmax(logits_intra, dim=1) 60 | 61 | # inter-anchor hardness-awareness 62 | logits_inter = logits / (temperature*dt_m) 63 | prob_inter = F.softmax(logits_inter, dim=1) 64 | 65 | # hardness-awareness factor 66 | inter_intra = (1 - prob_inter[:, 0]) / (1 - prob_intra[:, 0]) 67 | 68 | loss = -torch.nn.functional.log_softmax(logits_intra, dim=-1)[:, 0] 69 | 70 | # final loss 71 | loss = inter_intra.detach() * loss 72 | loss = loss.mean() 73 | 74 | return loss 75 | 76 | ``` 77 | 78 | # 🔧 Enviroment 79 | 80 | Please refer [solo-learn](https://github.com/vturrisi/solo-learn) to install the enviroment. 81 | 82 | > First clone the repo. 83 | > 84 | > Then, to install solo-learn with Dali and/or UMAP support, use: 85 | > 86 | > `pip3 install .[dali,umap,h5] --extra-index-url https://developer.download.nvidia.com/compute/redist` 87 | 88 | 89 | # Dataset 90 | CIFAR10 and CIFAR100 will be automately downloaded. 91 | 92 | # ⚡ Training 93 | To train SimCo, SimMoCo, and MoCoV2, use the script in folder `./bash_files`. 94 | 95 | You should change the entity and project name to enable the wandb logging. `--project --entity `. Or you can simply remove `--wandb` to disable wandb logging. 96 | 97 | # Results 98 | 99 | | Batch size | 64 | 128 | 256 | 512 | 1024 | 100 | |------------|-------|-------|----------------|-------|-------| 101 | | MoCo v2 | 52.58 | 54.40 | 53.28 | 51.47 | 48.90 | 102 | | SimMoCo | 54.02 | 54.93 | 54.11 | 52.45 | 49.70 | 103 | | SimCo | 58.04 | 58.29 | **58.35** | 57.08 | 55.34 | 104 | 105 | More result can be found in the paper. 106 | 107 | This code is developed based on [solo-learn](https://github.com/vturrisi/solo-learn). 108 | 109 | # Citation 110 | ``` 111 | @article{zhang2022dual, 112 | title={Dual temperature helps contrastive learning without many negative samples: Towards understanding and simplifying moco}, 113 | author={Zhang, Chaoning and Zhang, Kang and Pham, Trung X and Niu, Axi and Qiao, Zhinan and Yoo, Chang D and Kweon, In So}, 114 | journal={CVPR}, 115 | year={2022} 116 | } 117 | ``` 118 | 119 | 120 | # Acknowledgement 121 | 122 | This work was partly supported by Institute for Information & communications Technology Planning & Evaluation (IITP) grant funded by the Korea government (MSIT) under grant No.2019-0-01396 (Development of framework for analyzing, detecting, mitigating of bias in AI model and training data), No.2021-0-01381 (Development of Causal AI through Video Understanding and Reinforcement Learning, and Its Applications to Real Environments) and No.2021-0-02068 (Artificial Intelligence Innovation Hub). 123 | -------------------------------------------------------------------------------- /asset/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChaoningZhang/Dual-temperature/e8d5e768255721b23643d948631f80c0b8a0851b/asset/image.png -------------------------------------------------------------------------------- /bash_files/moco.sh: -------------------------------------------------------------------------------- 1 | # this is the script to train original moco 2 | python3 main_pretrain.py \ 3 | --dataset cifar100 \ 4 | --encoder resnet18 \ 5 | --data_dir ./data \ 6 | --max_epochs 200 \ 7 | --gpus 0 \ 8 | --precision 16 \ 9 | --optimizer sgd \ 10 | --scheduler warmup_cosine \ 11 | --lr 0.03 \ 12 | --classifier_lr 0.3 \ 13 | --weight_decay 5e-4 \ 14 | --batch_size 256 \ 15 | --num_workers 3 \ 16 | --brightness 0.4 \ 17 | --contrast 0.4 \ 18 | --saturation 0.4 \ 19 | --hue 0.1 \ 20 | --gaussian_prob 0.0 0.0 \ 21 | --name mocov2 \ 22 | --project \ 23 | --entity \ 24 | --wandb \ 25 | --method mocov2 \ 26 | --proj_hidden_dim 128 \ 27 | --temperature 0.1 \ 28 | --base_tau_momentum 0.99 \ 29 | --final_tau_momentum 0.99 \ 30 | --momentum_classifier 31 | -------------------------------------------------------------------------------- /bash_files/simco.sh: -------------------------------------------------------------------------------- 1 | # This is the script to train simco 2 | python3 main_pretrain.py \ 3 | --dataset cifar100 \ 4 | --encoder resnet18 \ 5 | --data_dir ./data \ 6 | --max_epochs 200 \ 7 | --gpus 1 \ 8 | --precision 16 \ 9 | --optimizer sgd \ 10 | --scheduler warmup_cosine \ 11 | --lr 0.03 \ 12 | --classifier_lr 0.3 \ 13 | --weight_decay 5e-4 \ 14 | --batch_size 256 \ 15 | --num_workers 3 \ 16 | --brightness 0.4 \ 17 | --contrast 0.4 \ 18 | --saturation 0.4 \ 19 | --hue 0.1 \ 20 | --gaussian_prob 0.0 0.0 \ 21 | --name simco \ 22 | --project \ 23 | --entity \ 24 | --wandb \ 25 | --method simco_dual_temperature \ 26 | --proj_hidden_dim 128 \ 27 | --temperature 0.1 \ 28 | --dt_m 10 \ 29 | --base_tau_momentum 0 \ 30 | --final_tau_momentum 0 \ 31 | --momentum_classifier 32 | 33 | -------------------------------------------------------------------------------- /bash_files/simmoco.sh: -------------------------------------------------------------------------------- 1 | # This is the script to train simmoco 2 | python3 main_pretrain.py \ 3 | --dataset cifar100 \ 4 | --encoder resnet18 \ 5 | --data_dir ./data \ 6 | --max_epochs 200 \ 7 | --gpus 3 \ 8 | --precision 16 \ 9 | --optimizer sgd \ 10 | --scheduler warmup_cosine \ 11 | --lr 0.03 \ 12 | --classifier_lr 0.3 \ 13 | --weight_decay 5e-4 \ 14 | --batch_size 256 \ 15 | --num_workers 3 \ 16 | --brightness 0.4 \ 17 | --contrast 0.4 \ 18 | --saturation 0.4 \ 19 | --hue 0.1 \ 20 | --gaussian_prob 0.0 0.0 \ 21 | --name simmoco \ 22 | --project \ 23 | --entity \ 24 | --wandb \ 25 | --method simmoco_dual_temperature \ 26 | --proj_hidden_dim 128 \ 27 | --temperature 0.1 \ 28 | --dt_m 10 \ 29 | --base_tau_momentum 0.99 \ 30 | --final_tau_momentum 0.99 \ 31 | --momentum_classifier 32 | 33 | -------------------------------------------------------------------------------- /main_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import os 21 | from pprint import pprint 22 | 23 | from pytorch_lightning import Trainer, seed_everything 24 | from pytorch_lightning.callbacks import LearningRateMonitor 25 | from pytorch_lightning.loggers import WandbLogger 26 | from pytorch_lightning.plugins import DDPPlugin 27 | 28 | from solo.args.setup import parse_args_pretrain 29 | from solo.methods import METHODS 30 | 31 | try: 32 | from solo.methods.dali import PretrainABC 33 | except ImportError: 34 | _dali_avaliable = False 35 | else: 36 | _dali_avaliable = True 37 | 38 | try: 39 | from solo.utils.auto_umap import AutoUMAP 40 | except ImportError: 41 | _umap_available = False 42 | else: 43 | _umap_available = True 44 | 45 | from solo.utils.checkpointer import Checkpointer 46 | from solo.utils.classification_dataloader import prepare_data as prepare_data_classification 47 | from solo.utils.pretrain_dataloader import ( 48 | prepare_dataloader, 49 | prepare_datasets, 50 | prepare_multicrop_transform, 51 | prepare_n_crop_transform, 52 | prepare_transform, 53 | ) 54 | import shutil 55 | import sys 56 | import glob 57 | 58 | 59 | def main(): 60 | # set the seed 61 | seed_everything(15) 62 | 63 | args = parse_args_pretrain() 64 | 65 | assert args.method in METHODS, f"Choose from {METHODS.keys()}" 66 | 67 | MethodClass = METHODS[args.method] 68 | if args.dali: 69 | assert ( 70 | _dali_avaliable 71 | ), "Dali is not currently avaiable, please install it first with [dali]." 72 | MethodClass = type(f"Dali{MethodClass.__name__}", (MethodClass, PretrainABC), {}) 73 | 74 | model = MethodClass(**args.__dict__) 75 | 76 | # contrastive dataloader 77 | if not args.dali: 78 | # asymmetric augmentations 79 | if args.unique_augs > 1: 80 | transform = [ 81 | prepare_transform(args.dataset, multicrop=args.multicrop, **kwargs) 82 | for kwargs in args.transform_kwargs 83 | ] 84 | else: 85 | transform = prepare_transform( 86 | args.dataset, multicrop=args.multicrop, **args.transform_kwargs 87 | ) 88 | 89 | if args.debug_augmentations: 90 | print("Transforms:") 91 | pprint(transform) 92 | 93 | if args.multicrop: 94 | assert not args.unique_augs == 1 95 | 96 | if args.dataset in ["cifar10", "cifar100"]: 97 | size_crops = [32, 24] 98 | elif args.dataset == "stl10": 99 | size_crops = [96, 58] 100 | # imagenet or custom dataset 101 | else: 102 | size_crops = [224, 96] 103 | 104 | transform = prepare_multicrop_transform( 105 | transform, size_crops=size_crops, num_crops=[args.num_crops, args.num_small_crops] 106 | ) 107 | else: 108 | if args.num_crops != 2: 109 | # import pdb; pdb.set_trace() 110 | assert args.method == "wmse" or args.method == "simsiam_eoa" or args.method == "simclr_neg_size" 111 | 112 | transform = prepare_n_crop_transform(transform, num_crops=args.num_crops) 113 | 114 | train_dataset = prepare_datasets( 115 | args.dataset, 116 | transform, 117 | data_dir=args.data_dir, 118 | train_dir=args.train_dir, 119 | no_labels=args.no_labels, 120 | ) 121 | train_loader = prepare_dataloader( 122 | train_dataset, batch_size=args.batch_size, num_workers=args.num_workers 123 | ) 124 | 125 | # normal dataloader for when it is available 126 | if args.dataset == "custom" and (args.no_labels or args.val_dir is None): 127 | val_loader = None 128 | elif args.dataset in ["imagenet100", "imagenet"] and args.val_dir is None: 129 | val_loader = None 130 | else: 131 | _, val_loader = prepare_data_classification( 132 | args.dataset, 133 | data_dir=args.data_dir, 134 | train_dir=args.train_dir, 135 | val_dir=args.val_dir, 136 | batch_size=args.batch_size, 137 | num_workers=args.num_workers, 138 | ) 139 | 140 | callbacks = [] 141 | 142 | # wandb logging 143 | if args.wandb: 144 | wandb_logger = WandbLogger( 145 | name=args.name, 146 | project=args.project, 147 | entity=args.entity, 148 | offline=args.offline, 149 | ) 150 | wandb_logger.watch(model, log="gradients", log_freq=100) 151 | wandb_logger.log_hyperparams(args) 152 | 153 | # lr logging 154 | lr_monitor = LearningRateMonitor(logging_interval="epoch") 155 | callbacks.append(lr_monitor) 156 | 157 | # save checkpoint on last epoch only 158 | ckpt = Checkpointer( 159 | args, 160 | logdir=os.path.join(args.checkpoint_dir, args.method), 161 | frequency=args.checkpoint_frequency, 162 | ) 163 | callbacks.append(ckpt) 164 | 165 | if args.auto_umap: 166 | assert ( 167 | _umap_available 168 | ), "UMAP is not currently avaiable, please install it first with [umap]." 169 | auto_umap = AutoUMAP( 170 | args, 171 | logdir=os.path.join(args.auto_umap_dir, args.method), 172 | frequency=args.auto_umap_frequency, 173 | ) 174 | callbacks.append(auto_umap) 175 | 176 | trainer = Trainer.from_argparse_args( 177 | args, 178 | logger=wandb_logger if args.wandb else None, 179 | callbacks=callbacks, 180 | plugins=DDPPlugin(find_unused_parameters=True), 181 | checkpoint_callback=False, 182 | terminate_on_nan=True, 183 | accelerator="ddp", 184 | log_every_n_steps=args.log_frenquence, 185 | ) 186 | 187 | # save code for each run to make each run reproducible 188 | ################################################################# 189 | if args.wandb: 190 | experimentdir = f"code/{args.method}_{args.project}_{args.name}_{trainer.logger.version}" 191 | args.codepath = experimentdir 192 | else: 193 | experimentdir = f"code/{args.method}_{args.project}_{args.name}_test" 194 | 195 | if not os.path.exists("code"): 196 | os.mkdir("code") 197 | 198 | if os.path.exists(experimentdir): 199 | print(experimentdir + ' : exists. overwrite it.') 200 | shutil.rmtree(experimentdir) 201 | os.mkdir(experimentdir) 202 | else: 203 | os.mkdir(experimentdir) 204 | 205 | shutil.copytree(f"solo", os.path.join(experimentdir, 'solo')) 206 | shutil.copytree(f"bash_files", os.path.join(experimentdir, 'bash_files')) 207 | shutil.copyfile(f"main_pretrain.py", os.path.join(experimentdir, 'main_pretrain.py')) 208 | ################################################################# 209 | 210 | if args.dali: 211 | trainer.fit(model, val_dataloaders=val_loader) 212 | else: 213 | trainer.fit(model, train_loader, val_loader) 214 | 215 | 216 | if __name__ == "__main__": 217 | main() 218 | -------------------------------------------------------------------------------- /solo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | 21 | from solo import args, losses, methods, utils 22 | 23 | __all__ = ["args", "losses", "methods", "utils"] 24 | -------------------------------------------------------------------------------- /solo/args/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | from solo.args import dataset, setup, utils 21 | 22 | __all__ = ["dataset", "setup", "utils"] 23 | 24 | 25 | -------------------------------------------------------------------------------- /solo/args/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | from argparse import ArgumentParser 21 | from pathlib import Path 22 | 23 | 24 | def dataset_args(parser: ArgumentParser): 25 | """Adds dataset-related arguments to a parser. 26 | 27 | Args: 28 | parser (ArgumentParser): parser to add dataset args to. 29 | """ 30 | 31 | SUPPORTED_DATASETS = [ 32 | "cifar10", 33 | "cifar100", 34 | "stl10", 35 | "imagenet", 36 | "imagenet100", 37 | "custom", 38 | ] 39 | 40 | parser.add_argument("--dataset", choices=SUPPORTED_DATASETS, type=str, required=True) 41 | 42 | # dataset path 43 | parser.add_argument("--data_dir", type=Path, required=True) 44 | parser.add_argument("--train_dir", type=Path, default=None) 45 | parser.add_argument("--val_dir", type=Path, default=None) 46 | 47 | # dali (imagenet-100/imagenet/custom only) 48 | parser.add_argument("--dali", action="store_true") 49 | parser.add_argument("--dali_device", type=str, default="gpu") 50 | 51 | # custom dataset only 52 | parser.add_argument("--no_labels", action="store_true") 53 | 54 | 55 | def augmentations_args(parser: ArgumentParser): 56 | """Adds augmentation-related arguments to a parser. 57 | 58 | Args: 59 | parser (ArgumentParser): parser to add augmentation args to. 60 | """ 61 | 62 | # cropping 63 | parser.add_argument("--multicrop", action="store_true") 64 | parser.add_argument("--num_crops", type=int, default=2) 65 | parser.add_argument("--num_small_crops", type=int, default=0) 66 | 67 | # augmentations 68 | parser.add_argument("--brightness", type=float, required=True, nargs="+") 69 | parser.add_argument("--contrast", type=float, required=True, nargs="+") 70 | parser.add_argument("--saturation", type=float, required=True, nargs="+") 71 | parser.add_argument("--hue", type=float, required=True, nargs="+") 72 | parser.add_argument("--gaussian_prob", type=float, default=[0.5], nargs="+") 73 | parser.add_argument("--solarization_prob", type=float, default=[0.0], nargs="+") 74 | parser.add_argument("--min_scale", type=float, default=[0.08], nargs="+") 75 | 76 | # for imagenet or custom dataset 77 | parser.add_argument("--size", type=int, default=[224], nargs="+") 78 | 79 | # for custom dataset 80 | parser.add_argument("--mean", type=float, default=[0.485, 0.456, 0.406], nargs="+") 81 | parser.add_argument("--std", type=float, default=[0.228, 0.224, 0.225], nargs="+") 82 | 83 | # debug 84 | parser.add_argument("--debug_augmentations", action="store_true") 85 | -------------------------------------------------------------------------------- /solo/args/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import argparse 21 | 22 | import pytorch_lightning as pl 23 | from solo.args.dataset import augmentations_args, dataset_args 24 | from solo.args.utils import additional_setup_linear, additional_setup_pretrain 25 | from solo.methods import METHODS 26 | from solo.utils.checkpointer import Checkpointer 27 | 28 | try: 29 | from solo.utils.auto_umap import AutoUMAP 30 | except ImportError: 31 | _umap_available = False 32 | else: 33 | _umap_available = True 34 | import os 35 | 36 | 37 | def parse_args_pretrain() -> argparse.Namespace: 38 | """Parses dataset, augmentation, pytorch lightning, model specific and additional args. 39 | 40 | First adds shared args such as dataset, augmentation and pytorch lightning args, then pulls the 41 | model name from the command and proceeds to add model specific args from the desired class. If 42 | wandb is enabled, it adds checkpointer args. Finally, adds additional non-user given parameters. 43 | 44 | Returns: 45 | argparse.Namespace: a namespace containing all args needed for pretraining. 46 | """ 47 | 48 | parser = argparse.ArgumentParser() 49 | 50 | # add current working path 51 | parser.add_argument("--runpath", default=os.getcwd(), type=str) 52 | 53 | # add code saving path 54 | parser.add_argument("--codepath", default=os.getcwd(), type=str) 55 | 56 | # add current working path 57 | parser.add_argument("--log_frenquence", default=50, type=int) 58 | 59 | # add shared arguments 60 | dataset_args(parser) 61 | augmentations_args(parser) 62 | 63 | # add pytorch lightning trainer args 64 | parser = pl.Trainer.add_argparse_args(parser) 65 | 66 | # add method-specific arguments 67 | parser.add_argument("--method", type=str) 68 | 69 | # THIS LINE IS KEY TO PULL THE MODEL NAME 70 | temp_args, _ = parser.parse_known_args() 71 | 72 | # add model specific args 73 | parser = METHODS[temp_args.method].add_model_specific_args(parser) 74 | 75 | # add auto umap args 76 | parser.add_argument("--auto_umap", action="store_true") 77 | 78 | # optionally add checkpointer and AutoUMAP args 79 | temp_args, _ = parser.parse_known_args() 80 | if temp_args.wandb: 81 | parser = Checkpointer.add_checkpointer_args(parser) 82 | 83 | if _umap_available and temp_args.auto_umap: 84 | parser = AutoUMAP.add_auto_umap_args(parser) 85 | 86 | # parse args 87 | args = parser.parse_args() 88 | 89 | 90 | # prepare arguments with additional setup 91 | additional_setup_pretrain(args) 92 | 93 | return args 94 | 95 | 96 | def parse_args_linear() -> argparse.Namespace: 97 | """Parses feature extractor, dataset, pytorch lightning, linear eval specific and additional args. 98 | 99 | First adds and arg for the pretrained feature extractor, then adds dataset, pytorch lightning 100 | and linear eval specific args. If wandb is enabled, it adds checkpointer args. Finally, adds 101 | additional non-user given parameters. 102 | 103 | Returns: 104 | argparse.Namespace: a namespace containing all args needed for pretraining. 105 | """ 106 | 107 | parser = argparse.ArgumentParser() 108 | 109 | parser.add_argument("--pretrained_feature_extractor", type=str) 110 | 111 | # add shared arguments 112 | dataset_args(parser) 113 | 114 | # add pytorch lightning trainer args 115 | parser = pl.Trainer.add_argparse_args(parser) 116 | 117 | # linear model 118 | parser = METHODS["linear"].add_model_specific_args(parser) 119 | 120 | # THIS LINE IS KEY TO PULL WANDB 121 | temp_args, _ = parser.parse_known_args() 122 | 123 | # add checkpointer args (only if logging is enabled) 124 | if temp_args.wandb: 125 | parser = Checkpointer.add_checkpointer_args(parser) 126 | 127 | # parse args 128 | args = parser.parse_args() 129 | additional_setup_linear(args) 130 | 131 | return args 132 | -------------------------------------------------------------------------------- /solo/args/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import os 21 | from argparse import Namespace 22 | 23 | N_CLASSES_PER_DATASET = { 24 | "cifar10": 10, 25 | "cifar100": 100, 26 | "stl10": 10, 27 | "imagenet": 1000, 28 | "imagenet100": 100, 29 | } 30 | 31 | 32 | def additional_setup_pretrain(args: Namespace): 33 | """Provides final setup for pretraining to non-user given parameters by changing args. 34 | 35 | Parsers arguments to extract the number of classes of a dataset, create 36 | transformations kwargs, correctly parse gpus, identify if a cifar dataset 37 | is being used and adjust the lr. 38 | 39 | Args: 40 | args (Namespace): object that needs to contain, at least: 41 | - dataset: dataset name. 42 | - brightness, contrast, saturation, hue, min_scale: required augmentations 43 | settings. 44 | - multicrop: flag to use multicrop. 45 | - dali: flag to use dali. 46 | - optimizer: optimizer name being used. 47 | - gpus: list of gpus to use. 48 | - lr: learning rate. 49 | 50 | [optional] 51 | - gaussian_prob, solarization_prob: optional augmentations settings. 52 | """ 53 | 54 | if args.dataset in N_CLASSES_PER_DATASET: 55 | args.num_classes = N_CLASSES_PER_DATASET[args.dataset] 56 | else: 57 | # hack to maintain the current pipeline 58 | # even if the custom dataset doesn't have any labels 59 | dir_path = args.data_dir / args.train_dir 60 | args.num_classes = max( 61 | 1, 62 | len([entry.name for entry in os.scandir(dir_path) if entry.is_dir]), 63 | ) 64 | 65 | unique_augs = max( 66 | len(p) 67 | for p in [ 68 | args.brightness, 69 | args.contrast, 70 | args.saturation, 71 | args.hue, 72 | args.gaussian_prob, 73 | args.solarization_prob, 74 | args.min_scale, 75 | args.size, 76 | ] 77 | ) 78 | # if args.method != "simclr_interintra_neg": 79 | # assert unique_augs == args.num_crops or unique_augs == 1 80 | 81 | # assert that either all unique augmentation pipelines have a unique 82 | # parameter or that a single parameter is replicated to all pipelines 83 | for p in [ 84 | "brightness", 85 | "contrast", 86 | "saturation", 87 | "hue", 88 | "gaussian_prob", 89 | "solarization_prob", 90 | "min_scale", 91 | "size", 92 | ]: 93 | values = getattr(args, p) 94 | n = len(values) 95 | assert n == unique_augs or n == 1 96 | 97 | if n == 1: 98 | setattr(args, p, getattr(args, p) * unique_augs) 99 | 100 | args.unique_augs = unique_augs 101 | 102 | if unique_augs > 1: 103 | args.transform_kwargs = [ 104 | dict( 105 | brightness=brightness, 106 | contrast=contrast, 107 | saturation=saturation, 108 | hue=hue, 109 | gaussian_prob=gaussian_prob, 110 | solarization_prob=solarization_prob, 111 | min_scale=min_scale, 112 | size=size, 113 | ) 114 | for ( 115 | brightness, 116 | contrast, 117 | saturation, 118 | hue, 119 | gaussian_prob, 120 | solarization_prob, 121 | min_scale, 122 | size, 123 | ) in zip( 124 | args.brightness, 125 | args.contrast, 126 | args.saturation, 127 | args.hue, 128 | args.gaussian_prob, 129 | args.solarization_prob, 130 | args.min_scale, 131 | args.size, 132 | ) 133 | ] 134 | 135 | elif not args.multicrop: 136 | args.transform_kwargs = dict( 137 | brightness=args.brightness[0], 138 | contrast=args.contrast[0], 139 | saturation=args.saturation[0], 140 | hue=args.hue[0], 141 | gaussian_prob=args.gaussian_prob[0], 142 | solarization_prob=args.solarization_prob[0], 143 | min_scale=args.min_scale[0], 144 | size=args.size[0], 145 | ) 146 | else: 147 | args.transform_kwargs = dict( 148 | brightness=args.brightness[0], 149 | contrast=args.contrast[0], 150 | saturation=args.saturation[0], 151 | hue=args.hue[0], 152 | gaussian_prob=args.gaussian_prob[0], 153 | solarization_prob=args.solarization_prob[0], 154 | ) 155 | 156 | # add support for custom mean and std 157 | if args.dataset == "custom": 158 | if isinstance(args.transform_kwargs, dict): 159 | args.transform_kwargs["mean"] = args.mean 160 | args.transform_kwargs["std"] = args.std 161 | else: 162 | for kwargs in args.transform_kwargs: 163 | kwargs["mean"] = args.mean 164 | kwargs["std"] = args.std 165 | 166 | if args.dataset in ["cifar10", "cifar100", "stl10"]: 167 | if isinstance(args.transform_kwargs, dict): 168 | del args.transform_kwargs["size"] 169 | else: 170 | for kwargs in args.transform_kwargs: 171 | del kwargs["size"] 172 | 173 | # create backbone-specific arguments 174 | args.backbone_args = {"cifar": True if args.dataset in ["cifar10", "cifar100"] else False} 175 | if "resnet" in args.encoder: 176 | args.backbone_args["zero_init_residual"] = args.zero_init_residual 177 | else: 178 | # dataset related for all transformers 179 | dataset = args.dataset 180 | if "cifar" in dataset: 181 | args.backbone_args["img_size"] = 32 182 | elif "stl" in dataset: 183 | args.backbone_args["img_size"] = 96 184 | elif "imagenet" in dataset: 185 | args.backbone_args["img_size"] = 224 186 | elif "custom" in dataset: 187 | transform_kwargs = args.transform_kwargs 188 | if isinstance(transform_kwargs, list): 189 | args.backbone_args["img_size"] = transform_kwargs[0]["size"] 190 | else: 191 | args.backbone_args["img_size"] = transform_kwargs["size"] 192 | 193 | if "vit" in args.encoder: 194 | args.backbone_args["patch_size"] = args.patch_size 195 | 196 | del args.zero_init_residual 197 | del args.patch_size 198 | 199 | if args.dali: 200 | assert args.dataset in ["imagenet100", "imagenet", "custom"] 201 | 202 | args.extra_optimizer_args = {} 203 | if args.optimizer == "sgd": 204 | args.extra_optimizer_args["momentum"] = 0.9 205 | 206 | if isinstance(args.gpus, int): 207 | args.gpus = [args.gpus] 208 | elif isinstance(args.gpus, str): 209 | args.gpus = [int(gpu) for gpu in args.gpus.split(",") if gpu] 210 | 211 | # adjust lr according to batch size 212 | args.lr = args.lr * args.batch_size * len(args.gpus) / 256 213 | 214 | 215 | def additional_setup_linear(args: Namespace): 216 | """Provides final setup for linear evaluation to non-user given parameters by changing args. 217 | 218 | Parsers arguments to extract the number of classes of a dataset, correctly parse gpus, identify 219 | if a cifar dataset is being used and adjust the lr. 220 | 221 | Args: 222 | args: Namespace object that needs to contain, at least: 223 | - dataset: dataset name. 224 | - optimizer: optimizer name being used. 225 | - gpus: list of gpus to use. 226 | - lr: learning rate. 227 | """ 228 | 229 | assert args.dataset in N_CLASSES_PER_DATASET 230 | args.num_classes = N_CLASSES_PER_DATASET[args.dataset] 231 | 232 | # create backbone-specific arguments 233 | args.backbone_args = {"cifar": True if args.dataset in ["cifar10", "cifar100"] else False} 234 | 235 | if "resnet" not in args.encoder: 236 | # dataset related for all transformers 237 | dataset = args.dataset 238 | if "cifar" in dataset: 239 | args.backbone_args["img_size"] = 32 240 | elif "stl" in dataset: 241 | args.backbone_args["img_size"] = 96 242 | elif "imagenet" in dataset: 243 | args.backbone_args["img_size"] = 224 244 | elif "custom" in dataset: 245 | transform_kwargs = args.transform_kwargs 246 | if isinstance(transform_kwargs, list): 247 | args.backbone_args["img_size"] = transform_kwargs[0]["size"] 248 | else: 249 | args.backbone_args["img_size"] = transform_kwargs["size"] 250 | 251 | if "vit" in args.encoder: 252 | args.backbone_args["patch_size"] = args.patch_size 253 | 254 | del args.patch_size 255 | 256 | if args.dali: 257 | assert args.dataset in ["imagenet100", "imagenet"] 258 | 259 | args.extra_optimizer_args = {} 260 | if args.optimizer == "sgd": 261 | args.extra_optimizer_args["momentum"] = 0.9 262 | 263 | if isinstance(args.gpus, int): 264 | args.gpus = [args.gpus] 265 | elif isinstance(args.gpus, str): 266 | args.gpus = [int(gpu) for gpu in args.gpus.split(",") if gpu] 267 | -------------------------------------------------------------------------------- /solo/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | from solo.losses.moco import moco_loss_func 21 | from solo.losses.dual_temperature_loss import dual_temperature_loss_func 22 | 23 | __all__ = [ 24 | "moco_loss_func", 25 | "dual_temperature_loss_func", 26 | ] 27 | -------------------------------------------------------------------------------- /solo/losses/dual_temperature_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | import torch.nn.functional as F 22 | 23 | def dual_temperature_loss_func( 24 | query: torch.Tensor, 25 | key: torch.Tensor, 26 | temperature=0.1, 27 | dt_m=10, 28 | ) -> torch.Tensor: 29 | """ 30 | query: anchor sample. 31 | key: positive sample. 32 | temperature: intra-anchor hardness-awareness control temperature. 33 | dt_m: the scalar number to get inter-anchor hardness awareness temperature. 34 | inter-anchor hardness awareness temperature is calculated by dt_m * temperature 35 | """ 36 | 37 | # intra-anchor hardness-awareness 38 | b = query.size(0) 39 | pos = torch.einsum("nc,nc->n", [query, key]).unsqueeze(-1) 40 | 41 | # Selecte the intra negative samples according the updata time, 42 | neg = torch.einsum("nc,ck->nk", [query, key.T]) 43 | mask_neg = torch.ones_like(neg, dtype=bool) 44 | mask_neg.fill_diagonal_(False) 45 | neg = neg[mask_neg].reshape(neg.size(0), neg.size(1)-1) 46 | logits = torch.cat([pos, neg], dim=1) 47 | 48 | logits_intra = logits / temperature 49 | prob_intra = F.softmax(logits_intra, dim=1) 50 | 51 | # inter-anchor hardness-awareness 52 | logits_inter = logits / (temperature*dt_m) 53 | prob_inter = F.softmax(logits_inter, dim=1) 54 | 55 | # hardness-awareness factor 56 | inter_intra = (1 - prob_inter[:, 0]) / (1 - prob_intra[:, 0]) 57 | 58 | loss = -torch.nn.functional.log_softmax(logits_intra, dim=-1)[:, 0] 59 | 60 | # final loss 61 | loss = inter_intra.detach() * loss 62 | loss = loss.mean() 63 | 64 | return loss 65 | -------------------------------------------------------------------------------- /solo/losses/moco.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | import torch.nn.functional as F 22 | 23 | 24 | def moco_loss_func( 25 | query: torch.Tensor, key: torch.Tensor, queue: torch.Tensor, temperature=0.1 26 | ) -> torch.Tensor: 27 | """Computes MoCo's loss given a batch of queries from view 1, a batch of keys from view 2 and a 28 | queue of past elements. 29 | 30 | Args: 31 | query (torch.Tensor): NxD Tensor containing the queries from view 1. 32 | key (torch.Tensor): NxD Tensor containing the queries from view 2. 33 | queue (torch.Tensor): a queue of negative samples for the contrastive loss. 34 | temperature (float, optional): [description]. temperature of the softmax in the contrastive 35 | loss. Defaults to 0.1. 36 | 37 | Returns: 38 | torch.Tensor: MoCo loss. 39 | """ 40 | 41 | pos = torch.einsum("nc,nc->n", [query, key]).unsqueeze(-1) 42 | neg = torch.einsum("nc,ck->nk", [query, queue]) 43 | logits = torch.cat([pos, neg], dim=1) 44 | logits /= temperature 45 | targets = torch.zeros(query.size(0), device=query.device, dtype=torch.long) 46 | return F.cross_entropy(logits, targets) 47 | -------------------------------------------------------------------------------- /solo/methods/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | 21 | from solo.methods.base import BaseMethod 22 | from solo.methods.mocov2plus import MoCoV2Plus 23 | 24 | # dual temperature method 25 | from solo.methods.simco_dual_temperature import SimCo_DualTemperature 26 | from solo.methods.simmoco_dual_temperature import SimMoCo_DualTemperature 27 | from solo.methods.mocov2 import MoCoV2 28 | 29 | METHODS = { 30 | # base classes 31 | "base": BaseMethod, 32 | # methods 33 | "mocov2plus": MoCoV2Plus, 34 | 35 | "simco_dual_temperature": SimCo_DualTemperature, 36 | "simmoco_dual_temperature": SimMoCo_DualTemperature, 37 | "mocov2": MoCoV2, 38 | 39 | } 40 | 41 | __all__ = [ 42 | "BaseMethod", 43 | "MoCoV2Plus", 44 | "SimCo_DualTemperature", 45 | "SimMoCo_DualTemperature", 46 | "MoCoV2", 47 | ] 48 | 49 | try: 50 | from solo.methods import dali # noqa: F401 51 | except ImportError: 52 | pass 53 | else: 54 | __all__.append("dali") 55 | -------------------------------------------------------------------------------- /solo/methods/dali.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import math 21 | from abc import ABC 22 | from pathlib import Path 23 | from typing import List 24 | 25 | import torch 26 | import torch.nn as nn 27 | from nvidia.dali.plugin.pytorch import DALIGenericIterator, LastBatchPolicy 28 | from solo.utils.dali_dataloader import ( 29 | CustomNormalPipeline, 30 | CustomTransform, 31 | ImagenetTransform, 32 | MulticropPretrainPipeline, 33 | NormalPipeline, 34 | PretrainPipeline, 35 | ) 36 | 37 | 38 | class BaseWrapper(DALIGenericIterator): 39 | """Temporary fix to handle LastBatchPolicy.DROP.""" 40 | 41 | def __len__(self): 42 | size = ( 43 | self._size_no_pad // self._shards_num 44 | if self._last_batch_policy == LastBatchPolicy.DROP 45 | else self.size 46 | ) 47 | if self._reader_name: 48 | if self._last_batch_policy != LastBatchPolicy.DROP: 49 | return math.ceil(size / self.batch_size) 50 | else: 51 | return size // self.batch_size 52 | else: 53 | if self._last_batch_policy != LastBatchPolicy.DROP: 54 | return math.ceil(size / (self._num_gpus * self.batch_size)) 55 | else: 56 | return size // (self._num_gpus * self.batch_size) 57 | 58 | 59 | class PretrainWrapper(BaseWrapper): 60 | def __init__( 61 | self, 62 | model_batch_size: int, 63 | model_rank: int, 64 | model_device: str, 65 | conversion_map: List[int] = None, 66 | *args, 67 | **kwargs, 68 | ): 69 | """Adds indices to a batch fetched from the parent. 70 | 71 | Args: 72 | model_batch_size (int): batch size. 73 | model_rank (int): rank of the current process. 74 | model_device (str): id of the current device. 75 | conversion_map (List[int], optional): list of integeres that map each index 76 | to a class label. If nothing is passed, no label mapping needs to be done. 77 | Defaults to None. 78 | """ 79 | 80 | super().__init__(*args, **kwargs) 81 | self.model_batch_size = model_batch_size 82 | self.model_rank = model_rank 83 | self.model_device = model_device 84 | self.conversion_map = conversion_map 85 | if self.conversion_map is not None: 86 | self.conversion_map = torch.tensor( 87 | self.conversion_map, dtype=torch.float32, device=self.model_device 88 | ).reshape(-1, 1) 89 | self.conversion_map = nn.Embedding.from_pretrained(self.conversion_map) 90 | 91 | def __next__(self): 92 | batch = super().__next__()[0] 93 | # PyTorch Lightning does double buffering 94 | # https://github.com/PyTorchLightning/pytorch-lightning/issues/1316, 95 | # and as DALI owns the tensors it returns the content of it is trashed so the copy needs, 96 | # to be made before returning. 97 | 98 | if self.conversion_map is not None: 99 | *all_X, indexes = [batch[v] for v in self.output_map] 100 | targets = self.conversion_map(indexes).flatten().long().detach().clone() 101 | indexes = indexes.flatten().long().detach().clone() 102 | else: 103 | *all_X, targets = [batch[v] for v in self.output_map] 104 | targets = targets.squeeze(-1).long().detach().clone() 105 | # creates dummy indexes 106 | indexes = ( 107 | ( 108 | torch.arange(self.model_batch_size, device=self.model_device) 109 | + (self.model_rank * self.model_batch_size) 110 | ) 111 | .detach() 112 | .clone() 113 | ) 114 | 115 | all_X = [x.detach().clone() for x in all_X] 116 | return [indexes, all_X, targets] 117 | 118 | 119 | class Wrapper(BaseWrapper): 120 | def __next__(self): 121 | batch = super().__next__() 122 | x, target = batch[0]["x"], batch[0]["label"] 123 | target = target.squeeze(-1).long() 124 | # PyTorch Lightning does double buffering 125 | # https://github.com/PyTorchLightning/pytorch-lightning/issues/1316, 126 | # and as DALI owns the tensors it returns the content of it is trashed so the copy needs, 127 | # to be made before returning. 128 | x = x.detach().clone() 129 | target = target.detach().clone() 130 | return x, target 131 | 132 | 133 | class PretrainABC(ABC): 134 | """Abstract pretrain class that returns a train_dataloader using dali.""" 135 | 136 | def train_dataloader(self) -> DALIGenericIterator: 137 | """Returns a train dataloader using dali. Supports multi-crop and asymmetric augmentations. 138 | 139 | Returns: 140 | DALIGenericIterator: a train dataloader in the form of a dali pipeline object wrapped 141 | with PretrainWrapper. 142 | """ 143 | 144 | device_id = self.local_rank 145 | shard_id = self.global_rank 146 | num_shards = self.trainer.world_size 147 | 148 | # get data arguments from model 149 | dali_device = self.extra_args["dali_device"] 150 | 151 | # data augmentations 152 | unique_augs = self.extra_args["unique_augs"] 153 | transform_kwargs = self.extra_args["transform_kwargs"] 154 | 155 | num_workers = self.extra_args["num_workers"] 156 | data_dir = Path(self.extra_args["data_dir"]) 157 | train_dir = Path(self.extra_args["train_dir"]) 158 | 159 | # hack to encode image indexes into the labels 160 | self.encode_indexes_into_labels = self.extra_args["encode_indexes_into_labels"] 161 | 162 | # handle custom data by creating the needed pipeline 163 | dataset = self.extra_args["dataset"] 164 | if dataset in ["imagenet100", "imagenet"]: 165 | transform_pipeline = ImagenetTransform 166 | elif dataset == "custom": 167 | transform_pipeline = CustomTransform 168 | else: 169 | raise ValueError(dataset, "is not supported, used [imagenet, imagenet100 or custom]") 170 | 171 | if self.multicrop: 172 | num_crops = [self.num_crops, self.num_small_crops] 173 | size_crops = [224, 96] 174 | min_scales = [0.14, 0.05] 175 | max_scale_crops = [1.0, 0.14] 176 | 177 | transforms = [] 178 | for size, min_scale, max_scale in zip(size_crops, min_scales, max_scale_crops): 179 | transform = transform_pipeline( 180 | device=dali_device, 181 | **transform_kwargs, 182 | size=size, 183 | min_scale=min_scale, 184 | max_scale=max_scale, 185 | ) 186 | transforms.append(transform) 187 | train_pipeline = MulticropPretrainPipeline( 188 | data_dir / train_dir, 189 | batch_size=self.batch_size, 190 | transforms=transforms, 191 | num_crops=num_crops, 192 | device=dali_device, 193 | device_id=device_id, 194 | shard_id=shard_id, 195 | num_shards=num_shards, 196 | num_threads=num_workers, 197 | no_labels=self.extra_args["no_labels"], 198 | encode_indexes_into_labels=self.encode_indexes_into_labels, 199 | ) 200 | output_map = [ 201 | *[f"large{i}" for i in range(num_crops[0])], 202 | *[f"small{i}" for i in range(num_crops[1])], 203 | "label", 204 | ] 205 | 206 | else: 207 | if unique_augs > 1: 208 | transform = [ 209 | transform_pipeline( 210 | device=dali_device, 211 | **kwargs, 212 | max_scale=1.0, 213 | ) 214 | for kwargs in transform_kwargs 215 | ] 216 | else: 217 | transform = transform_pipeline( 218 | device=dali_device, 219 | **transform_kwargs, 220 | max_scale=1.0, 221 | ) 222 | 223 | train_pipeline = PretrainPipeline( 224 | data_dir / train_dir, 225 | batch_size=self.batch_size, 226 | transform=transform, 227 | device=dali_device, 228 | device_id=device_id, 229 | shard_id=shard_id, 230 | num_shards=num_shards, 231 | num_threads=num_workers, 232 | no_labels=self.extra_args["no_labels"], 233 | encode_indexes_into_labels=self.encode_indexes_into_labels, 234 | ) 235 | output_map = [f"large{i}" for i in range(self.num_crops)] + ["label"] 236 | 237 | policy = LastBatchPolicy.DROP 238 | conversion_map = train_pipeline.conversion_map if self.encode_indexes_into_labels else None 239 | train_loader = PretrainWrapper( 240 | model_batch_size=self.batch_size, 241 | model_rank=device_id, 242 | model_device=self.device, 243 | conversion_map=conversion_map, 244 | pipelines=train_pipeline, 245 | output_map=output_map, 246 | reader_name="Reader", 247 | last_batch_policy=policy, 248 | auto_reset=True, 249 | ) 250 | 251 | self.dali_epoch_size = train_pipeline.epoch_size("Reader") 252 | 253 | return train_loader 254 | 255 | 256 | class ClassificationABC(ABC): 257 | """Abstract classification class that returns a train_dataloader and val_dataloader using 258 | dali.""" 259 | 260 | def train_dataloader(self) -> DALIGenericIterator: 261 | device_id = self.local_rank 262 | shard_id = self.global_rank 263 | num_shards = self.trainer.world_size 264 | 265 | num_workers = self.extra_args["num_workers"] 266 | dali_device = self.extra_args["dali_device"] 267 | data_dir = Path(self.extra_args["data_dir"]) 268 | train_dir = Path(self.extra_args["train_dir"]) 269 | 270 | # handle custom data by creating the needed pipeline 271 | dataset = self.extra_args["dataset"] 272 | if dataset in ["imagenet100", "imagenet"]: 273 | pipeline_class = NormalPipeline 274 | elif dataset == "custom": 275 | pipeline_class = CustomNormalPipeline 276 | else: 277 | raise ValueError(dataset, "is not supported, used [imagenet, imagenet100 or custom]") 278 | 279 | train_pipeline = pipeline_class( 280 | data_dir / train_dir, 281 | validation=False, 282 | batch_size=self.batch_size, 283 | device=dali_device, 284 | device_id=device_id, 285 | shard_id=shard_id, 286 | num_shards=num_shards, 287 | num_threads=num_workers, 288 | ) 289 | train_loader = Wrapper( 290 | train_pipeline, 291 | output_map=["x", "label"], 292 | reader_name="Reader", 293 | last_batch_policy=LastBatchPolicy.DROP, 294 | auto_reset=True, 295 | ) 296 | return train_loader 297 | 298 | def val_dataloader(self) -> DALIGenericIterator: 299 | device_id = self.local_rank 300 | shard_id = self.global_rank 301 | num_shards = self.trainer.world_size 302 | 303 | num_workers = self.extra_args["num_workers"] 304 | dali_device = self.extra_args["dali_device"] 305 | data_dir = Path(self.extra_args["data_dir"]) 306 | val_dir = Path(self.extra_args["val_dir"]) 307 | 308 | # handle custom data by creating the needed pipeline 309 | dataset = self.extra_args["dataset"] 310 | if dataset in ["imagenet100", "imagenet"]: 311 | pipeline_class = NormalPipeline 312 | elif dataset == "custom": 313 | pipeline_class = CustomNormalPipeline 314 | else: 315 | raise ValueError(dataset, "is not supported, used [imagenet, imagenet100 or custom]") 316 | 317 | val_pipeline = pipeline_class( 318 | data_dir / val_dir, 319 | validation=True, 320 | batch_size=self.batch_size, 321 | device=dali_device, 322 | device_id=device_id, 323 | shard_id=shard_id, 324 | num_shards=num_shards, 325 | num_threads=num_workers, 326 | ) 327 | 328 | val_loader = Wrapper( 329 | val_pipeline, 330 | output_map=["x", "label"], 331 | reader_name="Reader", 332 | last_batch_policy=LastBatchPolicy.PARTIAL, 333 | auto_reset=True, 334 | ) 335 | return val_loader 336 | -------------------------------------------------------------------------------- /solo/methods/mocov2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import argparse 21 | from typing import Any, Dict, List, Sequence, Tuple 22 | 23 | import torch 24 | import torch.nn as nn 25 | import torch.nn.functional as F 26 | from solo.losses.moco import moco_loss_func 27 | from solo.methods.base import BaseMomentumMethod 28 | from solo.utils.momentum import initialize_momentum_params 29 | from solo.utils.misc import gather 30 | 31 | 32 | class MoCoV2(BaseMomentumMethod): 33 | queue: torch.Tensor 34 | 35 | def __init__( 36 | self, 37 | proj_output_dim: int, 38 | proj_hidden_dim: int, 39 | temperature: float, 40 | queue_size: int, 41 | **kwargs 42 | ): 43 | """Implements MoCo. 44 | 45 | Args: 46 | proj_output_dim (int): number of dimensions of projected features. 47 | proj_hidden_dim (int): number of neurons of the hidden layers of the projector. 48 | temperature (float): temperature for the softmax in the contrastive loss. 49 | queue_size (int): number of samples to keep in the queue. 50 | """ 51 | 52 | super().__init__(**kwargs) 53 | 54 | self.temperature = temperature 55 | self.queue_size = queue_size 56 | 57 | # projector 58 | self.projector = nn.Sequential( 59 | nn.Linear(self.features_dim, proj_hidden_dim), 60 | nn.ReLU(), 61 | nn.Linear(proj_hidden_dim, proj_output_dim), 62 | ) 63 | 64 | # momentum projector 65 | self.momentum_projector = nn.Sequential( 66 | nn.Linear(self.features_dim, proj_hidden_dim), 67 | nn.ReLU(), 68 | nn.Linear(proj_hidden_dim, proj_output_dim), 69 | ) 70 | initialize_momentum_params(self.projector, self.momentum_projector) 71 | 72 | # create the queue 73 | self.register_buffer("queue", torch.randn(2, proj_output_dim, queue_size)) 74 | self.queue = nn.functional.normalize(self.queue, dim=1) 75 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 76 | 77 | @staticmethod 78 | def add_model_specific_args(parent_parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 79 | parent_parser = super(MoCoV2, MoCoV2).add_model_specific_args(parent_parser) 80 | parser = parent_parser.add_argument_group("mocov2") 81 | 82 | # projector 83 | parser.add_argument("--proj_output_dim", type=int, default=128) 84 | parser.add_argument("--proj_hidden_dim", type=int, default=2048) 85 | 86 | # parameters 87 | parser.add_argument("--temperature", type=float, default=0.1) 88 | 89 | # queue settings 90 | parser.add_argument("--queue_size", default=65536, type=int) 91 | 92 | return parent_parser 93 | 94 | @property 95 | def learnable_params(self) -> List[dict]: 96 | """Adds projector parameters together with parent's learnable parameters. 97 | 98 | Returns: 99 | List[dict]: list of learnable parameters. 100 | """ 101 | 102 | extra_learnable_params = [{"params": self.projector.parameters()}] 103 | return super().learnable_params + extra_learnable_params 104 | 105 | @property 106 | def momentum_pairs(self) -> List[Tuple[Any, Any]]: 107 | """Adds (projector, momentum_projector) to the parent's momentum pairs. 108 | 109 | Returns: 110 | List[Tuple[Any, Any]]: list of momentum pairs. 111 | """ 112 | 113 | extra_momentum_pairs = [(self.projector, self.momentum_projector)] 114 | return super().momentum_pairs + extra_momentum_pairs 115 | 116 | @torch.no_grad() 117 | def _dequeue_and_enqueue(self, keys: torch.Tensor): 118 | """Adds new samples and removes old samples from the queue in a fifo manner. 119 | 120 | Args: 121 | keys (torch.Tensor): output features of the momentum encoder. 122 | """ 123 | 124 | batch_size = keys.shape[1] 125 | ptr = int(self.queue_ptr) # type: ignore 126 | assert self.queue_size % batch_size == 0 # for simplicity 127 | 128 | # replace the keys at ptr (dequeue and enqueue) 129 | keys = keys.permute(0, 2, 1) 130 | self.queue[:, :, ptr : ptr + batch_size] = keys 131 | ptr = (ptr + batch_size) % self.queue_size # move pointer 132 | self.queue_ptr[0] = ptr # type: ignore 133 | 134 | def forward(self, X: torch.Tensor, *args, **kwargs) -> Dict[str, Any]: 135 | """Performs the forward pass of the online encoder and the online projection. 136 | 137 | Args: 138 | X (torch.Tensor): a batch of images in the tensor format. 139 | 140 | Returns: 141 | Dict[str, Any]: a dict containing the outputs of the parent and the projected features. 142 | """ 143 | 144 | out = super().forward(X, *args, **kwargs) 145 | q = F.normalize(self.projector(out["feats"]), dim=-1) 146 | return {**out, "q": q} 147 | 148 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: 149 | """ 150 | Training step for MoCo reusing BaseMomentumMethod training step. 151 | 152 | Args: 153 | batch (Sequence[Any]): a batch of data in the 154 | format of [img_indexes, [X], Y], where [X] is a list of size self.num_crops 155 | containing batches of images. 156 | batch_idx (int): index of the batch. 157 | 158 | Returns: 159 | torch.Tensor: total loss composed of MOCO loss and classification loss. 160 | 161 | """ 162 | 163 | out = super().training_step(batch, batch_idx) 164 | class_loss = out["loss"] 165 | feats1, _ = out["feats"] 166 | _, momentum_feats2 = out["momentum_feats"] 167 | 168 | q1 = self.projector(feats1) 169 | q1 = F.normalize(q1, dim=-1) 170 | 171 | with torch.no_grad(): 172 | k2 = self.momentum_projector(momentum_feats2) 173 | k2 = F.normalize(k2, dim=-1) 174 | 175 | # ------- contrastive loss ------- 176 | # symmetric 177 | queue = self.queue.clone().detach() 178 | nce_loss = moco_loss_func(q1, k2, queue[1], self.temperature) 179 | 180 | # ------- update queue ------- 181 | keys = torch.stack((torch.zeros_like(gather(k2)), gather(k2))) 182 | self._dequeue_and_enqueue(keys) 183 | 184 | self.log("train_nce_loss", nce_loss, on_epoch=True, sync_dist=True) 185 | 186 | return nce_loss + class_loss 187 | -------------------------------------------------------------------------------- /solo/methods/mocov2plus.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import argparse 21 | from typing import Any, Dict, List, Sequence, Tuple 22 | 23 | import torch 24 | import torch.nn as nn 25 | import torch.nn.functional as F 26 | from solo.losses.moco import moco_loss_func 27 | from solo.methods.base import BaseMomentumMethod 28 | from solo.utils.momentum import initialize_momentum_params 29 | from solo.utils.misc import gather 30 | 31 | 32 | class MoCoV2Plus(BaseMomentumMethod): 33 | queue: torch.Tensor 34 | 35 | def __init__( 36 | self, 37 | proj_output_dim: int, 38 | proj_hidden_dim: int, 39 | temperature: float, 40 | queue_size: int, 41 | **kwargs 42 | ): 43 | """Implements MoCo V2+ (https://arxiv.org/abs/2011.10566). 44 | 45 | Args: 46 | proj_output_dim (int): number of dimensions of projected features. 47 | proj_hidden_dim (int): number of neurons of the hidden layers of the projector. 48 | temperature (float): temperature for the softmax in the contrastive loss. 49 | queue_size (int): number of samples to keep in the queue. 50 | """ 51 | 52 | super().__init__(**kwargs) 53 | 54 | self.temperature = temperature 55 | self.queue_size = queue_size 56 | 57 | # projector 58 | self.projector = nn.Sequential( 59 | nn.Linear(self.features_dim, proj_hidden_dim), 60 | nn.ReLU(), 61 | nn.Linear(proj_hidden_dim, proj_output_dim), 62 | ) 63 | 64 | # momentum projector 65 | self.momentum_projector = nn.Sequential( 66 | nn.Linear(self.features_dim, proj_hidden_dim), 67 | nn.ReLU(), 68 | nn.Linear(proj_hidden_dim, proj_output_dim), 69 | ) 70 | initialize_momentum_params(self.projector, self.momentum_projector) 71 | 72 | # create the queue 73 | self.register_buffer("queue", torch.randn(2, proj_output_dim, queue_size)) 74 | self.queue = nn.functional.normalize(self.queue, dim=1) 75 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 76 | 77 | @staticmethod 78 | def add_model_specific_args(parent_parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 79 | parent_parser = super(MoCoV2Plus, MoCoV2Plus).add_model_specific_args(parent_parser) 80 | parser = parent_parser.add_argument_group("mocov2plus") 81 | 82 | # projector 83 | parser.add_argument("--proj_output_dim", type=int, default=128) 84 | parser.add_argument("--proj_hidden_dim", type=int, default=2048) 85 | 86 | # parameters 87 | parser.add_argument("--temperature", type=float, default=0.1) 88 | 89 | # queue settings 90 | parser.add_argument("--queue_size", default=65536, type=int) 91 | 92 | return parent_parser 93 | 94 | @property 95 | def learnable_params(self) -> List[dict]: 96 | """Adds projector parameters together with parent's learnable parameters. 97 | 98 | Returns: 99 | List[dict]: list of learnable parameters. 100 | """ 101 | 102 | extra_learnable_params = [{"params": self.projector.parameters()}] 103 | return super().learnable_params + extra_learnable_params 104 | 105 | @property 106 | def momentum_pairs(self) -> List[Tuple[Any, Any]]: 107 | """Adds (projector, momentum_projector) to the parent's momentum pairs. 108 | 109 | Returns: 110 | List[Tuple[Any, Any]]: list of momentum pairs. 111 | """ 112 | 113 | extra_momentum_pairs = [(self.projector, self.momentum_projector)] 114 | return super().momentum_pairs + extra_momentum_pairs 115 | 116 | @torch.no_grad() 117 | def _dequeue_and_enqueue(self, keys: torch.Tensor): 118 | """Adds new samples and removes old samples from the queue in a fifo manner. 119 | 120 | Args: 121 | keys (torch.Tensor): output features of the momentum encoder. 122 | """ 123 | 124 | batch_size = keys.shape[1] 125 | ptr = int(self.queue_ptr) # type: ignore 126 | assert self.queue_size % batch_size == 0 # for simplicity 127 | 128 | # replace the keys at ptr (dequeue and enqueue) 129 | keys = keys.permute(0, 2, 1) 130 | self.queue[:, :, ptr : ptr + batch_size] = keys 131 | ptr = (ptr + batch_size) % self.queue_size # move pointer 132 | self.queue_ptr[0] = ptr # type: ignore 133 | 134 | def forward(self, X: torch.Tensor, *args, **kwargs) -> Dict[str, Any]: 135 | """Performs the forward pass of the online encoder and the online projection. 136 | 137 | Args: 138 | X (torch.Tensor): a batch of images in the tensor format. 139 | 140 | Returns: 141 | Dict[str, Any]: a dict containing the outputs of the parent and the projected features. 142 | """ 143 | 144 | out = super().forward(X, *args, **kwargs) 145 | q = F.normalize(self.projector(out["feats"]), dim=-1) 146 | return {**out, "q": q} 147 | 148 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: 149 | """ 150 | Training step for MoCo reusing BaseMomentumMethod training step. 151 | 152 | Args: 153 | batch (Sequence[Any]): a batch of data in the 154 | format of [img_indexes, [X], Y], where [X] is a list of size self.num_crops 155 | containing batches of images. 156 | batch_idx (int): index of the batch. 157 | 158 | Returns: 159 | torch.Tensor: total loss composed of MOCO loss and classification loss. 160 | 161 | """ 162 | 163 | out = super().training_step(batch, batch_idx) 164 | class_loss = out["loss"] 165 | feats1, feats2 = out["feats"] 166 | momentum_feats1, momentum_feats2 = out["momentum_feats"] 167 | 168 | q1 = self.projector(feats1) 169 | q2 = self.projector(feats2) 170 | q1 = F.normalize(q1, dim=-1) 171 | q2 = F.normalize(q2, dim=-1) 172 | 173 | with torch.no_grad(): 174 | k1 = self.momentum_projector(momentum_feats1) 175 | k2 = self.momentum_projector(momentum_feats2) 176 | k1 = F.normalize(k1, dim=-1) 177 | k2 = F.normalize(k2, dim=-1) 178 | 179 | # ------- contrastive loss ------- 180 | # symmetric 181 | queue = self.queue.clone().detach() 182 | nce_loss = ( 183 | moco_loss_func(q1, k2, queue[1], self.temperature) 184 | + moco_loss_func(q2, k1, queue[0], self.temperature) 185 | ) / 2 186 | 187 | # ------- update queue ------- 188 | keys = torch.stack((gather(k1), gather(k2))) 189 | self._dequeue_and_enqueue(keys) 190 | 191 | self.log("train_nce_loss", nce_loss, on_epoch=True, sync_dist=True) 192 | 193 | return nce_loss + class_loss 194 | -------------------------------------------------------------------------------- /solo/methods/simco_dual_temperature.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import argparse 21 | from typing import Any, Dict, List, Sequence, Tuple 22 | 23 | import torch 24 | import torch.nn as nn 25 | import torch.nn.functional as F 26 | from solo.methods.base import BaseMomentumMethod 27 | from solo.utils.momentum import initialize_momentum_params 28 | from solo.losses.dual_temperature_loss import dual_temperature_loss_func 29 | 30 | 31 | class SimCo_DualTemperature(BaseMomentumMethod): 32 | queue: torch.Tensor 33 | 34 | def __init__( 35 | self, 36 | proj_output_dim: int, 37 | proj_hidden_dim: int, 38 | temperature: float, 39 | dt_m: float, 40 | **kwargs 41 | ): 42 | """Implements simco with dual temperature. 43 | 44 | Args: 45 | proj_output_dim (int): number of dimensions of projected features. 46 | proj_hidden_dim (int): number of neurons of the hidden layers of the projector. 47 | temperature (float): temperature for the softmax in the contrastive loss. 48 | queue_size (int): number of samples to keep in the queue. 49 | """ 50 | 51 | super().__init__(**kwargs) 52 | 53 | self.temperature = temperature 54 | self.dt_m = dt_m 55 | 56 | # projector 57 | self.projector = nn.Sequential( 58 | nn.Linear(self.features_dim, proj_hidden_dim), 59 | nn.ReLU(), 60 | nn.Linear(proj_hidden_dim, proj_output_dim), 61 | ) 62 | 63 | # momentum projector 64 | self.momentum_projector = nn.Sequential( 65 | nn.Linear(self.features_dim, proj_hidden_dim), 66 | nn.ReLU(), 67 | nn.Linear(proj_hidden_dim, proj_output_dim), 68 | ) 69 | 70 | initialize_momentum_params(self.projector, self.momentum_projector) 71 | 72 | 73 | @staticmethod 74 | def add_model_specific_args(parent_parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 75 | parent_parser = super(SimCo_DualTemperature, SimCo_DualTemperature).add_model_specific_args(parent_parser) 76 | parser = parent_parser.add_argument_group("simco_dual_temperature") 77 | 78 | # projector 79 | parser.add_argument("--proj_output_dim", type=int, default=128) 80 | parser.add_argument("--proj_hidden_dim", type=int, default=2048) 81 | 82 | # parameters 83 | parser.add_argument("--temperature", type=float, default=0.1) 84 | parser.add_argument("--dt_m", type=float, default=10) 85 | 86 | return parent_parser 87 | 88 | @property 89 | def learnable_params(self) -> List[dict]: 90 | """Adds projector parameters together with parent's learnable parameters. 91 | 92 | Returns: 93 | List[dict]: list of learnable parameters. 94 | """ 95 | 96 | extra_learnable_params = [{"params": self.projector.parameters()}] 97 | return super().learnable_params + extra_learnable_params 98 | 99 | @property 100 | def momentum_pairs(self) -> List[Tuple[Any, Any]]: 101 | """Adds (projector, momentum_projector) to the parent's momentum pairs. 102 | 103 | Returns: 104 | List[Tuple[Any, Any]]: list of momentum pairs. 105 | """ 106 | 107 | extra_momentum_pairs = [(self.projector, self.momentum_projector)] 108 | return super().momentum_pairs + extra_momentum_pairs 109 | 110 | 111 | def forward(self, X: torch.Tensor, *args, **kwargs) -> Dict[str, Any]: 112 | """Performs the forward pass of the online encoder and the online projection. 113 | 114 | Args: 115 | X (torch.Tensor): a batch of images in the tensor format. 116 | 117 | Returns: 118 | Dict[str, Any]: a dict containing the outputs of the parent and the projected features. 119 | """ 120 | 121 | out = super().forward(X, *args, **kwargs) 122 | q = F.normalize(self.projector(out["feats"]), dim=-1) 123 | return {**out, "q": q} 124 | 125 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: 126 | """ 127 | Training step for MoCo reusing BaseMomentumMethod training step. 128 | 129 | Args: 130 | batch (Sequence[Any]): a batch of data in the 131 | format of [img_indexes, [X], Y], where [X] is a list of size self.num_crops 132 | containing batches of images. 133 | batch_idx (int): index of the batch. 134 | 135 | Returns: 136 | torch.Tensor: total loss composed of MOCO loss and classification loss. 137 | 138 | """ 139 | 140 | out = super().training_step(batch, batch_idx) 141 | class_loss = out["loss"] 142 | feats1, feats2 = out["feats"] 143 | 144 | q1 = self.projector(feats1) 145 | q2 = self.projector(feats2) 146 | 147 | q1 = F.normalize(q1, dim=-1) 148 | q2 = F.normalize(q2, dim=-1) 149 | 150 | 151 | nce_loss = ( 152 | dual_temperature_loss_func(q1, q2, 153 | temperature=self.temperature, 154 | dt_m=self.dt_m) 155 | + dual_temperature_loss_func(q2, q1, 156 | temperature=self.temperature, 157 | dt_m=self.dt_m) 158 | ) / 2 159 | 160 | # calculate std of features 161 | z1_std = F.normalize(q1, dim=-1).std(dim=0).mean() 162 | z2_std = F.normalize(q2, dim=-1).std(dim=0).mean() 163 | z_std = (z1_std + z2_std) / 2 164 | 165 | metrics = { 166 | "train_nce_loss": nce_loss, 167 | "train_z_std": z_std, 168 | } 169 | self.log_dict(metrics, on_epoch=True, sync_dist=True) 170 | 171 | return nce_loss + class_loss 172 | 173 | -------------------------------------------------------------------------------- /solo/methods/simmoco_dual_temperature.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import argparse 21 | from typing import Any, Dict, List, Sequence, Tuple 22 | 23 | import torch 24 | import torch.nn as nn 25 | import torch.nn.functional as F 26 | from solo.methods.base import BaseMomentumMethod 27 | from solo.utils.momentum import initialize_momentum_params 28 | from solo.losses.dual_temperature_loss import dual_temperature_loss_func 29 | 30 | 31 | class SimMoCo_DualTemperature(BaseMomentumMethod): 32 | queue: torch.Tensor 33 | 34 | def __init__( 35 | self, 36 | proj_output_dim: int, 37 | proj_hidden_dim: int, 38 | temperature: float, 39 | dt_m: float, 40 | plus_version: bool, 41 | **kwargs 42 | ): 43 | """Implements simmoco with dual temperature. 44 | 45 | Args: 46 | proj_output_dim (int): number of dimensions of projected features. 47 | proj_hidden_dim (int): number of neurons of the hidden layers of the projector. 48 | temperature (float): temperature for the softmax in the contrastive loss. 49 | queue_size (int): number of samples to keep in the queue. 50 | """ 51 | 52 | super().__init__(**kwargs) 53 | 54 | self.temperature = temperature 55 | self.dt_m = dt_m 56 | self.plus_version = plus_version 57 | 58 | 59 | # projector 60 | self.projector = nn.Sequential( 61 | nn.Linear(self.features_dim, proj_hidden_dim), 62 | nn.ReLU(), 63 | nn.Linear(proj_hidden_dim, proj_output_dim), 64 | ) 65 | 66 | # momentum projector 67 | self.momentum_projector = nn.Sequential( 68 | nn.Linear(self.features_dim, proj_hidden_dim), 69 | nn.ReLU(), 70 | nn.Linear(proj_hidden_dim, proj_output_dim), 71 | ) 72 | 73 | initialize_momentum_params(self.projector, self.momentum_projector) 74 | 75 | 76 | @staticmethod 77 | def add_model_specific_args(parent_parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 78 | parent_parser = super(SimMoCo_DualTemperature, SimMoCo_DualTemperature).add_model_specific_args(parent_parser) 79 | parser = parent_parser.add_argument_group("simmoco_dual_temperature") 80 | 81 | # projector 82 | parser.add_argument("--proj_output_dim", type=int, default=128) 83 | parser.add_argument("--proj_hidden_dim", type=int, default=2048) 84 | 85 | # parameters 86 | parser.add_argument("--temperature", type=float, default=0.1) 87 | parser.add_argument("--dt_m", type=float, default=10) 88 | 89 | # train the plus version which uses symmetric loss 90 | parser.add_argument("--plus_version", action="store_true") 91 | 92 | return parent_parser 93 | 94 | @property 95 | def learnable_params(self) -> List[dict]: 96 | """Adds projector parameters together with parent's learnable parameters. 97 | 98 | Returns: 99 | List[dict]: list of learnable parameters. 100 | """ 101 | 102 | extra_learnable_params = [{"params": self.projector.parameters()}] 103 | return super().learnable_params + extra_learnable_params 104 | 105 | @property 106 | def momentum_pairs(self) -> List[Tuple[Any, Any]]: 107 | """Adds (projector, momentum_projector) to the parent's momentum pairs. 108 | 109 | Returns: 110 | List[Tuple[Any, Any]]: list of momentum pairs. 111 | """ 112 | 113 | extra_momentum_pairs = [(self.projector, self.momentum_projector)] 114 | return super().momentum_pairs + extra_momentum_pairs 115 | 116 | 117 | def forward(self, X: torch.Tensor, *args, **kwargs) -> Dict[str, Any]: 118 | """Performs the forward pass of the online encoder and the online projection. 119 | 120 | Args: 121 | X (torch.Tensor): a batch of images in the tensor format. 122 | 123 | Returns: 124 | Dict[str, Any]: a dict containing the outputs of the parent and the projected features. 125 | """ 126 | 127 | out = super().forward(X, *args, **kwargs) 128 | q = F.normalize(self.projector(out["feats"]), dim=-1) 129 | return {**out, "q": q} 130 | 131 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: 132 | """ 133 | Training step for MoCo reusing BaseMomentumMethod training step. 134 | 135 | Args: 136 | batch (Sequence[Any]): a batch of data in the 137 | format of [img_indexes, [X], Y], where [X] is a list of size self.num_crops 138 | containing batches of images. 139 | batch_idx (int): index of the batch. 140 | 141 | Returns: 142 | torch.Tensor: total loss composed of MOCO loss and classification loss. 143 | 144 | """ 145 | 146 | if self.plus_version: 147 | out = super().training_step(batch, batch_idx) 148 | class_loss = out["loss"] 149 | feats1, feats2 = out["feats"] 150 | momentum_feats1, momentum_feats2 = out["momentum_feats"] 151 | 152 | q1 = self.projector(feats1) 153 | q2 = self.projector(feats2) 154 | 155 | q1 = F.normalize(q1, dim=-1) 156 | q2 = F.normalize(q2, dim=-1) 157 | with torch.no_grad(): 158 | k1 = self.momentum_projector(momentum_feats1) 159 | k2 = self.momentum_projector(momentum_feats2) 160 | k1 = F.normalize(k1, dim=-1).detach() 161 | k2 = F.normalize(k2, dim=-1).detach() 162 | 163 | 164 | nce_loss = ( 165 | dual_temperature_loss_func(q1, k2, 166 | temperature=self.temperature, 167 | dt_m=self.dt_m) 168 | + dual_temperature_loss_func(q2, k1, 169 | temperature=self.temperature, 170 | dt_m=self.dt_m) 171 | ) / 2 172 | 173 | # calculate std of features 174 | z1_std = F.normalize(q1, dim=-1).std(dim=0).mean() 175 | z2_std = F.normalize(q2, dim=-1).std(dim=0).mean() 176 | z_std = (z1_std + z2_std) / 2 177 | 178 | metrics = { 179 | "train_nce_loss": nce_loss, 180 | "train_z_std": z_std, 181 | } 182 | self.log_dict(metrics, on_epoch=True, sync_dist=True) 183 | 184 | return nce_loss + class_loss 185 | 186 | else: 187 | out = super().training_step(batch, batch_idx) 188 | class_loss = out["loss"] 189 | feats1, _ = out["feats"] 190 | _, momentum_feats2 = out["momentum_feats"] 191 | 192 | q1 = self.projector(feats1) 193 | 194 | q1 = F.normalize(q1, dim=-1) 195 | 196 | with torch.no_grad(): 197 | k2 = self.momentum_projector(momentum_feats2) 198 | k2 = F.normalize(k2, dim=-1).detach() 199 | 200 | nce_loss = dual_temperature_loss_func(q1, k2, 201 | temperature=self.temperature, 202 | dt_m=self.dt_m) 203 | 204 | # calculate std of features 205 | z1_std = F.normalize(q1, dim=-1).std(dim=0).mean() 206 | z_std = z1_std 207 | 208 | metrics = { 209 | "train_nce_loss": nce_loss, 210 | "train_z_std": z_std, 211 | } 212 | self.log_dict(metrics, on_epoch=True, sync_dist=True) 213 | 214 | return nce_loss + class_loss 215 | -------------------------------------------------------------------------------- /solo/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | from solo.utils import ( 21 | backbones, 22 | checkpointer, 23 | classification_dataloader, 24 | knn, 25 | lars, 26 | metrics, 27 | misc, 28 | momentum, 29 | pretrain_dataloader, 30 | sinkhorn_knopp, 31 | ) 32 | 33 | __all__ = [ 34 | "backbones", 35 | "classification_dataloader", 36 | "pretrain_dataloader", 37 | "checkpointer", 38 | "knn", 39 | "misc", 40 | "lars", 41 | "metrics", 42 | "momentum", 43 | "sinkhorn_knopp", 44 | ] 45 | 46 | try: 47 | from solo.utils import dali_dataloader # noqa: F401 48 | except ImportError: 49 | pass 50 | else: 51 | __all__.append("dali_dataloader") 52 | 53 | try: 54 | from solo.utils import auto_umap # noqa: F401 55 | except ImportError: 56 | pass 57 | else: 58 | __all__.append("auto_umap") 59 | -------------------------------------------------------------------------------- /solo/utils/auto_umap.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import math 21 | import os 22 | from argparse import ArgumentParser, Namespace 23 | from pathlib import Path 24 | from typing import Optional, Union 25 | 26 | import pandas as pd 27 | import pytorch_lightning as pl 28 | import seaborn as sns 29 | import torch 30 | import umap 31 | import wandb 32 | from matplotlib import pyplot as plt 33 | from pytorch_lightning.callbacks import Callback 34 | 35 | from .misc import gather 36 | 37 | 38 | class AutoUMAP(Callback): 39 | def __init__( 40 | self, 41 | args: Namespace, 42 | logdir: Union[str, Path] = Path("auto_umap"), 43 | frequency: int = 1, 44 | keep_previous: bool = False, 45 | color_palette: str = "hls", 46 | ): 47 | """UMAP callback that automatically runs UMAP on the validation dataset and uploads the 48 | figure to wandb. 49 | 50 | Args: 51 | args (Namespace): namespace object containing at least an attribute name. 52 | logdir (Union[str, Path], optional): base directory to store checkpoints. 53 | Defaults to Path("auto_umap"). 54 | frequency (int, optional): number of epochs between each UMAP. Defaults to 1. 55 | color_palette (str, optional): color scheme for the classes. Defaults to "hls". 56 | keep_previous (bool, optional): whether to keep previous plots or not. 57 | Defaults to False. 58 | """ 59 | 60 | super().__init__() 61 | 62 | self.args = args 63 | self.logdir = Path(logdir) 64 | self.frequency = frequency 65 | self.color_palette = color_palette 66 | self.keep_previous = keep_previous 67 | 68 | @staticmethod 69 | def add_auto_umap_args(parent_parser: ArgumentParser): 70 | """Adds user-required arguments to a parser. 71 | 72 | Args: 73 | parent_parser (ArgumentParser): parser to add new args to. 74 | """ 75 | 76 | parser = parent_parser.add_argument_group("auto_umap") 77 | parser.add_argument("--auto_umap_dir", default=Path("auto_umap"), type=Path) 78 | parser.add_argument("--auto_umap_frequency", default=1, type=int) 79 | return parent_parser 80 | 81 | def initial_setup(self, trainer: pl.Trainer): 82 | """Creates the directories and does the initial setup needed. 83 | 84 | Args: 85 | trainer (pl.Trainer): pytorch lightning trainer object. 86 | """ 87 | 88 | if trainer.logger is None: 89 | version = None 90 | else: 91 | version = str(trainer.logger.version) 92 | if version is not None: 93 | self.path = self.logdir / version 94 | self.umap_placeholder = f"{self.args.name}-{version}" + "-ep={}.pdf" 95 | else: 96 | self.path = self.logdir 97 | self.umap_placeholder = f"{self.args.name}" + "-ep={}.pdf" 98 | self.last_ckpt: Optional[str] = None 99 | 100 | # create logging dirs 101 | if trainer.is_global_zero: 102 | os.makedirs(self.path, exist_ok=True) 103 | 104 | def on_train_start(self, trainer: pl.Trainer, _): 105 | """Performs initial setup on training start. 106 | 107 | Args: 108 | trainer (pl.Trainer): pytorch lightning trainer object. 109 | """ 110 | 111 | self.initial_setup(trainer) 112 | 113 | def plot(self, trainer: pl.Trainer, module: pl.LightningModule): 114 | """Produces a UMAP visualization by forwarding all data of the 115 | first validation dataloader through the module. 116 | 117 | Args: 118 | trainer (pl.Trainer): pytorch lightning trainer object. 119 | module (pl.LightningModule): current module object. 120 | """ 121 | 122 | device = module.device 123 | data = [] 124 | Y = [] 125 | 126 | # set module to eval model and collect all feature representations 127 | module.eval() 128 | with torch.no_grad(): 129 | for x, y in trainer.val_dataloaders[0]: 130 | x = x.to(device, non_blocking=True) 131 | y = y.to(device, non_blocking=True) 132 | 133 | feats = module(x)["feats"] 134 | 135 | feats = gather(feats) 136 | y = gather(y) 137 | 138 | data.append(feats.cpu()) 139 | Y.append(y.cpu()) 140 | module.train() 141 | 142 | if trainer.is_global_zero and len(data): 143 | data = torch.cat(data, dim=0).numpy() 144 | Y = torch.cat(Y, dim=0) 145 | num_classes = len(torch.unique(Y)) 146 | Y = Y.numpy() 147 | 148 | data = umap.UMAP(n_components=2).fit_transform(data) 149 | 150 | # passing to dataframe 151 | df = pd.DataFrame() 152 | df["feat_1"] = data[:, 0] 153 | df["feat_2"] = data[:, 1] 154 | df["Y"] = Y 155 | plt.figure(figsize=(9, 9)) 156 | ax = sns.scatterplot( 157 | x="feat_1", 158 | y="feat_2", 159 | hue="Y", 160 | palette=sns.color_palette(self.color_palette, num_classes), 161 | data=df, 162 | legend="full", 163 | alpha=0.3, 164 | ) 165 | ax.set(xlabel="", ylabel="", xticklabels=[], yticklabels=[]) 166 | ax.tick_params(left=False, right=False, bottom=False, top=False) 167 | 168 | # manually improve quality of imagenet umaps 169 | if num_classes > 100: 170 | anchor = (0.5, 1.8) 171 | else: 172 | anchor = (0.5, 1.35) 173 | 174 | plt.legend(loc="upper center", bbox_to_anchor=anchor, ncol=math.ceil(num_classes / 10)) 175 | plt.tight_layout() 176 | 177 | if isinstance(trainer.logger, pl.loggers.WandbLogger): 178 | wandb.log( 179 | {"validation_umap": wandb.Image(ax)}, 180 | commit=False, 181 | ) 182 | 183 | # save plot locally as well 184 | epoch = trainer.current_epoch # type: ignore 185 | plt.savefig(self.path / self.umap_placeholder.format(epoch)) 186 | plt.close() 187 | 188 | def on_validation_end(self, trainer: pl.Trainer, module: pl.LightningModule): 189 | """Tries to generate an up-to-date UMAP visualization of the features 190 | at the end of each validation epoch. 191 | 192 | Args: 193 | trainer (pl.Trainer): pytorch lightning trainer object. 194 | """ 195 | 196 | epoch = trainer.current_epoch # type: ignore 197 | if epoch % self.frequency == 0 and not trainer.sanity_checking: 198 | self.plot(trainer, module) 199 | -------------------------------------------------------------------------------- /solo/utils/backbones.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | # Copy-pasted from timm (https://github.com/rwightman/pytorch-image-models/blob/master/timm/), 21 | # but allowing different window sizes. 22 | 23 | 24 | from timm.models.swin_transformer import _create_swin_transformer, register_model 25 | from timm.models.vision_transformer import _create_vision_transformer 26 | 27 | 28 | @register_model 29 | def swin_tiny(window_size=7, **kwargs): 30 | model_kwargs = dict( 31 | patch_size=4, 32 | window_size=window_size, 33 | embed_dim=96, 34 | depths=(2, 2, 6, 2), 35 | num_heads=(3, 6, 12, 24), 36 | num_classes=0, 37 | **kwargs, 38 | ) 39 | return _create_swin_transformer("swin_tiny_patch4_window7_224", **model_kwargs) 40 | 41 | 42 | @register_model 43 | def swin_small(window_size=7, **kwargs): 44 | model_kwargs = dict( 45 | patch_size=4, 46 | window_size=window_size, 47 | embed_dim=96, 48 | depths=(2, 2, 18, 2), 49 | num_heads=(3, 6, 12, 24), 50 | num_classes=0, 51 | **kwargs, 52 | ) 53 | return _create_swin_transformer( 54 | "swin_small_patch4_window7_224", pretrained=False, **model_kwargs 55 | ) 56 | 57 | 58 | @register_model 59 | def swin_base(window_size=7, **kwargs): 60 | model_kwargs = dict( 61 | patch_size=4, 62 | window_size=window_size, 63 | embed_dim=128, 64 | depths=(2, 2, 18, 2), 65 | num_heads=(4, 8, 16, 32), 66 | num_classes=0, 67 | **kwargs, 68 | ) 69 | return _create_swin_transformer( 70 | "swin_base_patch4_window7_224", pretrained=False, **model_kwargs 71 | ) 72 | 73 | 74 | @register_model 75 | def swin_large(window_size=7, **kwargs): 76 | model_kwargs = dict( 77 | patch_size=4, 78 | window_size=window_size, 79 | embed_dim=192, 80 | depths=(2, 2, 18, 2), 81 | num_heads=(6, 12, 24, 48), 82 | num_classes=0, 83 | **kwargs, 84 | ) 85 | return _create_swin_transformer( 86 | "swin_large_patch4_window7_224", pretrained=False, **model_kwargs 87 | ) 88 | 89 | 90 | @register_model 91 | def vit_tiny(patch_size=16, **kwargs): 92 | """ViT-Tiny (Vit-Ti/16)""" 93 | model_kwargs = dict( 94 | patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, num_classes=0, **kwargs 95 | ) 96 | model = _create_vision_transformer("vit_tiny_patch16_224", pretrained=False, **model_kwargs) 97 | return model 98 | 99 | 100 | @register_model 101 | def vit_small(patch_size=16, **kwargs): 102 | model_kwargs = dict( 103 | patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, num_classes=0, **kwargs 104 | ) 105 | model = _create_vision_transformer("vit_small_patch16_224", pretrained=False, **model_kwargs) 106 | return model 107 | 108 | 109 | @register_model 110 | def vit_base(patch_size=16, **kwargs): 111 | model_kwargs = dict( 112 | patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, num_classes=0, **kwargs 113 | ) 114 | model = _create_vision_transformer("vit_base_patch16_224", pretrained=False, **model_kwargs) 115 | return model 116 | 117 | 118 | @register_model 119 | def vit_large(patch_size=16, **kwargs): 120 | model_kwargs = dict( 121 | patch_size=patch_size, embed_dim=1024, depth=24, num_heads=16, num_classes=0, **kwargs 122 | ) 123 | model = _create_vision_transformer("vit_large_patch16_224", pretrained=False, **model_kwargs) 124 | return model 125 | -------------------------------------------------------------------------------- /solo/utils/checkpointer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import json 21 | import os 22 | from argparse import ArgumentParser, Namespace 23 | from pathlib import Path 24 | from typing import Optional, Union 25 | 26 | import pytorch_lightning as pl 27 | from pytorch_lightning.callbacks import Callback 28 | 29 | 30 | class Checkpointer(Callback): 31 | def __init__( 32 | self, 33 | args: Namespace, 34 | logdir: Union[str, Path] = Path("trained_models"), 35 | frequency: int = 1, 36 | keep_previous_checkpoints: bool = False, 37 | ): 38 | """Custom checkpointer callback that stores checkpoints in an easier to access way. 39 | 40 | Args: 41 | args (Namespace): namespace object containing at least an attribute name. 42 | logdir (Union[str, Path], optional): base directory to store checkpoints. 43 | Defaults to "trained_models". 44 | frequency (int, optional): number of epochs between each checkpoint. Defaults to 1. 45 | keep_previous_checkpoints (bool, optional): whether to keep previous checkpoints or not. 46 | Defaults to False. 47 | """ 48 | 49 | super().__init__() 50 | 51 | self.args = args 52 | self.logdir = Path(logdir) 53 | self.frequency = frequency 54 | self.keep_previous_checkpoints = keep_previous_checkpoints 55 | 56 | @staticmethod 57 | def add_checkpointer_args(parent_parser: ArgumentParser): 58 | """Adds user-required arguments to a parser. 59 | 60 | Args: 61 | parent_parser (ArgumentParser): parser to add new args to. 62 | """ 63 | 64 | parser = parent_parser.add_argument_group("checkpointer") 65 | parser.add_argument("--checkpoint_dir", default=Path("trained_models"), type=Path) 66 | parser.add_argument("--checkpoint_frequency", default=1, type=int) 67 | return parent_parser 68 | 69 | def initial_setup(self, trainer: pl.Trainer): 70 | """Creates the directories and does the initial setup needed. 71 | 72 | Args: 73 | trainer (pl.Trainer): pytorch lightning trainer object. 74 | """ 75 | 76 | if trainer.logger is None: 77 | version = None 78 | else: 79 | version = str(trainer.logger.version) 80 | if version is not None: 81 | self.path = self.logdir / version 82 | self.ckpt_placeholder = f"{self.args.name}-{version}" + "-ep={}.ckpt" 83 | else: 84 | self.path = self.logdir 85 | self.ckpt_placeholder = f"{self.args.name}" + "-ep={}.ckpt" 86 | self.last_ckpt: Optional[str] = None 87 | 88 | # create logging dirs 89 | if trainer.is_global_zero: 90 | os.makedirs(self.path, exist_ok=True) 91 | 92 | def save_args(self, trainer: pl.Trainer): 93 | """Stores arguments into a json file. 94 | 95 | Args: 96 | trainer (pl.Trainer): pytorch lightning trainer object. 97 | """ 98 | 99 | if trainer.is_global_zero: 100 | args = vars(self.args) 101 | json_path = self.path / "args.json" 102 | json.dump(args, open(json_path, "w"), default=lambda o: "") 103 | 104 | def save(self, trainer: pl.Trainer): 105 | """Saves current checkpoint. 106 | 107 | Args: 108 | trainer (pl.Trainer): pytorch lightning trainer object. 109 | """ 110 | 111 | if trainer.is_global_zero and not trainer.sanity_checking: 112 | epoch = trainer.current_epoch # type: ignore 113 | ckpt = self.path / self.ckpt_placeholder.format(epoch) 114 | trainer.save_checkpoint(ckpt) 115 | 116 | if self.last_ckpt and self.last_ckpt != ckpt and not self.keep_previous_checkpoints: 117 | os.remove(self.last_ckpt) 118 | self.last_ckpt = ckpt 119 | 120 | def on_train_start(self, trainer: pl.Trainer, _): 121 | """Executes initial setup and saves arguments. 122 | 123 | Args: 124 | trainer (pl.Trainer): pytorch lightning trainer object. 125 | """ 126 | 127 | self.initial_setup(trainer) 128 | self.save_args(trainer) 129 | 130 | def on_validation_end(self, trainer: pl.Trainer, _): 131 | """Tries to save current checkpoint at the end of each validation epoch. 132 | 133 | Args: 134 | trainer (pl.Trainer): pytorch lightning trainer object. 135 | """ 136 | 137 | epoch = trainer.current_epoch # type: ignore 138 | if epoch % self.frequency == 0: 139 | self.save(trainer) 140 | -------------------------------------------------------------------------------- /solo/utils/classification_dataloader.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import os 21 | from pathlib import Path 22 | from typing import Callable, Optional, Tuple, Union 23 | 24 | import torchvision 25 | from torch import nn 26 | from torch.utils.data import DataLoader, Dataset 27 | from torchvision import transforms 28 | from torchvision.datasets import STL10, ImageFolder 29 | 30 | 31 | def build_custom_pipeline(): 32 | """Builds augmentation pipelines for custom data. 33 | If you want to do exoteric augmentations, you can just re-write this function. 34 | Needs to return a dict with the same structure. 35 | """ 36 | 37 | pipeline = { 38 | "T_train": transforms.Compose( 39 | [ 40 | transforms.RandomResizedCrop(size=224, scale=(0.08, 1.0)), 41 | transforms.RandomHorizontalFlip(), 42 | transforms.ToTensor(), 43 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.228, 0.224, 0.225)), 44 | ] 45 | ), 46 | "T_val": transforms.Compose( 47 | [ 48 | transforms.Resize(256), # resize shorter 49 | transforms.CenterCrop(224), # take center crop 50 | transforms.ToTensor(), 51 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.228, 0.224, 0.225)), 52 | ] 53 | ), 54 | } 55 | return pipeline 56 | 57 | 58 | def prepare_transforms(dataset: str) -> Tuple[nn.Module, nn.Module]: 59 | """Prepares pre-defined train and test transformation pipelines for some datasets. 60 | 61 | Args: 62 | dataset (str): dataset name. 63 | 64 | Returns: 65 | Tuple[nn.Module, nn.Module]: training and validation transformation pipelines. 66 | """ 67 | 68 | cifar_pipeline = { 69 | "T_train": transforms.Compose( 70 | [ 71 | transforms.RandomResizedCrop(size=32, scale=(0.08, 1.0)), 72 | transforms.RandomHorizontalFlip(), 73 | transforms.ToTensor(), 74 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)), 75 | ] 76 | ), 77 | "T_val": transforms.Compose( 78 | [ 79 | transforms.ToTensor(), 80 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)), 81 | ] 82 | ), 83 | } 84 | 85 | stl_pipeline = { 86 | "T_train": transforms.Compose( 87 | [ 88 | transforms.RandomResizedCrop(size=96, scale=(0.08, 1.0)), 89 | transforms.RandomHorizontalFlip(), 90 | transforms.ToTensor(), 91 | transforms.Normalize((0.4914, 0.4823, 0.4466), (0.247, 0.243, 0.261)), 92 | ] 93 | ), 94 | "T_val": transforms.Compose( 95 | [ 96 | transforms.Resize((96, 96)), 97 | transforms.ToTensor(), 98 | transforms.Normalize((0.4914, 0.4823, 0.4466), (0.247, 0.243, 0.261)), 99 | ] 100 | ), 101 | } 102 | 103 | imagenet_pipeline = { 104 | "T_train": transforms.Compose( 105 | [ 106 | transforms.RandomResizedCrop(size=224, scale=(0.08, 1.0)), 107 | transforms.RandomHorizontalFlip(), 108 | transforms.ToTensor(), 109 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.228, 0.224, 0.225)), 110 | ] 111 | ), 112 | "T_val": transforms.Compose( 113 | [ 114 | transforms.Resize(256), # resize shorter 115 | transforms.CenterCrop(224), # take center crop 116 | transforms.ToTensor(), 117 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.228, 0.224, 0.225)), 118 | ] 119 | ), 120 | } 121 | 122 | custom_pipeline = build_custom_pipeline() 123 | 124 | pipelines = { 125 | "cifar10": cifar_pipeline, 126 | "cifar100": cifar_pipeline, 127 | "stl10": stl_pipeline, 128 | "imagenet100": imagenet_pipeline, 129 | "imagenet": imagenet_pipeline, 130 | "custom": custom_pipeline, 131 | } 132 | 133 | assert dataset in pipelines 134 | 135 | pipeline = pipelines[dataset] 136 | T_train = pipeline["T_train"] 137 | T_val = pipeline["T_val"] 138 | 139 | return T_train, T_val 140 | 141 | 142 | def prepare_datasets( 143 | dataset: str, 144 | T_train: Callable, 145 | T_val: Callable, 146 | data_dir: Optional[Union[str, Path]] = None, 147 | train_dir: Optional[Union[str, Path]] = None, 148 | val_dir: Optional[Union[str, Path]] = None, 149 | ) -> Tuple[Dataset, Dataset]: 150 | """Prepares train and val datasets. 151 | 152 | Args: 153 | dataset (str): dataset name. 154 | T_train (Callable): pipeline of transformations for training dataset. 155 | T_val (Callable): pipeline of transformations for validation dataset. 156 | data_dir Optional[Union[str, Path]]: path where to download/locate the dataset. 157 | train_dir Optional[Union[str, Path]]: subpath where the training data is located. 158 | val_dir Optional[Union[str, Path]]: subpath where the validation data is located. 159 | 160 | Returns: 161 | Tuple[Dataset, Dataset]: training dataset and validation dataset. 162 | """ 163 | 164 | if data_dir is None: 165 | sandbox_dir = Path(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 166 | data_dir = sandbox_dir / "datasets" 167 | else: 168 | data_dir = Path(data_dir) 169 | 170 | if train_dir is None: 171 | train_dir = Path(f"{dataset}/train") 172 | else: 173 | train_dir = Path(train_dir) 174 | 175 | if val_dir is None: 176 | val_dir = Path(f"{dataset}/val") 177 | else: 178 | val_dir = Path(val_dir) 179 | 180 | assert dataset in ["cifar10", "cifar100", "stl10", "imagenet", "imagenet100", "custom"] 181 | 182 | if dataset in ["cifar10", "cifar100"]: 183 | DatasetClass = vars(torchvision.datasets)[dataset.upper()] 184 | train_dataset = DatasetClass( 185 | data_dir / train_dir, 186 | train=True, 187 | download=True, 188 | transform=T_train, 189 | ) 190 | 191 | val_dataset = DatasetClass( 192 | data_dir / val_dir, 193 | train=False, 194 | download=True, 195 | transform=T_val, 196 | ) 197 | 198 | elif dataset == "stl10": 199 | train_dataset = STL10( 200 | data_dir / train_dir, 201 | split="train", 202 | download=True, 203 | transform=T_train, 204 | ) 205 | val_dataset = STL10( 206 | data_dir / val_dir, 207 | split="test", 208 | download=True, 209 | transform=T_val, 210 | ) 211 | 212 | elif dataset in ["imagenet", "imagenet100", "custom"]: 213 | train_dir = data_dir / train_dir 214 | val_dir = data_dir / val_dir 215 | 216 | train_dataset = ImageFolder(train_dir, T_train) 217 | val_dataset = ImageFolder(val_dir, T_val) 218 | 219 | return train_dataset, val_dataset 220 | 221 | 222 | def prepare_dataloaders( 223 | train_dataset: Dataset, val_dataset: Dataset, batch_size: int = 64, num_workers: int = 4 224 | ) -> Tuple[DataLoader, DataLoader]: 225 | """Wraps a train and a validation dataset with a DataLoader. 226 | 227 | Args: 228 | train_dataset (Dataset): object containing training data. 229 | val_dataset (Dataset): object containing validation data. 230 | batch_size (int): batch size. 231 | num_workers (int): number of parallel workers. 232 | Returns: 233 | Tuple[DataLoader, DataLoader]: training dataloader and validation dataloader. 234 | """ 235 | 236 | train_loader = DataLoader( 237 | train_dataset, 238 | batch_size=batch_size, 239 | shuffle=True, 240 | num_workers=num_workers, 241 | pin_memory=True, 242 | drop_last=True, 243 | ) 244 | val_loader = DataLoader( 245 | val_dataset, 246 | batch_size=batch_size, 247 | num_workers=num_workers, 248 | pin_memory=True, 249 | drop_last=False, 250 | ) 251 | return train_loader, val_loader 252 | 253 | 254 | def prepare_data( 255 | dataset: str, 256 | data_dir: Optional[Union[str, Path]] = None, 257 | train_dir: Optional[Union[str, Path]] = None, 258 | val_dir: Optional[Union[str, Path]] = None, 259 | batch_size: int = 64, 260 | num_workers: int = 4, 261 | ) -> Tuple[DataLoader, DataLoader]: 262 | """Prepares transformations, creates dataset objects and wraps them in dataloaders. 263 | 264 | Args: 265 | dataset (str): dataset name. 266 | data_dir (Optional[Union[str, Path]], optional): path where to download/locate the dataset. 267 | Defaults to None. 268 | train_dir (Optional[Union[str, Path]], optional): subpath where the 269 | training data is located. Defaults to None. 270 | val_dir (Optional[Union[str, Path]], optional): subpath where the 271 | validation data is located. Defaults to None. 272 | batch_size (int, optional): batch size. Defaults to 64. 273 | num_workers (int, optional): number of parallel workers. Defaults to 4. 274 | 275 | Returns: 276 | Tuple[DataLoader, DataLoader]: prepared training and validation dataloader;. 277 | """ 278 | 279 | T_train, T_val = prepare_transforms(dataset) 280 | train_dataset, val_dataset = prepare_datasets( 281 | dataset, 282 | T_train, 283 | T_val, 284 | data_dir=data_dir, 285 | train_dir=train_dir, 286 | val_dir=val_dir, 287 | ) 288 | train_loader, val_loader = prepare_dataloaders( 289 | train_dataset, 290 | val_dataset, 291 | batch_size=batch_size, 292 | num_workers=num_workers, 293 | ) 294 | return train_loader, val_loader 295 | -------------------------------------------------------------------------------- /solo/utils/dali_dataloader.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import os 21 | from pathlib import Path 22 | from typing import Callable, Iterable, List, Sequence, Union 23 | 24 | import nvidia.dali.fn as fn 25 | import nvidia.dali.ops as ops 26 | import nvidia.dali.types as types 27 | from nvidia.dali.pipeline import Pipeline 28 | 29 | 30 | class Mux: 31 | def __init__(self, prob: float): 32 | """Implements mutex operation for dali in order to support probabilitic augmentations. 33 | 34 | Args: 35 | prob (float): probability value 36 | """ 37 | 38 | self.to_bool = ops.Cast(dtype=types.DALIDataType.BOOL) 39 | self.rng = ops.random.CoinFlip(probability=prob) 40 | 41 | def __call__(self, true_case, false_case): 42 | condition = self.to_bool(self.rng()) 43 | neg_condition = condition ^ True 44 | return condition * true_case + neg_condition * false_case 45 | 46 | 47 | class RandomGrayScaleConversion: 48 | def __init__(self, prob: float = 0.2, device: str = "gpu"): 49 | """Converts image to greyscale with probability. 50 | 51 | Args: 52 | prob (float, optional): probability of conversion. Defaults to 0.2. 53 | device (str, optional): device on which the operation will be performed. 54 | Defaults to "gpu". 55 | """ 56 | 57 | self.mux = Mux(prob=prob) 58 | self.grayscale = ops.ColorSpaceConversion( 59 | device=device, image_type=types.RGB, output_type=types.GRAY 60 | ) 61 | 62 | def __call__(self, images): 63 | out = self.grayscale(images) 64 | out = fn.cat(out, out, out, axis=2) 65 | return self.mux(true_case=out, false_case=images) 66 | 67 | 68 | class RandomColorJitter: 69 | def __init__( 70 | self, 71 | brightness: float, 72 | contrast: float, 73 | saturation: float, 74 | hue: float, 75 | prob: float = 0.8, 76 | device: str = "gpu", 77 | ): 78 | """Applies random color jittering with probability. 79 | 80 | Args: 81 | brightness (float): brightness value for samplying uniformly 82 | in [max(0, 1 - brightness), 1 + brightness]. 83 | contrast (float): contrast value for samplying uniformly 84 | in [max(0, 1 - contrast), 1 + contrast]. 85 | saturation (float): saturation value for samplying uniformly 86 | in [max(0, 1 - saturation), 1 + saturation]. 87 | hue (float): hue value for samplying uniformly in [-hue, hue]. 88 | prob (float, optional): probability of applying jitter. Defaults to 0.8. 89 | device (str, optional): device on which the operation will be performed. 90 | Defaults to "gpu". 91 | """ 92 | 93 | assert 0 <= hue <= 0.5 94 | 95 | self.mux = Mux(prob=prob) 96 | 97 | self.color = ops.ColorTwist(device=device) 98 | 99 | # look at torchvision docs to see how colorjitter samples stuff 100 | # for bright, cont and sat, it samples from [1-v, 1+v] 101 | # for hue, it samples from [-hue, hue] 102 | 103 | self.brightness = 1 104 | self.contrast = 1 105 | self.saturation = 1 106 | self.hue = 0 107 | 108 | if brightness: 109 | self.brightness = ops.random.Uniform(range=[max(0, 1 - brightness), 1 + brightness]) 110 | 111 | if contrast: 112 | self.contrast = ops.random.Uniform(range=[max(0, 1 - contrast), 1 + contrast]) 113 | 114 | if saturation: 115 | self.saturation = ops.random.Uniform(range=[max(0, 1 - saturation), 1 + saturation]) 116 | 117 | if hue: 118 | # dali uses hue in degrees for some reason... 119 | hue = 360 * hue 120 | self.hue = ops.random.Uniform(range=[-hue, hue]) 121 | 122 | def __call__(self, images): 123 | out = self.color( 124 | images, 125 | brightness=self.brightness() if callable(self.brightness) else self.brightness, 126 | contrast=self.contrast() if callable(self.contrast) else self.contrast, 127 | saturation=self.saturation() if callable(self.saturation) else self.saturation, 128 | hue=self.hue() if callable(self.hue) else self.hue, 129 | ) 130 | return self.mux(true_case=out, false_case=images) 131 | 132 | 133 | class RandomGaussianBlur: 134 | def __init__(self, prob: float = 0.5, window_size: int = 23, device: str = "gpu"): 135 | """Applies random gaussian blur with probability. 136 | 137 | Args: 138 | prob (float, optional): probability of applying random gaussian blur. Defaults to 0.5. 139 | window_size (int, optional): window size for gaussian blur. Defaults to 23. 140 | device (str, optional): device on which the operation will be performe. 141 | Defaults to "gpu". 142 | """ 143 | 144 | self.mux = Mux(prob=prob) 145 | # gaussian blur 146 | self.gaussian_blur = ops.GaussianBlur(device=device, window_size=(window_size, window_size)) 147 | self.sigma = ops.random.Uniform(range=[0, 1]) 148 | 149 | def __call__(self, images): 150 | sigma = self.sigma() * 1.9 + 0.1 151 | out = self.gaussian_blur(images, sigma=sigma) 152 | return self.mux(true_case=out, false_case=images) 153 | 154 | 155 | class RandomSolarize: 156 | def __init__(self, threshold: int = 128, prob: float = 0.0): 157 | """Applies random solarization with probability. 158 | 159 | Args: 160 | threshold (int, optional): threshold for inversion. Defaults to 128. 161 | prob (float, optional): probability of solarization. Defaults to 0.0. 162 | """ 163 | 164 | self.mux = Mux(prob=prob) 165 | 166 | self.threshold = threshold 167 | 168 | def __call__(self, images): 169 | inverted_img = 255 - images 170 | mask = images >= self.threshold 171 | out = mask * inverted_img + (True ^ mask) * images 172 | return self.mux(true_case=out, false_case=images) 173 | 174 | 175 | class NormalPipeline(Pipeline): 176 | def __init__( 177 | self, 178 | data_path: str, 179 | batch_size: int, 180 | device: str, 181 | validation: bool = False, 182 | device_id: int = 0, 183 | shard_id: int = 0, 184 | num_shards: int = 1, 185 | num_threads: int = 4, 186 | seed: int = 12, 187 | ): 188 | """Initializes the pipeline for validation or linear eval training. 189 | 190 | If validation is set to True then images will only be resized to 256px and center cropped 191 | to 224px, otherwise random resized crop, horizontal flip are applied. In both cases images 192 | are normalized. 193 | 194 | Args: 195 | data_path (str): directory that contains the data. 196 | batch_size (int): batch size. 197 | device (str): device on which the operation will be performed. 198 | validation (bool): whether it is validation or training. Defaults to False. Defaults to 199 | False. 200 | device_id (int): id of the device used to initialize the seed and for parent class. 201 | Defaults to 0. 202 | shard_id (int): id of the shard (chuck of samples). Defaults to 0. 203 | num_shards (int): total number of shards. Defaults to 1. 204 | num_threads (int): number of threads to run in parallel. Defaults to 4. 205 | seed (int): seed for random number generation. Defaults to 12. 206 | """ 207 | 208 | seed += device_id 209 | super().__init__(batch_size, num_threads, device_id, seed) 210 | 211 | self.device = device 212 | self.validation = validation 213 | 214 | self.reader = ops.readers.File( 215 | file_root=data_path, 216 | shard_id=shard_id, 217 | num_shards=num_shards, 218 | shuffle_after_epoch=True if not self.validation else False, 219 | ) 220 | decoder_device = "mixed" if self.device == "gpu" else "cpu" 221 | device_memory_padding = 211025920 if decoder_device == "mixed" else 0 222 | host_memory_padding = 140544512 if decoder_device == "mixed" else 0 223 | self.decode = ops.decoders.Image( 224 | device=decoder_device, 225 | output_type=types.RGB, 226 | device_memory_padding=device_memory_padding, 227 | host_memory_padding=host_memory_padding, 228 | ) 229 | 230 | # crop operations 231 | if self.validation: 232 | self.resize = ops.Resize( 233 | device=self.device, 234 | resize_shorter=256, 235 | interp_type=types.INTERP_CUBIC, 236 | ) 237 | # center crop and normalize 238 | self.cmn = ops.CropMirrorNormalize( 239 | device=self.device, 240 | dtype=types.FLOAT, 241 | output_layout=types.NCHW, 242 | crop=(224, 224), 243 | mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], 244 | std=[0.228 * 255, 0.224 * 255, 0.225 * 255], 245 | ) 246 | else: 247 | self.resize = ops.RandomResizedCrop( 248 | device=self.device, 249 | size=224, 250 | random_area=(0.08, 1.0), 251 | interp_type=types.INTERP_CUBIC, 252 | ) 253 | # normalize and horizontal flip 254 | self.cmn = ops.CropMirrorNormalize( 255 | device=self.device, 256 | dtype=types.FLOAT, 257 | output_layout=types.NCHW, 258 | mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], 259 | std=[0.228 * 255, 0.224 * 255, 0.225 * 255], 260 | ) 261 | 262 | self.coin05 = ops.random.CoinFlip(probability=0.5) 263 | self.to_int64 = ops.Cast(dtype=types.INT64, device=device) 264 | 265 | def define_graph(self): 266 | """Defines the computational graph for dali operations.""" 267 | 268 | # read images from memory 269 | inputs, labels = self.reader(name="Reader") 270 | images = self.decode(inputs) 271 | 272 | # crop into large and small images 273 | images = self.resize(images) 274 | 275 | if self.validation: 276 | # crop and normalize 277 | images = self.cmn(images) 278 | else: 279 | # normalize and maybe apply horizontal flip with 0.5 chance 280 | images = self.cmn(images, mirror=self.coin05()) 281 | 282 | if self.device == "gpu": 283 | labels = labels.gpu() 284 | # PyTorch expects labels as INT64 285 | labels = self.to_int64(labels) 286 | 287 | return (images, labels) 288 | 289 | 290 | class CustomNormalPipeline(NormalPipeline): 291 | """Initializes the custom pipeline for validation or linear eval training. 292 | This acts as a placeholder and behaves exactly like NormalPipeline. 293 | If you want to do exoteric augmentations, you can just re-write this class. 294 | """ 295 | 296 | pass 297 | 298 | 299 | class ImagenetTransform: 300 | def __init__( 301 | self, 302 | device: str, 303 | brightness: float, 304 | contrast: float, 305 | saturation: float, 306 | hue: float, 307 | gaussian_prob: float = 0.5, 308 | solarization_prob: float = 0.0, 309 | size: int = 224, 310 | min_scale: float = 0.08, 311 | max_scale: float = 1.0, 312 | ): 313 | """Applies Imagenet transformations to a batch of images. 314 | 315 | Args: 316 | device (str): device on which the operations will be performed. 317 | brightness (float): sampled uniformly in [max(0, 1 - brightness), 1 + brightness]. 318 | contrast (float): sampled uniformly in [max(0, 1 - contrast), 1 + contrast]. 319 | saturation (float): sampled uniformly in [max(0, 1 - saturation), 1 + saturation]. 320 | hue (float): sampled uniformly in [-hue, hue]. 321 | gaussian_prob (float, optional): probability of applying gaussian blur. Defaults to 0.5. 322 | solarization_prob (float, optional): probability of applying solarization. Defaults 323 | to 0.0. 324 | size (int, optional): size of the side of the image after transformation. Defaults 325 | to 224. 326 | min_scale (float, optional): minimum scale of the crops. Defaults to 0.08. 327 | max_scale (float, optional): maximum scale of the crops. Defaults to 1.0. 328 | """ 329 | 330 | # random crop 331 | self.random_crop = ops.RandomResizedCrop( 332 | device=device, 333 | size=size, 334 | random_area=(min_scale, max_scale), 335 | interp_type=types.INTERP_CUBIC, 336 | ) 337 | 338 | # color jitter 339 | self.random_color_jitter = RandomColorJitter( 340 | brightness=brightness, 341 | contrast=contrast, 342 | saturation=saturation, 343 | hue=hue, 344 | prob=0.8, 345 | device=device, 346 | ) 347 | 348 | # grayscale conversion 349 | self.random_grayscale = RandomGrayScaleConversion(prob=0.2, device=device) 350 | 351 | # gaussian blur 352 | self.random_gaussian_blur = RandomGaussianBlur(prob=gaussian_prob, device=device) 353 | 354 | # solarization 355 | self.random_solarization = RandomSolarize(prob=solarization_prob) 356 | 357 | # normalize and horizontal flip 358 | self.cmn = ops.CropMirrorNormalize( 359 | device=device, 360 | dtype=types.FLOAT, 361 | output_layout=types.NCHW, 362 | mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], 363 | std=[0.228 * 255, 0.224 * 255, 0.225 * 255], 364 | ) 365 | self.coin05 = ops.random.CoinFlip(probability=0.5) 366 | 367 | self.str = ( 368 | "ImagenetTransform(" 369 | f"random_crop({min_scale}, {max_scale}), " 370 | f"random_color_jitter(brightness={brightness}, " 371 | f"contrast={contrast}, saturation={saturation}, hue={hue}), " 372 | f"random_gray_scale, random_gaussian_blur({gaussian_prob}), " 373 | f"random_solarization({solarization_prob}), " 374 | "crop_mirror_resize())" 375 | ) 376 | 377 | def __str__(self) -> str: 378 | return self.str 379 | 380 | def __call__(self, images): 381 | out = self.random_crop(images) 382 | out = self.random_color_jitter(out) 383 | out = self.random_grayscale(out) 384 | out = self.random_gaussian_blur(out) 385 | out = self.random_solarization(out) 386 | out = self.cmn(out, mirror=self.coin05()) 387 | return out 388 | 389 | 390 | class CustomTransform: 391 | def __init__( 392 | self, 393 | device: str, 394 | brightness: float, 395 | contrast: float, 396 | saturation: float, 397 | hue: float, 398 | gaussian_prob: float = 0.5, 399 | solarization_prob: float = 0.0, 400 | size: int = 224, 401 | min_scale: float = 0.08, 402 | max_scale: float = 1.0, 403 | mean: Sequence[float] = (0.485, 0.456, 0.406), 404 | std: Sequence[float] = (0.228, 0.224, 0.225), 405 | ): 406 | """Applies Custom transformations. 407 | If you want to do exoteric augmentations, you can just re-write this class. 408 | 409 | Args: 410 | device (str): device on which the operations will be performed. 411 | brightness (float): sampled uniformly in [max(0, 1 - brightness), 1 + brightness]. 412 | contrast (float): sampled uniformly in [max(0, 1 - contrast), 1 + contrast]. 413 | saturation (float): sampled uniformly in [max(0, 1 - saturation), 1 + saturation]. 414 | hue (float): sampled uniformly in [-hue, hue]. 415 | gaussian_prob (float, optional): probability of applying gaussian blur. Defaults to 0.5. 416 | solarization_prob (float, optional): probability of applying solarization. Defaults 417 | to 0.0. 418 | size (int, optional): size of the side of the image after transformation. Defaults 419 | to 224. 420 | min_scale (float, optional): minimum scale of the crops. Defaults to 0.08. 421 | max_scale (float, optional): maximum scale of the crops. Defaults to 1.0. 422 | mean (Sequence[float], optional): mean values for normalization. 423 | Defaults to (0.485, 0.456, 0.406). 424 | std (Sequence[float], optional): std values for normalization. 425 | Defaults to (0.228, 0.224, 0.225). 426 | """ 427 | 428 | # random crop 429 | self.random_crop = ops.RandomResizedCrop( 430 | device=device, 431 | size=size, 432 | random_area=(min_scale, max_scale), 433 | interp_type=types.INTERP_CUBIC, 434 | ) 435 | 436 | # color jitter 437 | self.random_color_jitter = RandomColorJitter( 438 | brightness=brightness, 439 | contrast=contrast, 440 | saturation=saturation, 441 | hue=hue, 442 | prob=0.8, 443 | device=device, 444 | ) 445 | 446 | # grayscale conversion 447 | self.random_grayscale = RandomGrayScaleConversion(prob=0.2, device=device) 448 | 449 | # gaussian blur 450 | self.random_gaussian_blur = RandomGaussianBlur(prob=gaussian_prob, device=device) 451 | 452 | # solarization 453 | self.random_solarization = RandomSolarize(prob=solarization_prob) 454 | 455 | # normalize and horizontal flip 456 | self.cmn = ops.CropMirrorNormalize( 457 | device=device, 458 | dtype=types.FLOAT, 459 | output_layout=types.NCHW, 460 | mean=[v * 255 for v in mean], 461 | std=[v * 255 for v in std], 462 | ) 463 | self.coin05 = ops.random.CoinFlip(probability=0.5) 464 | 465 | self.str = ( 466 | "CustomTransform(" 467 | f"random_crop({min_scale}, {max_scale}), " 468 | f"random_color_jitter(brightness={brightness}, " 469 | f"contrast={contrast}, saturation={saturation}, hue={hue}), " 470 | f"random_gray_scale, random_gaussian_blur({gaussian_prob}), " 471 | f"random_solarization({solarization_prob}), " 472 | "crop_mirror_resize())" 473 | ) 474 | 475 | def __call__(self, images): 476 | out = self.random_crop(images) 477 | out = self.random_color_jitter(out) 478 | out = self.random_grayscale(out) 479 | out = self.random_gaussian_blur(out) 480 | out = self.random_solarization(out) 481 | out = self.cmn(out, mirror=self.coin05()) 482 | return out 483 | 484 | def __str__(self): 485 | return self.str 486 | 487 | 488 | class PretrainPipeline(Pipeline): 489 | def __init__( 490 | self, 491 | data_path: Union[str, Path], 492 | batch_size: int, 493 | device: str, 494 | transform: Union[Callable, Iterable], 495 | num_crops: int = 2, 496 | random_shuffle: bool = True, 497 | device_id: int = 0, 498 | shard_id: int = 0, 499 | num_shards: int = 1, 500 | num_threads: int = 4, 501 | seed: int = 12, 502 | no_labels: bool = False, 503 | encode_indexes_into_labels: bool = False, 504 | ): 505 | """Initializes the pipeline for pretraining. 506 | 507 | Args: 508 | data_path (str): directory that contains the data. 509 | batch_size (int): batch size. 510 | device (str): device on which the operation will be performed. 511 | transform (Union[Callable, Iterable]): a transformation or a sequence 512 | of transformations to be applied. 513 | num_crops (int, optional): number of crops. Defaults to 2. 514 | random_shuffle (bool, optional): whether to randomly shuffle the samples. 515 | Defaults to True. 516 | device_id (int, optional): id of the device used to initialize the seed and 517 | for parent class. Defaults to 0. 518 | shard_id (int, optional): id of the shard (chuck of samples). Defaults to 0. 519 | num_shards (int, optional): total number of shards. Defaults to 1. 520 | num_threads (int, optional): number of threads to run in parallel. Defaults to 4. 521 | seed (int, optional): seed for random number generation. Defaults to 12. 522 | no_labels (bool, optional): if the data has no labels. Defaults to False. 523 | encode_indexes_into_labels (bool, optional): uses sample indexes as labels 524 | and then gets the labels from a lookup table. This may use more CPU memory, 525 | so just use when needed. Defaults to False. 526 | """ 527 | 528 | seed += device_id 529 | super().__init__( 530 | batch_size=batch_size, 531 | num_threads=num_threads, 532 | device_id=device_id, 533 | seed=seed, 534 | ) 535 | 536 | self.device = device 537 | 538 | data_path = Path(data_path) 539 | if no_labels: 540 | files = [data_path / f for f in sorted(os.listdir(data_path))] 541 | labels = [-1] * len(files) 542 | self.reader = ops.readers.File( 543 | files=files, 544 | shard_id=shard_id, 545 | num_shards=num_shards, 546 | shuffle_after_epoch=random_shuffle, 547 | labels=labels, 548 | ) 549 | elif encode_indexes_into_labels: 550 | labels = sorted(Path(entry.name) for entry in os.scandir(data_path) if entry.is_dir()) 551 | 552 | data = [ 553 | (data_path / label / file, label_idx) 554 | for label_idx, label in enumerate(labels) 555 | for file in sorted(os.listdir(data_path / label)) 556 | ] 557 | 558 | files = [] 559 | labels = [] 560 | # for debugging 561 | true_labels = [] 562 | 563 | self.conversion_map = [] 564 | for file_idx, (file, label_idx) in enumerate(data): 565 | files.append(file) 566 | labels.append(file_idx) 567 | true_labels.append(label_idx) 568 | self.conversion_map.append(label_idx) 569 | 570 | # debugging 571 | for file, file_idx, label_idx in zip(files, labels, true_labels): 572 | assert self.conversion_map[file_idx] == label_idx 573 | 574 | self.reader = ops.readers.File( 575 | files=files, 576 | shard_id=shard_id, 577 | num_shards=num_shards, 578 | shuffle_after_epoch=random_shuffle, 579 | ) 580 | else: 581 | self.reader = ops.readers.File( 582 | file_root=data_path, 583 | shard_id=shard_id, 584 | num_shards=num_shards, 585 | shuffle_after_epoch=random_shuffle, 586 | ) 587 | 588 | decoder_device = "mixed" if self.device == "gpu" else "cpu" 589 | device_memory_padding = 211025920 if decoder_device == "mixed" else 0 590 | host_memory_padding = 140544512 if decoder_device == "mixed" else 0 591 | self.decode = ops.decoders.Image( 592 | device=decoder_device, 593 | output_type=types.RGB, 594 | device_memory_padding=device_memory_padding, 595 | host_memory_padding=host_memory_padding, 596 | ) 597 | self.to_int64 = ops.Cast(dtype=types.INT64, device=device) 598 | 599 | self.num_crops = num_crops 600 | 601 | # transformations 602 | self.transform = transform 603 | 604 | if isinstance(transform, Iterable): 605 | self.one_transform_per_crop = True 606 | else: 607 | self.one_transform_per_crop = False 608 | self.num_crops = num_crops 609 | 610 | def define_graph(self): 611 | """Defines the computational graph for dali operations.""" 612 | 613 | # read images from memory 614 | inputs, labels = self.reader(name="Reader") 615 | 616 | images = self.decode(inputs) 617 | 618 | if self.one_transform_per_crop: 619 | crops = [transform(images) for transform in self.transform] 620 | else: 621 | crops = [self.transform(images) for i in range(self.num_crops)] 622 | 623 | if self.device == "gpu": 624 | labels = labels.gpu() 625 | # PyTorch expects labels as INT64 626 | labels = self.to_int64(labels) 627 | 628 | return (*crops, labels) 629 | 630 | 631 | class MulticropPretrainPipeline(Pipeline): 632 | def __init__( 633 | self, 634 | data_path: Union[str, Path], 635 | batch_size: int, 636 | device: str, 637 | transforms: List, 638 | num_crops: List[int], 639 | random_shuffle: bool = True, 640 | device_id: int = 0, 641 | shard_id: int = 0, 642 | num_shards: int = 1, 643 | num_threads: int = 4, 644 | seed: int = 12, 645 | no_labels: bool = False, 646 | encode_indexes_into_labels: bool = False, 647 | ): 648 | """Initializes the pipeline for pretraining with multicrop. 649 | 650 | Args: 651 | data_path (str): directory that contains the data. 652 | batch_size (int): batch size. 653 | device (str): device on which the operation will be performed. 654 | transforms (List): list of transformations to be applied. 655 | num_crops (List[int]): number of crops. 656 | random_shuffle (bool, optional): whether to randomly shuffle the samples. 657 | Defaults to True. 658 | device_id (int, optional): id of the device used to initialize the seed and 659 | for parent class. Defaults to 0. 660 | shard_id (int, optional): id of the shard (chuck of samples). Defaults to 0. 661 | num_shards (int, optional): total number of shards. Defaults to 1. 662 | num_threads (int, optional): number of threads to run in parallel. Defaults to 4. 663 | seed (int, optional): seed for random number generation. Defaults to 12. 664 | no_labels (bool, optional): if the data has no labels. Defaults to False. 665 | encode_indexes_into_labels (bool, optional): uses sample indexes as labels 666 | and then gets the labels from a lookup table. This may use more CPU memory, 667 | so just use when needed. Defaults to False. 668 | """ 669 | 670 | seed += device_id 671 | super().__init__( 672 | batch_size=batch_size, 673 | num_threads=num_threads, 674 | device_id=device_id, 675 | seed=seed, 676 | ) 677 | 678 | self.device = device 679 | 680 | data_path = Path(data_path) 681 | if no_labels: 682 | files = [data_path / f for f in sorted(os.listdir(data_path))] 683 | labels = [-1] * len(files) 684 | self.reader = ops.readers.File( 685 | files=files, 686 | shard_id=shard_id, 687 | num_shards=num_shards, 688 | shuffle_after_epoch=random_shuffle, 689 | labels=labels, 690 | ) 691 | elif encode_indexes_into_labels: 692 | labels = sorted(Path(entry.name) for entry in os.scandir(data_path) if entry.is_dir()) 693 | 694 | data = [ 695 | (data_path / label / file, label_idx) 696 | for label_idx, label in enumerate(labels) 697 | for file in sorted(os.listdir(data_path / label)) 698 | ] 699 | 700 | files = [] 701 | labels = [] 702 | # for debugging 703 | true_labels = [] 704 | 705 | self.conversion_map = [] 706 | for file_idx, (file, label_idx) in enumerate(data): 707 | files.append(file) 708 | labels.append(file_idx) 709 | true_labels.append(label_idx) 710 | self.conversion_map.append(label_idx) 711 | 712 | # debugging 713 | for file, file_idx, label_idx in zip(files, labels, true_labels): 714 | assert self.conversion_map[file_idx] == label_idx 715 | 716 | self.reader = ops.readers.File( 717 | files=files, 718 | shard_id=shard_id, 719 | num_shards=num_shards, 720 | shuffle_after_epoch=random_shuffle, 721 | ) 722 | else: 723 | self.reader = ops.readers.File( 724 | file_root=data_path, 725 | shard_id=shard_id, 726 | num_shards=num_shards, 727 | shuffle_after_epoch=random_shuffle, 728 | ) 729 | 730 | decoder_device = "mixed" if self.device == "gpu" else "cpu" 731 | device_memory_padding = 211025920 if decoder_device == "mixed" else 0 732 | host_memory_padding = 140544512 if decoder_device == "mixed" else 0 733 | self.decode = ops.decoders.Image( 734 | device=decoder_device, 735 | output_type=types.RGB, 736 | device_memory_padding=device_memory_padding, 737 | host_memory_padding=host_memory_padding, 738 | ) 739 | self.to_int64 = ops.Cast(dtype=types.INT64, device=device) 740 | 741 | self.num_crops = num_crops 742 | self.transforms = transforms 743 | 744 | assert len(transforms) == len(num_crops) 745 | 746 | def define_graph(self): 747 | """Defines the computational graph for dali operations.""" 748 | 749 | # read images from memory 750 | inputs, labels = self.reader(name="Reader") 751 | images = self.decode(inputs) 752 | 753 | # crop into large and small images 754 | crops = [] 755 | for i, transform in enumerate(self.transforms): 756 | for _ in range(self.num_crops[i]): 757 | crop = transform(images) 758 | crops.append(crop) 759 | 760 | if self.device == "gpu": 761 | labels = labels.gpu() 762 | # PyTorch expects labels as INT64 763 | labels = self.to_int64(labels) 764 | 765 | return (*crops, labels) 766 | -------------------------------------------------------------------------------- /solo/utils/kmeans.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | from typing import Any, Sequence 21 | 22 | import numpy as np 23 | import torch 24 | import torch.distributed as dist 25 | import torch.nn.functional as F 26 | from scipy.sparse import csr_matrix 27 | 28 | 29 | class KMeans: 30 | def __init__( 31 | self, 32 | world_size: int, 33 | rank: int, 34 | num_crops: int, 35 | dataset_size: int, 36 | proj_features_dim: int, 37 | num_prototypes: int, 38 | kmeans_iters: int = 10, 39 | ): 40 | """Class that performs K-Means on the hypersphere. 41 | 42 | Args: 43 | world_size (int): world size. 44 | rank (int): rank of the current process. 45 | num_crops (int): number of crops. 46 | dataset_size (int): total size of the dataset (number of samples). 47 | proj_features_dim (int): number of dimensions of the projected features. 48 | num_prototypes (int): number of prototypes. 49 | kmeans_iters (int, optional): number of iterations for the k-means clustering. 50 | Defaults to 10. 51 | """ 52 | self.world_size = world_size 53 | self.rank = rank 54 | self.num_crops = num_crops 55 | self.dataset_size = dataset_size 56 | self.proj_features_dim = proj_features_dim 57 | self.num_prototypes = num_prototypes 58 | self.kmeans_iters = kmeans_iters 59 | 60 | @staticmethod 61 | def get_indices_sparse(data: np.ndarray): 62 | cols = np.arange(data.size) 63 | M = csr_matrix((cols, (data.ravel(), cols)), shape=(int(data.max()) + 1, data.size)) 64 | return [np.unravel_index(row.data, data.shape) for row in M] 65 | 66 | def cluster_memory( 67 | self, 68 | local_memory_index: torch.Tensor, 69 | local_memory_embeddings: torch.Tensor, 70 | ) -> Sequence[Any]: 71 | """Performs K-Means clustering on the hypersphere and returns centroids and 72 | assignments for each sample. 73 | 74 | Args: 75 | local_memory_index (torch.Tensor): memory bank cointaining indices of the 76 | samples. 77 | local_memory_embeddings (torch.Tensor): memory bank cointaining embeddings 78 | of the samples. 79 | 80 | Returns: 81 | Sequence[Any]: assignments and centroids. 82 | """ 83 | j = 0 84 | device = local_memory_embeddings.device 85 | assignments = -torch.ones(len(self.num_prototypes), self.dataset_size).long() 86 | centroids_list = [] 87 | with torch.no_grad(): 88 | for i_K, K in enumerate(self.num_prototypes): 89 | # run distributed k-means 90 | 91 | # init centroids with elements from memory bank of rank 0 92 | centroids = torch.empty(K, self.proj_features_dim).to(device, non_blocking=True) 93 | if self.rank == 0: 94 | random_idx = torch.randperm(len(local_memory_embeddings[j]))[:K] 95 | assert len(random_idx) >= K, "please reduce the number of centroids" 96 | centroids = local_memory_embeddings[j][random_idx] 97 | if dist.is_available() and dist.is_initialized(): 98 | dist.broadcast(centroids, 0) 99 | 100 | for n_iter in range(self.kmeans_iters + 1): 101 | 102 | # E step 103 | dot_products = torch.mm(local_memory_embeddings[j], centroids.t()) 104 | _, local_assignments = dot_products.max(dim=1) 105 | 106 | # finish 107 | if n_iter == self.kmeans_iters: 108 | break 109 | 110 | # M step 111 | where_helper = self.get_indices_sparse(local_assignments.cpu().numpy()) 112 | counts = torch.zeros(K).to(device, non_blocking=True).int() 113 | emb_sums = torch.zeros(K, self.proj_features_dim).to(device, non_blocking=True) 114 | for k in range(len(where_helper)): 115 | if len(where_helper[k][0]) > 0: 116 | emb_sums[k] = torch.sum( 117 | local_memory_embeddings[j][where_helper[k][0]], 118 | dim=0, 119 | ) 120 | counts[k] = len(where_helper[k][0]) 121 | if dist.is_available() and dist.is_initialized(): 122 | dist.all_reduce(counts) 123 | dist.all_reduce(emb_sums) 124 | mask = counts > 0 125 | centroids[mask] = emb_sums[mask] / counts[mask].unsqueeze(1) 126 | 127 | # normalize centroids 128 | centroids = F.normalize(centroids, dim=1, p=2) 129 | 130 | centroids_list.append(centroids) 131 | 132 | if dist.is_available() and dist.is_initialized(): 133 | # gather the assignments 134 | assignments_all = torch.empty( 135 | self.world_size, 136 | local_assignments.size(0), 137 | dtype=local_assignments.dtype, 138 | device=local_assignments.device, 139 | ) 140 | assignments_all = list(assignments_all.unbind(0)) 141 | 142 | dist_process = dist.all_gather( 143 | assignments_all, local_assignments, async_op=True 144 | ) 145 | dist_process.wait() 146 | assignments_all = torch.cat(assignments_all).cpu() 147 | 148 | # gather the indexes 149 | indexes_all = torch.empty( 150 | self.world_size, 151 | local_memory_index.size(0), 152 | dtype=local_memory_index.dtype, 153 | device=local_memory_index.device, 154 | ) 155 | indexes_all = list(indexes_all.unbind(0)) 156 | dist_process = dist.all_gather(indexes_all, local_memory_index, async_op=True) 157 | dist_process.wait() 158 | indexes_all = torch.cat(indexes_all).cpu() 159 | 160 | else: 161 | assignments_all = local_assignments 162 | indexes_all = local_memory_index 163 | 164 | # log assignments 165 | assignments[i_K][indexes_all] = assignments_all 166 | 167 | # next memory bank to use 168 | j = (j + 1) % self.num_crops 169 | 170 | return assignments, centroids_list 171 | -------------------------------------------------------------------------------- /solo/utils/knn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | from typing import Sequence 21 | 22 | import torch 23 | from torchmetrics.metric import Metric 24 | 25 | 26 | class WeightedKNNClassifier(Metric): 27 | def __init__( 28 | self, 29 | k: int = 20, 30 | T: float = 0.07, 31 | num_chunks: int = 100, 32 | distance_fx: str = "cosine", 33 | epsilon: float = 0.00001, 34 | dist_sync_on_step: bool = False, 35 | ): 36 | """Implements the weighted k-NN classifier used for evaluation. 37 | 38 | Args: 39 | k (int, optional): number of neighbors. Defaults to 20. 40 | T (float, optional): temperature for the exponential. Only used with cosine 41 | distance. Defaults to 0.07. 42 | num_chunks (int, optional): number of chunks of test features. Defaults to 100. 43 | distance_fx (str, optional): Distance function. Accepted arguments: "cosine" or 44 | "euclidean". Defaults to "cosine". 45 | epsilon (float, optional): Small value for numerical stability. Only used with 46 | euclidean distance. Defaults to 0.00001. 47 | dist_sync_on_step (bool, optional): whether to sync distributed values at every 48 | step. Defaults to False. 49 | """ 50 | 51 | super().__init__(dist_sync_on_step=dist_sync_on_step, compute_on_step=False) 52 | 53 | self.k = k 54 | self.T = T 55 | self.num_chunks = num_chunks 56 | self.distance_fx = distance_fx 57 | self.epsilon = epsilon 58 | 59 | self.add_state("train_features", default=[], persistent=False) 60 | self.add_state("train_targets", default=[], persistent=False) 61 | self.add_state("test_features", default=[], persistent=False) 62 | self.add_state("test_targets", default=[], persistent=False) 63 | 64 | def update( 65 | self, 66 | train_features: torch.Tensor = None, 67 | train_targets: torch.Tensor = None, 68 | test_features: torch.Tensor = None, 69 | test_targets: torch.Tensor = None, 70 | ): 71 | """Updates the memory banks. If train (test) features are passed as input, the 72 | corresponding train (test) targets must be passed as well. 73 | 74 | Args: 75 | train_features (torch.Tensor, optional): a batch of train features. Defaults to None. 76 | train_targets (torch.Tensor, optional): a batch of train targets. Defaults to None. 77 | test_features (torch.Tensor, optional): a batch of test features. Defaults to None. 78 | test_targets (torch.Tensor, optional): a batch of test targets. Defaults to None. 79 | """ 80 | assert (train_features is None) == (train_targets is None) 81 | assert (test_features is None) == (test_targets is None) 82 | 83 | if train_features is not None: 84 | assert train_features.size(0) == train_targets.size(0) 85 | self.train_features.append(train_features) 86 | self.train_targets.append(train_targets) 87 | 88 | if test_features is not None: 89 | assert test_features.size(0) == test_targets.size(0) 90 | self.test_features.append(test_features) 91 | self.test_targets.append(test_targets) 92 | 93 | @torch.no_grad() 94 | def compute(self) -> Sequence[float]: 95 | """Computes weighted k-NN accuracy @1 and @5. If cosine distance is selected, 96 | the weight is computed using the exponential of the temperature scaled cosine 97 | distance of the samples. If euclidean distance is selected, the weight corresponds 98 | to the inverse of the euclidean distance. 99 | 100 | Returns: 101 | Sequence[float]: k-NN accuracy @1 and @5. 102 | """ 103 | train_features = torch.cat(self.train_features) 104 | train_targets = torch.cat(self.train_targets) 105 | test_features = torch.cat(self.test_features) 106 | test_targets = torch.cat(self.test_targets) 107 | 108 | top1, top5, total = 0.0, 0.0, 0 109 | num_classes = torch.unique(test_targets).numel() 110 | num_test_images = test_targets.size(0) 111 | chunk_size = max(1, num_test_images // self.num_chunks) 112 | k = min(self.k, train_targets.size(0)) 113 | retrieval_one_hot = torch.zeros(k, num_classes).to(train_features.device) 114 | for idx in range(0, num_test_images, chunk_size): 115 | # get the features for test images 116 | features = test_features[idx : min((idx + chunk_size), num_test_images), :] 117 | targets = test_targets[idx : min((idx + chunk_size), num_test_images)] 118 | batch_size = targets.size(0) 119 | 120 | # calculate the dot product and compute top-k neighbors 121 | if self.distance_fx == "cosine": 122 | similarity = torch.mm(features, train_features.t()) 123 | elif self.distance_fx == "euclidean": 124 | similarity = 1 / (torch.cdist(features, train_features) + self.epsilon) 125 | else: 126 | raise NotImplementedError 127 | 128 | distances, indices = similarity.topk(k, largest=True, sorted=True) 129 | candidates = train_targets.view(1, -1).expand(batch_size, -1) 130 | retrieved_neighbors = torch.gather(candidates, 1, indices) 131 | 132 | retrieval_one_hot.resize_(batch_size * k, num_classes).zero_() 133 | # import pdb; pdb.set_trace() 134 | 135 | retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1) 136 | 137 | if self.distance_fx == "cosine": 138 | distances = distances.clone().div_(self.T).exp_() 139 | 140 | probs = torch.sum( 141 | torch.mul( 142 | retrieval_one_hot.view(batch_size, -1, num_classes), 143 | distances.view(batch_size, -1, 1), 144 | ), 145 | 1, 146 | ) 147 | _, predictions = probs.sort(1, True) 148 | 149 | # find the predictions that match the target 150 | correct = predictions.eq(targets.data.view(-1, 1)) 151 | # import pdb; pdb.set_trace() 152 | top1 = top1 + correct.narrow(1, 0, 1).sum().item() 153 | top5 = ( 154 | top5 + correct.narrow(1, 0, min(5, k)).sum().item() 155 | ) # top5 does not make sense if k < 5 156 | total += targets.size(0) 157 | 158 | top1 = top1 * 100.0 / total 159 | top5 = top5 * 100.0 / total 160 | 161 | self.reset() 162 | 163 | return top1, top5 164 | -------------------------------------------------------------------------------- /solo/utils/lars.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | # Copied from Pytorch Lightning (https://github.com/PyTorchLightning/pytorch-lightning/) 21 | # with extra documentations. 22 | 23 | 24 | import torch 25 | from torch.optim import Optimizer 26 | 27 | 28 | class LARSWrapper: 29 | def __init__( 30 | self, 31 | optimizer: Optimizer, 32 | eta: float = 1e-3, 33 | clip: bool = False, 34 | eps: float = 1e-8, 35 | exclude_bias_n_norm: bool = False, 36 | ): 37 | """Wrapper that adds LARS scheduling to any optimizer. 38 | This helps stability with huge batch sizes. 39 | 40 | Args: 41 | optimizer (Optimizer): torch optimizer. 42 | eta (float, optional): trust coefficient. Defaults to 1e-3. 43 | clip (bool, optional): clip gradient values. Defaults to False. 44 | eps (float, optional): adaptive_lr stability coefficient. Defaults to 1e-8. 45 | exclude_bias_n_norm (bool, optional): exclude bias and normalization layers from lars. 46 | Defaults to False. 47 | """ 48 | 49 | self.optim = optimizer 50 | self.eta = eta 51 | self.eps = eps 52 | self.clip = clip 53 | self.exclude_bias_n_norm = exclude_bias_n_norm 54 | 55 | # transfer optim methods 56 | self.state_dict = self.optim.state_dict 57 | self.load_state_dict = self.optim.load_state_dict 58 | self.zero_grad = self.optim.zero_grad 59 | self.add_param_group = self.optim.add_param_group 60 | 61 | self.__setstate__ = self.optim.__setstate__ # type: ignore 62 | self.__getstate__ = self.optim.__getstate__ # type: ignore 63 | self.__repr__ = self.optim.__repr__ # type: ignore 64 | 65 | @property 66 | def defaults(self): 67 | return self.optim.defaults 68 | 69 | @defaults.setter 70 | def defaults(self, defaults): 71 | self.optim.defaults = defaults 72 | 73 | @property # type: ignore 74 | def __class__(self): 75 | return Optimizer 76 | 77 | @property 78 | def state(self): 79 | return self.optim.state 80 | 81 | @state.setter 82 | def state(self, state): 83 | self.optim.state = state 84 | 85 | @property 86 | def param_groups(self): 87 | return self.optim.param_groups 88 | 89 | @param_groups.setter 90 | def param_groups(self, value): 91 | self.optim.param_groups = value 92 | 93 | @torch.no_grad() 94 | def step(self, closure=None): 95 | weight_decays = [] 96 | 97 | for group in self.optim.param_groups: 98 | weight_decay = group.get("weight_decay", 0) 99 | weight_decays.append(weight_decay) 100 | 101 | # reset weight decay 102 | group["weight_decay"] = 0 103 | 104 | # update the parameters 105 | for p in group["params"]: 106 | if p.grad is not None and (p.ndim != 1 or not self.exclude_bias_n_norm): 107 | self.update_p(p, group, weight_decay) 108 | 109 | # update the optimizer 110 | self.optim.step(closure=closure) 111 | 112 | # return weight decay control to optimizer 113 | for group_idx, group in enumerate(self.optim.param_groups): 114 | group["weight_decay"] = weight_decays[group_idx] 115 | 116 | def update_p(self, p, group, weight_decay): 117 | # calculate new norms 118 | p_norm = torch.norm(p.data) 119 | g_norm = torch.norm(p.grad.data) 120 | 121 | if p_norm != 0 and g_norm != 0: 122 | # calculate new lr 123 | new_lr = (self.eta * p_norm) / (g_norm + p_norm * weight_decay + self.eps) 124 | 125 | # clip lr 126 | if self.clip: 127 | new_lr = min(new_lr / group["lr"], 1) 128 | 129 | # update params with clipped lr 130 | p.grad.data += weight_decay * p.data 131 | p.grad.data *= new_lr 132 | -------------------------------------------------------------------------------- /solo/utils/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | from typing import Dict, List, Sequence 21 | 22 | import torch 23 | 24 | 25 | def accuracy_at_k( 26 | outputs: torch.Tensor, targets: torch.Tensor, top_k: Sequence[int] = (1, 5) 27 | ) -> Sequence[int]: 28 | """Computes the accuracy over the k top predictions for the specified values of k. 29 | 30 | Args: 31 | outputs (torch.Tensor): output of a classifier (logits or probabilities). 32 | targets (torch.Tensor): ground truth labels. 33 | top_k (Sequence[int], optional): sequence of top k values to compute the accuracy over. 34 | Defaults to (1, 5). 35 | 36 | Returns: 37 | Sequence[int]: accuracies at the desired k. 38 | """ 39 | 40 | with torch.no_grad(): 41 | maxk = max(top_k) 42 | batch_size = targets.size(0) 43 | 44 | _, pred = outputs.topk(maxk, 1, True, True) 45 | pred = pred.t() 46 | correct = pred.eq(targets.view(1, -1).expand_as(pred)) 47 | 48 | res = [] 49 | for k in top_k: 50 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) 51 | res.append(correct_k.mul_(100.0 / batch_size)) 52 | return res 53 | 54 | 55 | def weighted_mean(outputs: List[Dict], key: str, batch_size_key: str) -> float: 56 | """Computes the mean of the values of a key weighted by the batch size. 57 | 58 | Args: 59 | outputs (List[Dict]): list of dicts containing the outputs of a validation step. 60 | key (str): key of the metric of interest. 61 | batch_size_key (str): key of batch size values. 62 | 63 | Returns: 64 | float: weighted mean of the values of a key 65 | """ 66 | 67 | value = 0 68 | n = 0 69 | for out in outputs: 70 | value += out[batch_size_key] * out[key] 71 | n += out[batch_size_key] 72 | value = value / n 73 | return value.squeeze(0) 74 | -------------------------------------------------------------------------------- /solo/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import math 21 | import warnings 22 | from typing import List, Tuple 23 | 24 | import torch 25 | import torch.distributed as dist 26 | import torch.nn as nn 27 | 28 | 29 | def _1d_filter(tensor: torch.Tensor) -> torch.Tensor: 30 | return tensor.isfinite() 31 | 32 | 33 | def _2d_filter(tensor: torch.Tensor) -> torch.Tensor: 34 | return tensor.isfinite().all(dim=1) 35 | 36 | 37 | def _single_input_filter(tensor: torch.Tensor) -> Tuple[torch.Tensor]: 38 | if len(tensor.size()) == 1: 39 | filter_func = _1d_filter 40 | elif len(tensor.size()) == 2: 41 | filter_func = _2d_filter 42 | else: 43 | raise RuntimeError("Only 1d and 2d tensors are supported.") 44 | 45 | selected = filter_func(tensor) 46 | tensor = tensor[selected] 47 | 48 | return tensor, selected 49 | 50 | 51 | def _multi_input_filter(tensors: List[torch.Tensor]) -> Tuple[torch.Tensor]: 52 | if len(tensors[0].size()) == 1: 53 | filter_func = _1d_filter 54 | elif len(tensors[0].size()) == 2: 55 | filter_func = _2d_filter 56 | else: 57 | raise RuntimeError("Only 1d and 2d tensors are supported.") 58 | 59 | selected = filter_func(tensors[0]) 60 | for tensor in tensors[1:]: 61 | selected = torch.logical_and(selected, filter_func(tensor)) 62 | tensors = [tensor[selected] for tensor in tensors] 63 | 64 | return tensors, selected 65 | 66 | 67 | def filter_inf_n_nan(tensors: List[torch.Tensor], return_indexes: bool = False): 68 | """Filters out inf and nans from any tensor. 69 | This is usefull when there are instability issues, 70 | which cause a small number of values to go bad. 71 | 72 | Args: 73 | tensor (List): tensor to remove nans and infs from. 74 | 75 | Returns: 76 | torch.Tensor: filtered view of the tensor without nans or infs. 77 | """ 78 | 79 | if isinstance(tensors, torch.Tensor): 80 | tensors, selected = _single_input_filter(tensors) 81 | else: 82 | tensors, selected = _multi_input_filter(tensors) 83 | 84 | if return_indexes: 85 | return tensors, selected 86 | return tensors 87 | 88 | 89 | class FilterInfNNan(nn.Module): 90 | def __init__(self, module): 91 | """Layer that filters out inf and nans from any tensor. 92 | This is usefull when there are instability issues, 93 | which cause a small number of values to go bad. 94 | 95 | Args: 96 | tensor (List): tensor to remove nans and infs from. 97 | 98 | Returns: 99 | torch.Tensor: filtered view of the tensor without nans or infs. 100 | """ 101 | super().__init__() 102 | 103 | self.module = module 104 | 105 | def forward(self, x: torch.Tensor) -> torch.Tensor: 106 | out = self.module(x) 107 | out = filter_inf_n_nan(out) 108 | return out 109 | 110 | def __getattr__(self, name): 111 | try: 112 | return super().__getattr__(name) 113 | except AttributeError: 114 | if name == "module": 115 | raise AttributeError() 116 | return getattr(self.module, name) 117 | 118 | 119 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 120 | """Copy & paste from PyTorch official master until it's in a few official releases - RW 121 | Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 122 | """ 123 | 124 | def norm_cdf(x): 125 | """Computes standard normal cumulative distribution function""" 126 | 127 | return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 128 | 129 | if (mean < a - 2 * std) or (mean > b + 2 * std): 130 | warnings.warn( 131 | "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 132 | "The distribution of values may be incorrect.", 133 | stacklevel=2, 134 | ) 135 | 136 | with torch.no_grad(): 137 | # Values are generated by using a truncated uniform distribution and 138 | # then using the inverse CDF for the normal distribution. 139 | # Get upper and lower cdf values 140 | l = norm_cdf((a - mean) / std) 141 | u = norm_cdf((b - mean) / std) 142 | 143 | # Uniformly fill tensor with values from [l, u], then translate to 144 | # [2l-1, 2u-1]. 145 | tensor.uniform_(2 * l - 1, 2 * u - 1) 146 | 147 | # Use inverse cdf transform for normal distribution to get truncated 148 | # standard normal 149 | tensor.erfinv_() 150 | 151 | # Transform to proper mean, std 152 | tensor.mul_(std * math.sqrt(2.0)) 153 | tensor.add_(mean) 154 | 155 | # Clamp to ensure it's in the proper range 156 | tensor.clamp_(min=a, max=b) 157 | return tensor 158 | 159 | 160 | def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): 161 | """Copy & paste from PyTorch official master until it's in a few official releases - RW 162 | Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 163 | """ 164 | 165 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 166 | 167 | 168 | class GatherLayer(torch.autograd.Function): 169 | """Gathers tensors from all processes, supporting backward propagation.""" 170 | 171 | @staticmethod 172 | def forward(ctx, input): 173 | ctx.save_for_backward(input) 174 | if dist.is_available() and dist.is_initialized(): 175 | output = [torch.zeros_like(input) for _ in range(dist.get_world_size())] 176 | dist.all_gather(output, input) 177 | else: 178 | output = [input] 179 | return tuple(output) 180 | 181 | @staticmethod 182 | def backward(ctx, *grads): 183 | (input,) = ctx.saved_tensors 184 | if dist.is_available() and dist.is_initialized(): 185 | grad_out = torch.zeros_like(input) 186 | grad_out[:] = grads[dist.get_rank()] 187 | else: 188 | grad_out = grads[0] 189 | return grad_out 190 | 191 | 192 | def gather(X, dim=0): 193 | """Gathers tensors from all processes, supporting backward propagation.""" 194 | return torch.cat(GatherLayer.apply(X), dim=dim) 195 | -------------------------------------------------------------------------------- /solo/utils/momentum.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import math 21 | 22 | import torch 23 | from torch import nn 24 | 25 | 26 | @torch.no_grad() 27 | def initialize_momentum_params(online_net: nn.Module, momentum_net: nn.Module): 28 | """Copies the parameters of the online network to the momentum network. 29 | 30 | Args: 31 | online_net (nn.Module): online network (e.g. online encoder, online projection, etc...). 32 | momentum_net (nn.Module): momentum network (e.g. momentum encoder, 33 | momentum projection, etc...). 34 | """ 35 | 36 | params_online = online_net.parameters() 37 | params_momentum = momentum_net.parameters() 38 | for po, pm in zip(params_online, params_momentum): 39 | pm.data.copy_(po.data) 40 | pm.requires_grad = False 41 | 42 | 43 | class MomentumUpdater: 44 | def __init__(self, base_tau: float = 0.996, final_tau: float = 1.0): 45 | """Updates momentum parameters using exponential moving average. 46 | 47 | Args: 48 | base_tau (float, optional): base value of the weight decrease coefficient 49 | (should be in [0,1]). Defaults to 0.996. 50 | final_tau (float, optional): final value of the weight decrease coefficient 51 | (should be in [0,1]). Defaults to 1.0. 52 | """ 53 | 54 | super().__init__() 55 | 56 | assert 0 <= base_tau <= 1 57 | assert 0 <= final_tau <= 1 and base_tau <= final_tau 58 | 59 | self.base_tau = base_tau 60 | self.cur_tau = base_tau 61 | self.final_tau = final_tau 62 | 63 | @torch.no_grad() 64 | def update(self, online_net: nn.Module, momentum_net: nn.Module): 65 | """Performs the momentum update for each param group. 66 | 67 | Args: 68 | online_net (nn.Module): online network (e.g. online encoder, online projection, etc...). 69 | momentum_net (nn.Module): momentum network (e.g. momentum encoder, 70 | momentum projection, etc...). 71 | """ 72 | 73 | for op, mp in zip(online_net.parameters(), momentum_net.parameters()): 74 | mp.data = self.cur_tau * mp.data + (1 - self.cur_tau) * op.data 75 | 76 | def update_tau(self, cur_step: int, max_steps: int): 77 | """Computes the next value for the weighting decrease coefficient tau using cosine annealing. 78 | 79 | Args: 80 | cur_step (int): number of gradient steps so far. 81 | max_steps (int): overall number of gradient steps in the whole training. 82 | """ 83 | 84 | self.cur_tau = ( 85 | self.final_tau 86 | - (self.final_tau - self.base_tau) * (math.cos(math.pi * cur_step / max_steps) + 1) / 2 87 | ) 88 | -------------------------------------------------------------------------------- /solo/utils/pretrain_dataloader.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import os 21 | import random 22 | from pathlib import Path 23 | from typing import Any, Callable, Iterable, List, Optional, Sequence, Type, Union 24 | 25 | import torch 26 | import torchvision 27 | from PIL import Image, ImageFilter, ImageOps 28 | from torch.utils.data import DataLoader 29 | from torch.utils.data.dataset import Dataset 30 | from torchvision import transforms 31 | from torchvision.datasets import STL10, ImageFolder 32 | 33 | 34 | def dataset_with_index(DatasetClass: Type[Dataset]) -> Type[Dataset]: 35 | """Factory for datasets that also returns the data index. 36 | 37 | Args: 38 | DatasetClass (Type[Dataset]): Dataset class to be wrapped. 39 | 40 | Returns: 41 | Type[Dataset]: dataset with index. 42 | """ 43 | 44 | class DatasetWithIndex(DatasetClass): 45 | def __getitem__(self, index): 46 | data = super().__getitem__(index) 47 | return (index, *data) 48 | 49 | return DatasetWithIndex 50 | 51 | 52 | class CustomDatasetWithoutLabels(Dataset): 53 | def __init__(self, root, transform=None): 54 | self.root = Path(root) 55 | self.transform = transform 56 | self.images = os.listdir(root) 57 | 58 | def __getitem__(self, index): 59 | path = self.root / self.images[index] 60 | x = Image.open(path).convert("RGB") 61 | if self.transform is not None: 62 | x = self.transform(x) 63 | return x, -1 64 | 65 | def __len__(self): 66 | return len(self.images) 67 | 68 | 69 | class GaussianBlur: 70 | def __init__(self, sigma: Sequence[float] = [0.1, 2.0]): 71 | """Gaussian blur as a callable object. 72 | 73 | Args: 74 | sigma (Sequence[float]): range to sample the radius of the gaussian blur filter. 75 | Defaults to [0.1, 2.0]. 76 | """ 77 | 78 | self.sigma = sigma 79 | 80 | def __call__(self, x: torch.Tensor) -> torch.Tensor: 81 | """Applies gaussian blur to an input image. 82 | 83 | Args: 84 | x (torch.Tensor): an image in the tensor format. 85 | 86 | Returns: 87 | torch.Tensor: returns a blurred image. 88 | """ 89 | 90 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 91 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 92 | return x 93 | 94 | 95 | class Solarization: 96 | """Solarization as a callable object.""" 97 | 98 | def __call__(self, img: Image) -> Image: 99 | """Applies solarization to an input image. 100 | 101 | Args: 102 | img (Image): an image in the PIL.Image format. 103 | 104 | Returns: 105 | Image: a solarized image. 106 | """ 107 | 108 | return ImageOps.solarize(img) 109 | 110 | 111 | class NCropAugmentation: 112 | def __init__(self, transform: Union[Callable, Sequence], num_crops: Optional[int] = None): 113 | """Creates a pipeline that apply a transformation pipeline multiple times. 114 | 115 | Args: 116 | transform (Union[Callable, Sequence]): transformation pipeline or list of 117 | transformation pipelines. 118 | num_crops: if transformation pipeline is not a list, applies the same 119 | pipeline num_crops times, if it is a list, this is ignored and each 120 | element of the list is applied once. 121 | """ 122 | 123 | self.transform = transform 124 | 125 | if isinstance(transform, Iterable): 126 | self.one_transform_per_crop = True 127 | assert num_crops == len(transform) 128 | else: 129 | self.one_transform_per_crop = False 130 | self.num_crops = num_crops 131 | 132 | def __call__(self, x: Image) -> List[torch.Tensor]: 133 | """Applies transforms n times to generate n crops. 134 | 135 | Args: 136 | x (Image): an image in the PIL.Image format. 137 | 138 | Returns: 139 | List[torch.Tensor]: an image in the tensor format. 140 | """ 141 | 142 | if self.one_transform_per_crop: 143 | return [transform(x) for transform in self.transform] 144 | else: 145 | return [self.transform(x) for _ in range(self.num_crops)] 146 | 147 | 148 | class BaseTransform: 149 | """Adds callable base class to implement different transformation pipelines.""" 150 | 151 | def __call__(self, x: Image) -> torch.Tensor: 152 | return self.transform(x) 153 | 154 | def __repr__(self) -> str: 155 | return str(self.transform) 156 | 157 | 158 | class CifarTransform(BaseTransform): 159 | def __init__( 160 | self, 161 | brightness: float, 162 | contrast: float, 163 | saturation: float, 164 | hue: float, 165 | gaussian_prob: float = 0.0, 166 | solarization_prob: float = 0.0, 167 | min_scale: float = 0.08, 168 | ): 169 | """Applies cifar transformations. 170 | 171 | Args: 172 | brightness (float): sampled uniformly in [max(0, 1 - brightness), 1 + brightness]. 173 | contrast (float): sampled uniformly in [max(0, 1 - contrast), 1 + contrast]. 174 | saturation (float): sampled uniformly in [max(0, 1 - saturation), 1 + saturation]. 175 | hue (float): sampled uniformly in [-hue, hue]. 176 | gaussian_prob (float, optional): probability of applying gaussian blur. Defaults to 0.0. 177 | solarization_prob (float, optional): probability of applying solarization. Defaults 178 | to 0.0. 179 | min_scale (float, optional): minimum scale of the crops. Defaults to 0.08. 180 | """ 181 | 182 | super().__init__() 183 | 184 | self.transform = transforms.Compose( 185 | [ 186 | transforms.RandomResizedCrop( 187 | (32, 32), 188 | scale=(min_scale, 1.0), 189 | interpolation=transforms.InterpolationMode.BICUBIC, 190 | ), 191 | transforms.RandomApply( 192 | [transforms.ColorJitter(brightness, contrast, saturation, hue)], p=0.8 193 | ), 194 | transforms.RandomGrayscale(p=0.2), 195 | transforms.RandomApply([GaussianBlur()], p=gaussian_prob), 196 | transforms.RandomApply([Solarization()], p=solarization_prob), 197 | transforms.RandomHorizontalFlip(p=0.5), 198 | transforms.ToTensor(), 199 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)), 200 | ] 201 | ) 202 | 203 | 204 | class STLTransform(BaseTransform): 205 | def __init__( 206 | self, 207 | brightness: float, 208 | contrast: float, 209 | saturation: float, 210 | hue: float, 211 | gaussian_prob: float = 0.0, 212 | solarization_prob: float = 0.0, 213 | min_scale: float = 0.08, 214 | ): 215 | """Applies STL10 transformations. 216 | 217 | Args: 218 | brightness (float): sampled uniformly in [max(0, 1 - brightness), 1 + brightness]. 219 | contrast (float): sampled uniformly in [max(0, 1 - contrast), 1 + contrast]. 220 | saturation (float): sampled uniformly in [max(0, 1 - saturation), 1 + saturation]. 221 | hue (float): sampled uniformly in [-hue, hue]. 222 | gaussian_prob (float, optional): probability of applying gaussian blur. Defaults to 0.0. 223 | solarization_prob (float, optional): probability of applying solarization. Defaults 224 | to 0.0. 225 | min_scale (float, optional): minimum scale of the crops. Defaults to 0.08. 226 | """ 227 | 228 | super().__init__() 229 | self.transform = transforms.Compose( 230 | [ 231 | transforms.RandomResizedCrop( 232 | (96, 96), 233 | scale=(min_scale, 1.0), 234 | interpolation=transforms.InterpolationMode.BICUBIC, 235 | ), 236 | transforms.RandomApply( 237 | [transforms.ColorJitter(brightness, contrast, saturation, hue)], p=0.8 238 | ), 239 | transforms.RandomGrayscale(p=0.2), 240 | transforms.RandomApply([GaussianBlur()], p=gaussian_prob), 241 | transforms.RandomApply([Solarization()], p=solarization_prob), 242 | transforms.RandomHorizontalFlip(p=0.5), 243 | transforms.ToTensor(), 244 | transforms.Normalize((0.4914, 0.4823, 0.4466), (0.247, 0.243, 0.261)), 245 | ] 246 | ) 247 | 248 | 249 | class ImagenetTransform(BaseTransform): 250 | def __init__( 251 | self, 252 | brightness: float, 253 | contrast: float, 254 | saturation: float, 255 | hue: float, 256 | gaussian_prob: float = 0.5, 257 | solarization_prob: float = 0.0, 258 | size: int = 224, 259 | min_scale: float = 0.08, 260 | ): 261 | """Class that applies Imagenet transformations. 262 | 263 | Args: 264 | brightness (float): sampled uniformly in [max(0, 1 - brightness), 1 + brightness]. 265 | contrast (float): sampled uniformly in [max(0, 1 - contrast), 1 + contrast]. 266 | saturation (float): sampled uniformly in [max(0, 1 - saturation), 1 + saturation]. 267 | hue (float): sampled uniformly in [-hue, hue]. 268 | gaussian_prob (float, optional): probability of applying gaussian blur. Defaults to 0.0. 269 | solarization_prob (float, optional): probability of applying solarization. Defaults 270 | to 0.0. 271 | min_scale (float, optional): minimum scale of the crops. Defaults to 0.08. 272 | size (int, optional): size of the crop. Defaults to 224. 273 | """ 274 | 275 | super().__init__() 276 | self.transform = transforms.Compose( 277 | [ 278 | transforms.RandomResizedCrop( 279 | size, 280 | scale=(min_scale, 1.0), 281 | interpolation=transforms.InterpolationMode.BICUBIC, 282 | ), 283 | transforms.RandomApply( 284 | [transforms.ColorJitter(brightness, contrast, saturation, hue)], 285 | p=0.8, 286 | ), 287 | transforms.RandomGrayscale(p=0.2), 288 | transforms.RandomApply([GaussianBlur()], p=gaussian_prob), 289 | transforms.RandomApply([Solarization()], p=solarization_prob), 290 | transforms.RandomHorizontalFlip(p=0.5), 291 | transforms.ToTensor(), 292 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.228, 0.224, 0.225)), 293 | ] 294 | ) 295 | 296 | 297 | class CustomTransform(BaseTransform): 298 | def __init__( 299 | self, 300 | brightness: float, 301 | contrast: float, 302 | saturation: float, 303 | hue: float, 304 | gaussian_prob: float = 0.5, 305 | solarization_prob: float = 0.0, 306 | min_scale: float = 0.08, 307 | size: int = 224, 308 | mean: Sequence[float] = (0.485, 0.456, 0.406), 309 | std: Sequence[float] = (0.228, 0.224, 0.225), 310 | ): 311 | """Class that applies Custom transformations. 312 | If you want to do exoteric augmentations, you can just re-write this class. 313 | 314 | Args: 315 | brightness (float): sampled uniformly in [max(0, 1 - brightness), 1 + brightness]. 316 | contrast (float): sampled uniformly in [max(0, 1 - contrast), 1 + contrast]. 317 | saturation (float): sampled uniformly in [max(0, 1 - saturation), 1 + saturation]. 318 | hue (float): sampled uniformly in [-hue, hue]. 319 | gaussian_prob (float, optional): probability of applying gaussian blur. Defaults to 0.0. 320 | solarization_prob (float, optional): probability of applying solarization. Defaults 321 | to 0.0. 322 | min_scale (float, optional): minimum scale of the crops. Defaults to 0.08. 323 | size (int, optional): size of the crop. Defaults to 224. 324 | mean (Sequence[float], optional): mean values for normalization. 325 | Defaults to (0.485, 0.456, 0.406). 326 | std (Sequence[float], optional): std values for normalization. 327 | Defaults to (0.228, 0.224, 0.225). 328 | """ 329 | 330 | super().__init__() 331 | self.transform = transforms.Compose( 332 | [ 333 | transforms.RandomResizedCrop( 334 | size, 335 | scale=(min_scale, 1.0), 336 | interpolation=transforms.InterpolationMode.BICUBIC, 337 | ), 338 | transforms.RandomApply( 339 | [transforms.ColorJitter(brightness, contrast, saturation, hue)], 340 | p=0.8, 341 | ), 342 | transforms.RandomGrayscale(p=0.2), 343 | transforms.RandomApply([GaussianBlur()], p=gaussian_prob), 344 | transforms.RandomApply([Solarization()], p=solarization_prob), 345 | transforms.RandomHorizontalFlip(p=0.5), 346 | transforms.ToTensor(), 347 | transforms.Normalize(mean=mean, std=std), 348 | ] 349 | ) 350 | 351 | 352 | class MulticropAugmentation: 353 | def __init__( 354 | self, 355 | transform: Callable, 356 | size_crops: Sequence[int], 357 | num_crops: Sequence[int], 358 | min_scales: Sequence[float], 359 | max_scale_crops: Sequence[float], 360 | ): 361 | """Class that applies multi crop augmentation. 362 | 363 | Args: 364 | transform (Callable): transformation callable without cropping. 365 | size_crops (Sequence[int]): a sequence of sizes of the crops. 366 | num_crops (Sequence[int]): a sequence number of crops per crop size. 367 | min_scales (Sequence[float]): sequence of minimum crop scales per crop 368 | size. 369 | max_scale_crops (Sequence[float]): sequence of maximum crop scales per crop 370 | size. 371 | """ 372 | 373 | self.size_crops = size_crops 374 | self.num_crops = num_crops 375 | self.min_scales = min_scales 376 | self.max_scale_crops = max_scale_crops 377 | 378 | self.transforms = [] 379 | for i in range(len(size_crops)): 380 | rrc = transforms.RandomResizedCrop( 381 | size_crops[i], 382 | scale=(min_scales[i], max_scale_crops[i]), 383 | interpolation=transforms.InterpolationMode.BICUBIC, 384 | ) 385 | full_transform = transforms.Compose([rrc, transform]) 386 | self.transforms.append(full_transform) 387 | 388 | def __call__(self, x: Image) -> List[torch.Tensor]: 389 | """Applies multi crop augmentations. 390 | 391 | Args: 392 | x (Image): an image in the PIL.Image format. 393 | 394 | Returns: 395 | List[torch.Tensor]: a list of crops in the tensor format. 396 | """ 397 | 398 | imgs = [] 399 | for n, transform in zip(self.num_crops, self.transforms): 400 | imgs.extend([transform(x) for i in range(n)]) 401 | return imgs 402 | 403 | 404 | class MulticropCifarTransform(BaseTransform): 405 | def __init__(self): 406 | """Class that applies multicrop transform for CIFAR""" 407 | 408 | super().__init__() 409 | 410 | self.transform = transforms.Compose( 411 | [ 412 | transforms.RandomHorizontalFlip(p=0.5), 413 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), 414 | transforms.RandomGrayscale(p=0.2), 415 | transforms.ToTensor(), 416 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)), 417 | ] 418 | ) 419 | 420 | 421 | class MulticropSTLTransform(BaseTransform): 422 | def __init__(self): 423 | """Class that applies multicrop transform for STL10""" 424 | 425 | super().__init__() 426 | self.transform = transforms.Compose( 427 | [ 428 | transforms.RandomHorizontalFlip(p=0.5), 429 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), 430 | transforms.RandomGrayscale(p=0.2), 431 | transforms.ToTensor(), 432 | transforms.Normalize((0.4914, 0.4823, 0.4466), (0.247, 0.243, 0.261)), 433 | ] 434 | ) 435 | 436 | 437 | class MulticropImagenetTransform(BaseTransform): 438 | def __init__( 439 | self, 440 | brightness: float, 441 | contrast: float, 442 | saturation: float, 443 | hue: float, 444 | gaussian_prob: float = 0.5, 445 | solarization_prob: float = 0.0, 446 | ): 447 | """Class that applies multicrop transform for Imagenet. 448 | 449 | Args: 450 | brightness (float): sampled uniformly in [max(0, 1 - brightness), 1 + brightness]. 451 | contrast (float): sampled uniformly in [max(0, 1 - contrast), 1 + contrast]. 452 | saturation (float): sampled uniformly in [max(0, 1 - saturation), 1 + saturation]. 453 | hue (float): sampled uniformly in [-hue, hue]. 454 | gaussian_prob (float, optional): probability of applying gaussian blur. Defaults to 0.5. 455 | solarization_prob (float, optional): minimum scale of the crops. Defaults to 0.0. 456 | """ 457 | 458 | super().__init__() 459 | self.transform = transforms.Compose( 460 | [ 461 | transforms.RandomApply( 462 | [transforms.ColorJitter(brightness, contrast, saturation, hue)], 463 | p=0.8, 464 | ), 465 | transforms.RandomGrayscale(p=0.2), 466 | transforms.RandomApply([GaussianBlur()], p=gaussian_prob), 467 | transforms.RandomApply([Solarization()], p=solarization_prob), 468 | transforms.RandomHorizontalFlip(p=0.5), 469 | transforms.ToTensor(), 470 | transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.228, 0.224, 0.225)), 471 | ] 472 | ) 473 | 474 | 475 | class MulticropCustomTransform(BaseTransform): 476 | def __init__( 477 | self, 478 | brightness: float, 479 | contrast: float, 480 | saturation: float, 481 | hue: float, 482 | gaussian_prob: float = 0.5, 483 | solarization_prob: float = 0.0, 484 | mean: Sequence[float] = (0.485, 0.456, 0.406), 485 | std: Sequence[float] = (0.228, 0.224, 0.225), 486 | ): 487 | """Class that applies multicrop transform for Custom Datasets. 488 | If you want to do exoteric augmentations, you can just re-write this class. 489 | 490 | Args: 491 | brightness (float): sampled uniformly in [max(0, 1 - brightness), 1 + brightness]. 492 | contrast (float): sampled uniformly in [max(0, 1 - contrast), 1 + contrast]. 493 | saturation (float): sampled uniformly in [max(0, 1 - saturation), 1 + saturation]. 494 | hue (float): sampled uniformly in [-hue, hue]. 495 | gaussian_prob (float, optional): probability of applying gaussian blur. Defaults to 0.5. 496 | solarization_prob (float, optional): minimum scale of the crops. Defaults to 0.0. 497 | mean (Sequence[float], optional): mean values for normalization. 498 | Defaults to (0.485, 0.456, 0.406). 499 | std (Sequence[float], optional): std values for normalization. 500 | Defaults to (0.228, 0.224, 0.225). 501 | """ 502 | 503 | super().__init__() 504 | self.transform = transforms.Compose( 505 | [ 506 | transforms.RandomApply( 507 | [transforms.ColorJitter(brightness, contrast, saturation, hue)], 508 | p=0.8, 509 | ), 510 | transforms.RandomGrayscale(p=0.2), 511 | transforms.RandomApply([GaussianBlur()], p=gaussian_prob), 512 | transforms.RandomApply([Solarization()], p=solarization_prob), 513 | transforms.RandomHorizontalFlip(p=0.5), 514 | transforms.ToTensor(), 515 | transforms.Normalize(mean=mean, std=std), 516 | ] 517 | ) 518 | 519 | 520 | def prepare_transform(dataset: str, multicrop: bool = False, **kwargs) -> Any: 521 | """Prepares transforms for a specific dataset. Optionally uses multi crop. 522 | 523 | Args: 524 | dataset (str): name of the dataset. 525 | multicrop (bool, optional): whether or not to use multi crop. Defaults to False. 526 | 527 | Returns: 528 | Any: a transformation for a specific dataset. 529 | """ 530 | 531 | if dataset in ["cifar10", "cifar100"]: 532 | return CifarTransform(**kwargs) if not multicrop else MulticropCifarTransform() 533 | elif dataset == "stl10": 534 | return STLTransform(**kwargs) if not multicrop else MulticropSTLTransform() 535 | elif dataset in ["imagenet", "imagenet100"]: 536 | return ( 537 | ImagenetTransform(**kwargs) if not multicrop else MulticropImagenetTransform(**kwargs) 538 | ) 539 | elif dataset == "custom": 540 | return CustomTransform(**kwargs) if not multicrop else MulticropCustomTransform(**kwargs) 541 | 542 | 543 | def prepare_n_crop_transform( 544 | transform: Callable, num_crops: Optional[int] = None 545 | ) -> NCropAugmentation: 546 | """Turns a single crop transformation to an N crops transformation. 547 | 548 | Args: 549 | transform (Callable): a transformation. 550 | num_crops (Optional[int], optional): number of crops. Defaults to None. 551 | 552 | Returns: 553 | NCropAugmentation: an N crop transformation. 554 | """ 555 | 556 | return NCropAugmentation(transform, num_crops) 557 | 558 | 559 | def prepare_multicrop_transform( 560 | transform: Callable, 561 | size_crops: Sequence[int], 562 | num_crops: Optional[Sequence[int]] = None, 563 | min_scales: Optional[Sequence[float]] = None, 564 | max_scale_crops: Optional[Sequence[float]] = None, 565 | ) -> MulticropAugmentation: 566 | """Prepares multicrop transformations by creating custom crops given the parameters. 567 | 568 | Args: 569 | transform (Callable): transformation callable without cropping. 570 | size_crops (Sequence[int]): a sequence of sizes of the crops. 571 | num_crops (Optional[Sequence[int]]): list of number of crops per crop size. 572 | min_scales (Optional[Sequence[float]]): sequence of minimum crop scales per crop 573 | size. 574 | max_scale_crops (Optional[Sequence[float]]): sequence of maximum crop scales per crop 575 | size. 576 | 577 | Returns: 578 | MulticropAugmentation: prepared augmentation pipeline that supports multicrop with 579 | different sizes. 580 | """ 581 | 582 | if num_crops is None: 583 | num_crops = [2, 6] 584 | if min_scales is None: 585 | min_scales = [0.14, 0.05] 586 | if max_scale_crops is None: 587 | max_scale_crops = [1.0, 0.14] 588 | 589 | return MulticropAugmentation( 590 | transform, 591 | size_crops=size_crops, 592 | num_crops=num_crops, 593 | min_scales=min_scales, 594 | max_scale_crops=max_scale_crops, 595 | ) 596 | 597 | 598 | def prepare_datasets( 599 | dataset: str, 600 | transform: Callable, 601 | data_dir: Optional[Union[str, Path]] = None, 602 | train_dir: Optional[Union[str, Path]] = None, 603 | no_labels: Optional[Union[str, Path]] = False, 604 | ) -> Dataset: 605 | """Prepares the desired dataset. 606 | 607 | Args: 608 | dataset (str): the name of the dataset. 609 | transform (Callable): a transformation. 610 | data_dir (Optional[Union[str, Path]], optional): the directory to load data from. 611 | Defaults to None. 612 | train_dir (Optional[Union[str, Path]], optional): training data directory 613 | to be appended to data_dir. Defaults to None. 614 | no_labels (Optional[bool], optional): if the custom dataset has no labels. 615 | 616 | Returns: 617 | Dataset: the desired dataset with transformations. 618 | """ 619 | 620 | if data_dir is None: 621 | sandbox_folder = Path(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 622 | data_dir = sandbox_folder / "datasets" 623 | 624 | if train_dir is None: 625 | train_dir = Path(f"{dataset}/train") 626 | else: 627 | train_dir = Path(train_dir) 628 | 629 | if dataset in ["cifar10", "cifar100"]: 630 | DatasetClass = vars(torchvision.datasets)[dataset.upper()] 631 | train_dataset = dataset_with_index(DatasetClass)( 632 | data_dir / train_dir, 633 | train=True, 634 | download=True, 635 | transform=transform, 636 | ) 637 | 638 | elif dataset == "stl10": 639 | train_dataset = dataset_with_index(STL10)( 640 | data_dir / train_dir, 641 | split="train+unlabeled", 642 | download=True, 643 | transform=transform, 644 | ) 645 | 646 | elif dataset in ["imagenet", "imagenet100"]: 647 | train_dir = data_dir / train_dir 648 | train_dataset = dataset_with_index(ImageFolder)(train_dir, transform) 649 | 650 | elif dataset == "custom": 651 | train_dir = data_dir / train_dir 652 | 653 | if no_labels: 654 | dataset_class = CustomDatasetWithoutLabels 655 | else: 656 | dataset_class = ImageFolder 657 | 658 | train_dataset = dataset_with_index(dataset_class)(train_dir, transform) 659 | 660 | return train_dataset 661 | 662 | 663 | def prepare_dataloader( 664 | train_dataset: Dataset, batch_size: int = 64, num_workers: int = 4 665 | ) -> DataLoader: 666 | """Prepares the training dataloader for pretraining. 667 | 668 | Args: 669 | train_dataset (Dataset): the name of the dataset. 670 | batch_size (int, optional): batch size. Defaults to 64. 671 | num_workers (int, optional): number of workers. Defaults to 4. 672 | 673 | Returns: 674 | DataLoader: the training dataloader with the desired dataset. 675 | """ 676 | 677 | train_loader = DataLoader( 678 | train_dataset, 679 | batch_size=batch_size, 680 | shuffle=True, 681 | num_workers=num_workers, 682 | pin_memory=True, 683 | drop_last=True, 684 | ) 685 | return train_loader 686 | -------------------------------------------------------------------------------- /solo/utils/sinkhorn_knopp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | # Adapted from https://github.com/facebookresearch/swav. 21 | 22 | import torch 23 | import torch.distributed as dist 24 | 25 | 26 | class SinkhornKnopp(torch.nn.Module): 27 | def __init__(self, num_iters: int = 3, epsilon: float = 0.05, world_size: int = 1): 28 | """Approximates optimal transport using the Sinkhorn-Knopp algorithm. 29 | 30 | A simple iterative method to approach the double stochastic matrix is to alternately rescale 31 | rows and columns of the matrix to sum to 1. 32 | 33 | Args: 34 | num_iters (int, optional): number of times to perform row and column normalization. 35 | Defaults to 3. 36 | epsilon (float, optional): weight for the entropy regularization term. Defaults to 0.05. 37 | world_size (int, optional): number of nodes for distributed training. Defaults to 1. 38 | """ 39 | 40 | super().__init__() 41 | self.num_iters = num_iters 42 | self.epsilon = epsilon 43 | self.world_size = world_size 44 | 45 | @torch.no_grad() 46 | def forward(self, Q: torch.Tensor) -> torch.Tensor: 47 | """Produces assignments using Sinkhorn-Knopp algorithm. 48 | 49 | Applies the entropy regularization, normalizes the Q matrix and then normalizes rows and 50 | columns in an alternating fashion for num_iter times. Before returning it normalizes again 51 | the columns in order for the output to be an assignment of samples to prototypes. 52 | 53 | Args: 54 | Q (torch.Tensor): cosine similarities between the features of the 55 | samples and the prototypes. 56 | 57 | Returns: 58 | torch.Tensor: assignment of samples to prototypes according to optimal transport. 59 | """ 60 | 61 | Q = torch.exp(Q / self.epsilon).t() 62 | B = Q.shape[1] * self.world_size 63 | K = Q.shape[0] # num prototypes 64 | 65 | # make the matrix sums to 1 66 | sum_Q = torch.sum(Q) 67 | if dist.is_available() and dist.is_initialized(): 68 | dist.all_reduce(sum_Q) 69 | Q /= sum_Q 70 | 71 | for it in range(self.num_iters): 72 | # normalize each row: total weight per prototype must be 1/K 73 | sum_of_rows = torch.sum(Q, dim=1, keepdim=True) 74 | if dist.is_available() and dist.is_initialized(): 75 | dist.all_reduce(sum_of_rows) 76 | Q /= sum_of_rows 77 | Q /= K 78 | 79 | # normalize each column: total weight per sample must be 1/B 80 | Q /= torch.sum(Q, dim=0, keepdim=True) 81 | Q /= B 82 | 83 | Q *= B # the colomns must sum to 1 so that Q is an assignment 84 | return Q.t() 85 | -------------------------------------------------------------------------------- /solo/utils/whitening.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | 21 | import torch 22 | import torch.nn as nn 23 | from torch.cuda.amp import custom_fwd 24 | from torch.nn.functional import conv2d 25 | 26 | 27 | class Whitening2d(nn.Module): 28 | def __init__(self, output_dim: int, eps: float = 0.0): 29 | """Layer that computes hard whitening for W-MSE using the Cholesky decomposition. 30 | 31 | Args: 32 | output_dim (int): number of dimension of projected features. 33 | eps (float, optional): eps for numerical stability in Cholesky decomposition. Defaults 34 | to 0.0. 35 | """ 36 | 37 | super(Whitening2d, self).__init__() 38 | self.output_dim = output_dim 39 | self.eps = eps 40 | 41 | @custom_fwd(cast_inputs=torch.float32) 42 | def forward(self, x: torch.Tensor) -> torch.Tensor: 43 | """Performs whitening using the Cholesky decomposition. 44 | 45 | Args: 46 | x (torch.Tensor): a batch or slice of projected features. 47 | 48 | Returns: 49 | torch.Tensor: a batch or slice of whitened features. 50 | """ 51 | 52 | x = x.unsqueeze(2).unsqueeze(3) 53 | m = x.mean(0).view(self.output_dim, -1).mean(-1).view(1, -1, 1, 1) 54 | xn = x - m 55 | 56 | T = xn.permute(1, 0, 2, 3).contiguous().view(self.output_dim, -1) 57 | f_cov = torch.mm(T, T.permute(1, 0)) / (T.shape[-1] - 1) 58 | 59 | eye = torch.eye(self.output_dim).type(f_cov.type()) 60 | 61 | f_cov_shrinked = (1 - self.eps) * f_cov + self.eps * eye 62 | 63 | inv_sqrt = torch.triangular_solve(eye, torch.cholesky(f_cov_shrinked), upper=False)[0] 64 | inv_sqrt = inv_sqrt.contiguous().view(self.output_dim, self.output_dim, 1, 1) 65 | 66 | decorrelated = conv2d(xn, inv_sqrt) 67 | 68 | return decorrelated.squeeze(2).squeeze(2) 69 | --------------------------------------------------------------------------------