├── .gitignore ├── README.md ├── cifar100_datamodule.py ├── environment.yaml ├── moco ├── README.md └── moco.py ├── model_checkpoint.py ├── random_search.py ├── simclr.py ├── simclr_finetune.py ├── simclr_module.py ├── ssl_online.py ├── tiny_imagenet_datamodule.py └── transforms.py /.gitignore: -------------------------------------------------------------------------------- 1 | cifar-*/ 2 | *.tar.gz 3 | lightning_logs/ 4 | data/ 5 | SimCLR/ 6 | __pycache__/ 7 | tiny-imagenet-200/ 8 | tiny-imagenet-200.zip 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Introduction 2 | 3 | Official implementation of ICLR 2024 paper "Contrastive Learning Is Spectral Clustering On Similarity Graph" (https://arxiv.org/abs/2303.15103) . 4 | 5 | 6 | ## Installation 7 | 8 | Requirement: 9 | - Conda 10 | 11 | Once installed conda, you can create the `contrastive` environment using 12 | `conda env create -f environment.yaml`. 13 | 14 | 15 | 16 | ## Random Search 17 | Just run 18 | `python random_search.py` 19 | 20 | You can overwrite any pretraining arguments while random searching. For example, you want to random search the hyperparameters for CIFAR100 with lars optimizer in 100 epochs, you can run `python random_search.py --dataset cifar100 --optimizer lars --max_epochs 100` 21 | 22 | For more details, see the argument help of `random_search.py` 23 | 24 | ## Pretraining 25 | 26 | Once you have got the best parameter by random search, you can run `python simclr_module.py [args]` to pretrain. 27 | 28 | For more details, see the argument help of `simclr_module.py`. 29 | 30 | 31 | ## Linear Probing 32 | 33 | For linear probe, run `python simclr_finetune.py --ckpt_path [path/to/your/ckpt] [args]` 34 | 35 | For more details, see the argument help of `simclr_finetune.py`. For most cases, you may only need to change `dataset`, `data_dir`, `ckpt_path` three arguments. 36 | 37 | ## Acknowledgement 38 | 39 | This repo is mainly based on [Pytorch Lightning](https://github.com/Lightning-AI/lightning). Many thanks to their wonderful work! 40 | 41 | 42 | ## Citations 43 | Please cite the paper and star this repo if you use Kernel-InfoNCE and find it interesting/useful, thanks! Feel free to contact zhangyif21@mails.tsinghua.edu.cn | yangjq21@mails.tsinghua.edu.cn or open an issue if you have any questions. 44 | 45 | ```bibtex 46 | @article{tan2023contrastive, 47 | title={Contrastive Learning Is Spectral Clustering On Similarity Graph}, 48 | author={Tan, Zhiquan and Zhang, Yifan and Yang, Jingqin and Yuan, Yang}, 49 | journal={arXiv preprint arXiv:2303.15103}, 50 | year={2023} 51 | } 52 | ``` 53 | -------------------------------------------------------------------------------- /cifar100_datamodule.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from typing import Any, Callable, Optional, Sequence, Union 3 | 4 | from pl_bolts.datamodules.vision_datamodule import VisionDataModule 5 | from pl_bolts.datasets import TrialCIFAR10 6 | from pl_bolts.transforms.dataset_normalizations import cifar10_normalization 7 | from pl_bolts.utils import _TORCHVISION_AVAILABLE 8 | from pl_bolts.utils.stability import under_review 9 | from pl_bolts.utils.warnings import warn_missing_pkg 10 | 11 | if _TORCHVISION_AVAILABLE: 12 | from torchvision import transforms as transform_lib 13 | from torchvision.datasets import CIFAR100 14 | else: # pragma: no cover 15 | warn_missing_pkg("torchvision") 16 | CIFAR100 = None 17 | 18 | 19 | class CIFAR100DataModule(VisionDataModule): 20 | """ 21 | .. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/wp-content/uploads/2019/01/ 22 | Plot-of-a-Subset-of-Images-from-the-CIFAR-10-Dataset.png 23 | :width: 400 24 | :alt: CIFAR-10 25 | Specs: 26 | - 10 classes (1 per class) 27 | - Each image is (3 x 32 x 32) 28 | Standard CIFAR10, train, val, test splits and transforms 29 | Transforms:: 30 | transforms = transform_lib.Compose([ 31 | transform_lib.ToTensor(), 32 | transforms.Normalize( 33 | mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], 34 | std=[x / 255.0 for x in [63.0, 62.1, 66.7]] 35 | ) 36 | ]) 37 | Example:: 38 | from pl_bolts.datamodules import CIFAR10DataModule 39 | dm = CIFAR10DataModule(PATH) 40 | model = LitModel() 41 | Trainer().fit(model, datamodule=dm) 42 | Or you can set your own transforms 43 | Example:: 44 | dm.train_transforms = ... 45 | dm.test_transforms = ... 46 | dm.val_transforms = ... 47 | """ 48 | 49 | name = "cifar100" 50 | dataset_cls = CIFAR100 51 | dims = (3, 32, 32) 52 | 53 | def __init__( 54 | self, 55 | data_dir: Optional[str] = None, 56 | val_split: Union[int, float] = 0.1, 57 | num_workers: int = 0, 58 | normalize: bool = False, 59 | batch_size: int = 32, 60 | seed: int = 42, 61 | shuffle: bool = True, 62 | pin_memory: bool = True, 63 | drop_last: bool = False, 64 | *args: Any, 65 | **kwargs: Any, 66 | ) -> None: 67 | """ 68 | Args: 69 | data_dir: Where to save/load the data 70 | val_split: Percent (float) or number (int) of samples to use for the validation split 71 | num_workers: How many workers to use for loading data 72 | normalize: If true applies image normalize 73 | batch_size: How many samples per batch to load 74 | seed: Random seed to be used for train/val/test splits 75 | shuffle: If true shuffles the train data every epoch 76 | pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before 77 | returning them 78 | drop_last: If true drops the last incomplete batch 79 | """ 80 | super().__init__( # type: ignore[misc] 81 | data_dir=data_dir, 82 | val_split=val_split, 83 | num_workers=num_workers, 84 | normalize=normalize, 85 | batch_size=batch_size, 86 | seed=seed, 87 | shuffle=shuffle, 88 | pin_memory=pin_memory, 89 | drop_last=drop_last, 90 | *args, 91 | **kwargs, 92 | ) 93 | 94 | @property 95 | def num_samples(self) -> int: 96 | train_len, _ = self._get_splits(len_dataset=50_000) 97 | return train_len 98 | 99 | @property 100 | def num_classes(self) -> int: 101 | """ 102 | Return: 103 | 100 104 | """ 105 | return 100 106 | 107 | def default_transforms(self) -> Callable: 108 | if self.normalize: 109 | cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()]) 110 | else: 111 | cf10_transforms = transform_lib.Compose([transform_lib.ToTensor()]) 112 | 113 | return cf10_transforms 114 | 115 | @staticmethod 116 | def add_dataset_specific_args(parent_parser: ArgumentParser) -> ArgumentParser: 117 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 118 | 119 | parser.add_argument("--data_dir", type=str, default=".") 120 | parser.add_argument("--num_workers", type=int, default=0) 121 | parser.add_argument("--batch_size", type=int, default=32) 122 | 123 | return parser -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: contrastive 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _openmp_mutex=5.1=1_gnu 7 | - ca-certificates=2022.10.11=h06a4308_0 8 | - certifi=2022.12.7=py38h06a4308_0 9 | - ld_impl_linux-64=2.38=h1181459_1 10 | - libffi=3.4.2=h6a678d5_6 11 | - libgcc-ng=11.2.0=h1234567_1 12 | - libgomp=11.2.0=h1234567_1 13 | - libstdcxx-ng=11.2.0=h1234567_1 14 | - ncurses=6.3=h5eee18b_3 15 | - openssl=1.1.1s=h7f8727e_0 16 | - pip=22.3.1=py38h06a4308_0 17 | - python=3.8.15=h7a1cb2a_2 18 | - readline=8.2=h5eee18b_0 19 | - setuptools=65.6.3=py38h06a4308_0 20 | - sqlite=3.40.1=h5082296_0 21 | - tk=8.6.12=h1ccaba5_0 22 | - wheel=0.37.1=pyhd3eb1b0_0 23 | - xz=5.2.8=h5eee18b_0 24 | - zlib=1.2.13=h5eee18b_0 25 | - pip: 26 | - absl-py==1.4.0 27 | - aiohttp==3.8.3 28 | - aiohttp-cors==0.7.0 29 | - aiorwlock==1.3.0 30 | - aiosignal==1.3.1 31 | - anyio==3.6.2 32 | - async-timeout==4.0.2 33 | - attrs==22.2.0 34 | - blessed==1.19.1 35 | - byol-pytorch==0.6.0 36 | - cachetools==5.2.1 37 | - charset-normalizer==2.1.1 38 | - click==8.1.3 39 | - colorful==0.5.5 40 | - contourpy==1.0.7 41 | - cycler==0.11.0 42 | - distlib==0.3.6 43 | - fastapi==0.89.1 44 | - filelock==3.9.0 45 | - fonttools==4.38.0 46 | - frozenlist==1.3.3 47 | - fsspec==2022.11.0 48 | - google-api-core==2.11.0 49 | - google-auth==2.16.0 50 | - google-auth-oauthlib==0.4.6 51 | - googleapis-common-protos==1.58.0 52 | - gpustat==1.0.0 53 | - grpcio==1.51.1 54 | - h11==0.14.0 55 | - idna==3.4 56 | - importlib-metadata==6.0.0 57 | - importlib-resources==5.10.2 58 | - joblib==1.2.0 59 | - jsonschema==4.17.3 60 | - kiwisolver==1.4.4 61 | - lightning-bolts==0.6.0.post1 62 | - lightning-utilities==0.5.0 63 | - markdown==3.4.1 64 | - markupsafe==2.1.1 65 | - matplotlib==3.6.3 66 | - msgpack==1.0.4 67 | - multidict==6.0.4 68 | - numpy==1.24.1 69 | - nvidia-cublas-cu11==11.10.3.66 70 | - nvidia-cuda-nvrtc-cu11==11.7.99 71 | - nvidia-cuda-runtime-cu11==11.7.99 72 | - nvidia-cudnn-cu11==8.5.0.96 73 | - nvidia-ml-py==11.495.46 74 | - oauthlib==3.2.2 75 | - opencensus==0.11.0 76 | - opencensus-context==0.1.3 77 | - packaging==23.0 78 | - pandas==1.5.2 79 | - pathtools==0.1.2 80 | - pillow==9.4.0 81 | - pkgutil-resolve-name==1.3.10 82 | - platformdirs==2.6.2 83 | - prometheus-client==0.13.1 84 | - protobuf==3.19.6 85 | - psutil==5.9.4 86 | - py-spy==0.3.14 87 | - pyarrow==7.0.0 88 | - pyasn1==0.4.8 89 | - pyasn1-modules==0.2.8 90 | - pydantic==1.10.4 91 | - pyparsing==3.0.9 92 | - pyrsistent==0.19.3 93 | - python-dateutil==2.8.2 94 | - pytorch-lightning==1.8.6 95 | - pytz==2022.7.1 96 | - pyyaml==6.0 97 | - ray==2.2.0 98 | - requests==2.28.2 99 | - requests-oauthlib==1.3.1 100 | - rsa==4.9 101 | - scikit-learn==1.2.2 102 | - scipy==1.10.1 103 | - six==1.16.0 104 | - sklearn==0.0.post5 105 | - smart-open==6.3.0 106 | - sniffio==1.3.0 107 | - starlette==0.22.0 108 | - tabulate==0.9.0 109 | - tensorboard==2.11.2 110 | - tensorboard-data-server==0.6.1 111 | - tensorboard-plugin-wit==1.8.1 112 | - tensorboardx==2.5.1 113 | - thop==0.1.1-2209072238 114 | - threadpoolctl==3.1.0 115 | - torch==1.12.1+cu113 116 | - torchaudio==0.12.1+cu113 117 | - torchmetrics==0.10.3 118 | - torchvision==0.13.1+cu113 119 | - tqdm==4.64.1 120 | - typing-extensions==4.4.0 121 | - urllib3==1.26.14 122 | - uvicorn==0.20.0 123 | - virtualenv==20.17.1 124 | - wcwidth==0.2.6 125 | - werkzeug==2.2.2 126 | - yarl==1.8.2 127 | - zipp==3.11.0 128 | prefix: /home/yjq/anaconda3/envs/contrastive 129 | -------------------------------------------------------------------------------- /moco/README.md: -------------------------------------------------------------------------------- 1 | # Kernel-InfoNCE on MoCo 2 | 3 | ## Installation 4 | 5 | Based on kernel-infoNCE's environment, run `pip install lightly` 6 | 7 | ## Run Experiment 8 | 9 | `python moco.py [args]` 10 | 11 | The meaning of different arguments is same to Kernel-InfoNCE's arguments. 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /moco/moco.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import argparse 3 | import pytorch_lightning as pl 4 | import torch 5 | import torch.nn as nn 6 | import torchvision 7 | 8 | from lightly.data import LightlyDataset 9 | # from lightly.loss import NTXentLoss 10 | from lightly.loss.memory_bank import MemoryBankModule 11 | from lightly.models import ResNetGenerator 12 | from lightly.models.modules.heads import MoCoProjectionHead 13 | from lightly.models.utils import ( 14 | batch_shuffle, 15 | batch_unshuffle, 16 | deactivate_requires_grad, 17 | update_momentum, 18 | ) 19 | from lightly.transforms import MoCoV2Transform, utils 20 | 21 | 22 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 23 | 24 | parser.add_argument('--loss_type', default='origin', type=str) 25 | parser.add_argument('--dataset', default='cifar10', type=str) 26 | parser.add_argument('--epochs', default=200, type=int) 27 | parser.add_argument('--gamma', default=1.0, type=float, metavar='M', 28 | help='mce gamma') 29 | parser.add_argument('--temperature', default=0.1, type=float) 30 | parser.add_argument('--temperature2', default=0.3, type=float) 31 | parser.add_argument('--gamma_lambd', default=0.2, type=float) 32 | 33 | args = parser.parse_args() 34 | 35 | num_workers = 8 36 | batch_size = 512 37 | memory_bank_size = 4096 38 | seed = 1 39 | max_epochs = args.epochs 40 | 41 | 42 | class NTXentLoss(MemoryBankModule): 43 | def __init__( 44 | self, 45 | temperature: float = 0.5, 46 | memory_bank_size: int = 0, 47 | gather_distributed: bool = False, 48 | ): 49 | super(NTXentLoss, self).__init__(size=memory_bank_size) 50 | self.temperature = temperature 51 | self.gather_distributed = gather_distributed 52 | self.cross_entropy = nn.CrossEntropyLoss(reduction="mean") 53 | self.eps = 1e-8 54 | self.gamma_lambd = args.gamma_lambd 55 | self.temperature = args.temperature 56 | self.gamma = args.gamma 57 | self.temperature2 = args.temperature2 58 | self.loss_type = args.loss_type 59 | # print('wocaonima') 60 | 61 | if abs(self.temperature) < self.eps: 62 | raise ValueError( 63 | "Illegal temperature: abs({}) < 1e-8".format(self.temperature) 64 | ) 65 | if gather_distributed and not torch_dist.is_available(): 66 | raise ValueError( 67 | "gather_distributed is True but torch.distributed is not available. " 68 | "Please set gather_distributed=False or install a torch version with " 69 | "distributed support." 70 | ) 71 | 72 | def gamma_loss(self, out_1, out_2, gamma, temperature, eps=1e-6, negative=None): 73 | 74 | cov = torch.pow(torch.cdist(out_1, negative.T, p=2), gamma) * -1. 75 | sim = torch.exp(cov / temperature) 76 | neg = torch.clamp(sim.sum(dim=-1), min=eps) 77 | sim_adj = torch.pow(torch.norm(out_1 - out_2, dim=-1, p=2.), gamma) * -1. 78 | pos = torch.exp(sim_adj / temperature) 79 | loss = -torch.log(pos / (neg + eps)).mean() 80 | 81 | return loss 82 | 83 | def nt_xent_loss(self, out_1, out_2, negative): 84 | """ 85 | assume out_1 and out_2 are normalized 86 | out_1: [batch_size, dim] 87 | out_2: [batch_size, dim] 88 | """ 89 | 90 | if self.loss_type == "sum": 91 | loss = self.gamma_loss(out_1=out_1, out_2=out_2, gamma=self.gamma, 92 | temperature=self.temperature, negative=negative) * self.gamma_lambd + self.gamma_loss( 93 | out_1=out_1, 94 | out_2=out_2, 95 | gamma=2.0, 96 | temperature=self.temperature2, negative=negative) * ( 97 | 1. - self.gamma_lambd) 98 | elif self.loss_type == "origin": 99 | loss = self.gamma_loss(out_1=out_1, out_2=out_2, gamma=self.gamma, temperature=self.temperature, negative=negative) 100 | else: 101 | raise NotImplementedError 102 | 103 | return loss 104 | 105 | def forward(self, out0: torch.Tensor, out1: torch.Tensor): 106 | 107 | device = out0.device 108 | batch_size, _ = out0.shape 109 | 110 | out0 = nn.functional.normalize(out0, dim=1) 111 | out1 = nn.functional.normalize(out1, dim=1) 112 | 113 | out1, negatives = super(NTXentLoss, self).forward( 114 | out1, update=out0.requires_grad 115 | ) 116 | return self.nt_xent_loss(out0, out1, negative=negatives.to(device)) 117 | 118 | # %% 119 | # Replace the path with the location of your CIFAR-10 dataset. 120 | # We assume we have a train folder with subfolders 121 | # for each class and .png images inside. 122 | # 123 | # You can download `CIFAR-10 in folders from Kaggle 124 | # `_. 125 | 126 | # The dataset structure should be like this: 127 | # cifar10/train/ 128 | # L airplane/ 129 | # L 10008_airplane.png 130 | # L ... 131 | # L automobile/ 132 | # L bird/ 133 | # L cat/ 134 | # L deer/ 135 | # L dog/ 136 | # L frog/ 137 | # L horse/ 138 | # L ship/ 139 | # L truck/ 140 | path_to_train = "/data/cifar10/cifar10/train/" 141 | path_to_test = "/data/cifar10/cifar10/test/" 142 | pl.seed_everything(seed) 143 | 144 | transform = MoCoV2Transform( 145 | input_size=32, 146 | gaussian_blur=0.0, 147 | ) 148 | train_classifier_transforms = torchvision.transforms.Compose( 149 | [ 150 | torchvision.transforms.RandomCrop(32, padding=4), 151 | torchvision.transforms.RandomHorizontalFlip(), 152 | torchvision.transforms.ToTensor(), 153 | torchvision.transforms.Normalize( 154 | mean=utils.IMAGENET_NORMALIZE["mean"], 155 | std=utils.IMAGENET_NORMALIZE["std"], 156 | ), 157 | ] 158 | ) 159 | 160 | test_transforms = torchvision.transforms.Compose( 161 | [ 162 | torchvision.transforms.Resize((32, 32)), 163 | torchvision.transforms.ToTensor(), 164 | torchvision.transforms.Normalize( 165 | mean=utils.IMAGENET_NORMALIZE["mean"], 166 | std=utils.IMAGENET_NORMALIZE["std"], 167 | ), 168 | ] 169 | ) 170 | 171 | 172 | if args.dataset == 'cifar10': 173 | # We use the moco augmentations for training moco 174 | dataset_train_moco = LightlyDataset(input_dir=path_to_train, transform=transform) 175 | 176 | dataset_train_classifier = LightlyDataset( 177 | input_dir=path_to_train, transform=train_classifier_transforms 178 | ) 179 | 180 | dataset_test = LightlyDataset(input_dir=path_to_test, transform=test_transforms) 181 | elif args.dataset == 'cifar100': 182 | dataset_train_moco = LightlyDataset.from_torch_dataset( 183 | torchvision.datasets.cifar.CIFAR100(root='/data', transform=transform, train=True, download=False), transform=transform) 184 | dataset_train_classifier = LightlyDataset.from_torch_dataset( 185 | torchvision.datasets.cifar.CIFAR100(root='/data', transform=train_classifier_transforms, train=True, download=False), transform=train_classifier_transforms) 186 | 187 | dataset_test = LightlyDataset.from_torch_dataset( 188 | torchvision.datasets.cifar.CIFAR100(root='/data', transform=test_transforms, train=False, download=False), transform=test_transforms) 189 | elif args.dataset == 'tiny': 190 | dataset_train_moco = LightlyDataset(input_dir='/data/tiny-imagenet-200/train', transform=transform) 191 | dataset_train_classifier = LightlyDataset( 192 | input_dir='/data/tiny-imagenet-200/train', transform=train_classifier_transforms 193 | ) 194 | 195 | dataset_test = LightlyDataset(input_dir='/data/tiny-imagenet-200/val', transform=test_transforms) 196 | else: 197 | raise NotImplementedError 198 | 199 | dataloader_train_moco = torch.utils.data.DataLoader( 200 | dataset_train_moco, 201 | batch_size=batch_size, 202 | shuffle=True, 203 | drop_last=True, 204 | num_workers=num_workers, 205 | ) 206 | 207 | dataloader_train_classifier = torch.utils.data.DataLoader( 208 | dataset_train_classifier, 209 | batch_size=batch_size, 210 | shuffle=True, 211 | drop_last=True, 212 | num_workers=num_workers, 213 | ) 214 | 215 | dataloader_test = torch.utils.data.DataLoader( 216 | dataset_test, 217 | batch_size=batch_size, 218 | shuffle=False, 219 | drop_last=False, 220 | num_workers=num_workers, 221 | ) 222 | 223 | class MocoModel(pl.LightningModule): 224 | def __init__(self): 225 | super().__init__() 226 | 227 | # create a ResNet backbone and remove the classification head 228 | resnet = ResNetGenerator("resnet-18", 1, num_splits=8) 229 | self.backbone = nn.Sequential( 230 | *list(resnet.children())[:-1], 231 | nn.AdaptiveAvgPool2d(1), 232 | ) 233 | 234 | # create a moco model based on ResNet 235 | self.projection_head = MoCoProjectionHead(512, 512, 128) 236 | self.backbone_momentum = copy.deepcopy(self.backbone) 237 | self.projection_head_momentum = copy.deepcopy(self.projection_head) 238 | deactivate_requires_grad(self.backbone_momentum) 239 | deactivate_requires_grad(self.projection_head_momentum) 240 | 241 | # create our loss with the optional memory bank 242 | self.criterion = NTXentLoss(temperature=0.1, memory_bank_size=memory_bank_size) 243 | 244 | def training_step(self, batch, batch_idx): 245 | (x_q, x_k), _, _ = batch 246 | 247 | # update momentum 248 | update_momentum(self.backbone, self.backbone_momentum, 0.99) 249 | update_momentum(self.projection_head, self.projection_head_momentum, 0.99) 250 | 251 | # get queries 252 | q = self.backbone(x_q).flatten(start_dim=1) 253 | q = self.projection_head(q) 254 | 255 | # get keys 256 | k, shuffle = batch_shuffle(x_k) 257 | k = self.backbone_momentum(k).flatten(start_dim=1) 258 | k = self.projection_head_momentum(k) 259 | k = batch_unshuffle(k, shuffle) 260 | 261 | loss = self.criterion(q, k) 262 | self.log("train_loss_ssl", loss) 263 | return loss 264 | 265 | def on_train_epoch_end(self): 266 | self.custom_histogram_weights() 267 | 268 | # We provide a helper method to log weights in tensorboard 269 | # which is useful for debugging. 270 | def custom_histogram_weights(self): 271 | for name, params in self.named_parameters(): 272 | self.logger.experiment.add_histogram(name, params, self.current_epoch) 273 | 274 | def configure_optimizers(self): 275 | optim = torch.optim.SGD( 276 | self.parameters(), 277 | lr=6e-2, 278 | momentum=0.9, 279 | weight_decay=5e-4, 280 | ) 281 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, max_epochs) 282 | return [optim], [scheduler] 283 | 284 | class Classifier(pl.LightningModule): 285 | def __init__(self, backbone): 286 | super().__init__() 287 | # use the pretrained ResNet backbone 288 | self.backbone = backbone 289 | 290 | # freeze the backbone 291 | deactivate_requires_grad(backbone) 292 | 293 | # create a linear layer for our downstream classification model 294 | if args.dataset == 'cifar10': self.fc = nn.Linear(512, 10) 295 | elif args.dataset == 'cifar100': self.fc = nn.Linear(512, 100) 296 | else: self.fc = nn.Linear(512, 200) 297 | 298 | self.criterion = nn.CrossEntropyLoss() 299 | self.validation_step_outputs = [] 300 | 301 | def forward(self, x): 302 | y_hat = self.backbone(x).flatten(start_dim=1) 303 | y_hat = self.fc(y_hat) 304 | return y_hat 305 | 306 | def training_step(self, batch, batch_idx): 307 | x, y, _ = batch 308 | y_hat = self.forward(x) 309 | loss = self.criterion(y_hat, y) 310 | self.log("train_loss_fc", loss) 311 | return loss 312 | 313 | def on_train_epoch_end(self): 314 | self.custom_histogram_weights() 315 | 316 | # We provide a helper method to log weights in tensorboard 317 | # which is useful for debugging. 318 | def custom_histogram_weights(self): 319 | for name, params in self.named_parameters(): 320 | self.logger.experiment.add_histogram(name, params, self.current_epoch) 321 | 322 | def validation_step(self, batch, batch_idx): 323 | x, y, _ = batch 324 | y_hat = self.forward(x) 325 | y_hat = torch.nn.functional.softmax(y_hat, dim=1) 326 | 327 | # calculate number of correct predictions 328 | _, predicted = torch.max(y_hat, 1) 329 | num = predicted.shape[0] 330 | correct = (predicted == y).float().sum() 331 | self.validation_step_outputs.append((num, correct)) 332 | return num, correct 333 | 334 | def on_validation_epoch_end(self): 335 | # calculate and log top1 accuracy 336 | if self.validation_step_outputs: 337 | total_num = 0 338 | total_correct = 0 339 | for num, correct in self.validation_step_outputs: 340 | total_num += num 341 | total_correct += correct 342 | acc = total_correct / total_num 343 | self.log("val_acc", acc, on_epoch=True, prog_bar=True) 344 | self.validation_step_outputs.clear() 345 | 346 | def configure_optimizers(self): 347 | optim = torch.optim.SGD(self.fc.parameters(), lr=30.0) 348 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, 100) 349 | return [optim], [scheduler] 350 | 351 | 352 | model = MocoModel() 353 | trainer = pl.Trainer(max_epochs=max_epochs, devices=1, accelerator="gpu") 354 | trainer.fit(model, dataloader_train_moco) 355 | 356 | model.eval() 357 | classifier = Classifier(model.backbone) 358 | trainer = pl.Trainer(max_epochs=100, devices=1, accelerator="gpu") 359 | trainer.fit(classifier, dataloader_train_classifier, dataloader_test) 360 | -------------------------------------------------------------------------------- /model_checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright The PyTorch Lightning team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Model Checkpointing 16 | =================== 17 | 18 | Automatically save model checkpoints during training. 19 | 20 | """ 21 | import logging 22 | import os 23 | import re 24 | import time 25 | import warnings 26 | from copy import deepcopy 27 | from datetime import timedelta 28 | from typing import Any, Dict, Optional, Set 29 | from weakref import proxy 30 | 31 | import numpy as np 32 | import torch 33 | import yaml 34 | from lightning_utilities.core.rank_zero import WarningCache 35 | from torch import Tensor 36 | 37 | import pytorch_lightning as pl 38 | from lightning_lite.utilities.cloud_io import get_filesystem 39 | from lightning_lite.utilities.types import _PATH 40 | from pytorch_lightning.callbacks import Checkpoint 41 | from pytorch_lightning.utilities.exceptions import MisconfigurationException 42 | from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn 43 | from pytorch_lightning.utilities.types import STEP_OUTPUT 44 | 45 | log = logging.getLogger(__name__) 46 | warning_cache = WarningCache() 47 | 48 | 49 | class ModelCheckpoint(Checkpoint): 50 | r""" 51 | Save the model periodically by monitoring a quantity. Every metric logged with 52 | :meth:`~pytorch_lightning.core.module.log` or :meth:`~pytorch_lightning.core.module.log_dict` in 53 | LightningModule is a candidate for the monitor key. For more information, see 54 | :ref:`checkpointing`. 55 | 56 | After training finishes, use :attr:`best_model_path` to retrieve the path to the 57 | best checkpoint file and :attr:`best_model_score` to retrieve its score. 58 | 59 | Args: 60 | dirpath: directory to save the model file. 61 | 62 | Example:: 63 | 64 | # custom path 65 | # saves a file like: my/path/epoch=0-step=10.ckpt 66 | >>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/') 67 | 68 | By default, dirpath is ``None`` and will be set at runtime to the location 69 | specified by :class:`~pytorch_lightning.trainer.trainer.Trainer`'s 70 | :paramref:`~pytorch_lightning.trainer.trainer.Trainer.default_root_dir` argument, 71 | and if the Trainer uses a logger, the path will also contain logger name and version. 72 | 73 | filename: checkpoint filename. Can contain named formatting options to be auto-filled. 74 | 75 | Example:: 76 | 77 | # save any arbitrary metrics like `val_loss`, etc. in name 78 | # saves a file like: my/path/epoch=2-val_loss=0.02-other_metric=0.03.ckpt 79 | >>> checkpoint_callback = ModelCheckpoint( 80 | ... dirpath='my/path', 81 | ... filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}' 82 | ... ) 83 | 84 | By default, filename is ``None`` and will be set to ``'{epoch}-{step}'``. 85 | monitor: quantity to monitor. By default it is ``None`` which saves a checkpoint only for the last epoch. 86 | verbose: verbosity mode. Default: ``False``. 87 | save_last: When ``True``, saves an exact copy of the checkpoint to a file `last.ckpt` whenever a checkpoint 88 | file gets saved. This allows accessing the latest checkpoint in a deterministic manner. Default: ``None``. 89 | save_top_k: if ``save_top_k == k``, 90 | the best k models according to the quantity monitored will be saved. 91 | if ``save_top_k == 0``, no models are saved. 92 | if ``save_top_k == -1``, all models are saved. 93 | Please note that the monitors are checked every ``every_n_epochs`` epochs. 94 | if ``save_top_k >= 2`` and the callback is called multiple 95 | times inside an epoch, the name of the saved file will be 96 | appended with a version count starting with ``v1``. 97 | mode: one of {min, max}. 98 | If ``save_top_k != 0``, the decision to overwrite the current save file is made 99 | based on either the maximization or the minimization of the monitored quantity. 100 | For ``'val_acc'``, this should be ``'max'``, for ``'val_loss'`` this should be ``'min'``, etc. 101 | auto_insert_metric_name: When ``True``, the checkpoints filenames will contain the metric name. 102 | For example, ``filename='checkpoint_{epoch:02d}-{acc:02.0f}`` with epoch ``1`` and acc ``1.12`` will resolve 103 | to ``checkpoint_epoch=01-acc=01.ckpt``. Is useful to set it to ``False`` when metric names contain ``/`` 104 | as this will result in extra folders. 105 | For example, ``filename='epoch={epoch}-step={step}-val_acc={val/acc:.2f}', auto_insert_metric_name=False`` 106 | save_weights_only: if ``True``, then only the model's weights will be 107 | saved. Otherwise, the optimizer states, lr-scheduler states, etc are added in the checkpoint too. 108 | every_n_train_steps: Number of training steps between checkpoints. 109 | If ``every_n_train_steps == None or every_n_train_steps == 0``, we skip saving during training. 110 | To disable, set ``every_n_train_steps = 0``. This value must be ``None`` or non-negative. 111 | This must be mutually exclusive with ``train_time_interval`` and ``every_n_epochs``. 112 | train_time_interval: Checkpoints are monitored at the specified time interval. 113 | For all practical purposes, this cannot be smaller than the amount 114 | of time it takes to process a single training batch. This is not 115 | guaranteed to execute at the exact time specified, but should be close. 116 | This must be mutually exclusive with ``every_n_train_steps`` and ``every_n_epochs``. 117 | every_n_epochs: Number of epochs between checkpoints. 118 | This value must be ``None`` or non-negative. 119 | To disable saving top-k checkpoints, set ``every_n_epochs = 0``. 120 | This argument does not impact the saving of ``save_last=True`` checkpoints. 121 | If all of ``every_n_epochs``, ``every_n_train_steps`` and 122 | ``train_time_interval`` are ``None``, we save a checkpoint at the end of every epoch 123 | (equivalent to ``every_n_epochs = 1``). 124 | If ``every_n_epochs == None`` and either ``every_n_train_steps != None`` or ``train_time_interval != None``, 125 | saving at the end of each epoch is disabled 126 | (equivalent to ``every_n_epochs = 0``). 127 | This must be mutually exclusive with ``every_n_train_steps`` and ``train_time_interval``. 128 | Setting both ``ModelCheckpoint(..., every_n_epochs=V, save_on_train_epoch_end=False)`` and 129 | ``Trainer(max_epochs=N, check_val_every_n_epoch=M)`` 130 | will only save checkpoints at epochs 0 < E <= N 131 | where both values for ``every_n_epochs`` and ``check_val_every_n_epoch`` evenly divide E. 132 | save_on_train_epoch_end: Whether to run checkpointing at the end of the training epoch. 133 | If this is ``False``, then the check runs at the end of the validation. 134 | 135 | Note: 136 | For extra customization, ModelCheckpoint includes the following attributes: 137 | 138 | - ``CHECKPOINT_JOIN_CHAR = "-"`` 139 | - ``CHECKPOINT_NAME_LAST = "last"`` 140 | - ``FILE_EXTENSION = ".ckpt"`` 141 | - ``STARTING_VERSION = 1`` 142 | 143 | For example, you can change the default last checkpoint name by doing 144 | ``checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last"`` 145 | 146 | If you want to checkpoint every N hours, every M train batches, and/or every K val epochs, 147 | then you should create multiple ``ModelCheckpoint`` callbacks. 148 | 149 | If the checkpoint's ``dirpath`` changed from what it was before while resuming the training, 150 | only ``best_model_path`` will be reloaded and a warning will be issued. 151 | 152 | Raises: 153 | MisconfigurationException: 154 | If ``save_top_k`` is smaller than ``-1``, 155 | if ``monitor`` is ``None`` and ``save_top_k`` is none of ``None``, ``-1``, and ``0``, or 156 | if ``mode`` is none of ``"min"`` or ``"max"``. 157 | ValueError: 158 | If ``trainer.save_checkpoint`` is ``None``. 159 | 160 | Example:: 161 | 162 | >>> from pytorch_lightning import Trainer 163 | >>> from pytorch_lightning.callbacks import ModelCheckpoint 164 | 165 | # saves checkpoints to 'my/path/' at every epoch 166 | >>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/') 167 | >>> trainer = Trainer(callbacks=[checkpoint_callback]) 168 | 169 | # save epoch and val_loss in name 170 | # saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt 171 | >>> checkpoint_callback = ModelCheckpoint( 172 | ... monitor='val_loss', 173 | ... dirpath='my/path/', 174 | ... filename='sample-mnist-{epoch:02d}-{val_loss:.2f}' 175 | ... ) 176 | 177 | # save epoch and val_loss in name, but specify the formatting yourself (e.g. to avoid problems with Tensorboard 178 | # or Neptune, due to the presence of characters like '=' or '/') 179 | # saves a file like: my/path/sample-mnist-epoch02-val_loss0.32.ckpt 180 | >>> checkpoint_callback = ModelCheckpoint( 181 | ... monitor='val/loss', 182 | ... dirpath='my/path/', 183 | ... filename='sample-mnist-epoch{epoch:02d}-val_loss{val/loss:.2f}', 184 | ... auto_insert_metric_name=False 185 | ... ) 186 | 187 | # retrieve the best checkpoint after training 188 | checkpoint_callback = ModelCheckpoint(dirpath='my/path/') 189 | trainer = Trainer(callbacks=[checkpoint_callback]) 190 | model = ... 191 | trainer.fit(model) 192 | checkpoint_callback.best_model_path 193 | 194 | .. tip:: Saving and restoring multiple checkpoint callbacks at the same time is supported under variation in the 195 | following arguments: 196 | 197 | *monitor, mode, every_n_train_steps, every_n_epochs, train_time_interval, save_on_train_epoch_end* 198 | 199 | Read more: :ref:`Persisting Callback State ` 200 | """ 201 | 202 | CHECKPOINT_JOIN_CHAR = "-" 203 | CHECKPOINT_NAME_LAST = "last" 204 | FILE_EXTENSION = ".ckpt" 205 | STARTING_VERSION = 1 206 | 207 | def __init__( 208 | self, 209 | dirpath: Optional[_PATH] = None, 210 | filename: Optional[str] = None, 211 | monitor: Optional[str] = None, 212 | verbose: bool = False, 213 | save_last: Optional[bool] = None, 214 | save_top_k: int = 1, 215 | save_weights_only: bool = False, 216 | mode: str = "min", 217 | auto_insert_metric_name: bool = True, 218 | every_n_train_steps: Optional[int] = None, 219 | train_time_interval: Optional[timedelta] = None, 220 | every_n_epochs: Optional[int] = None, 221 | save_on_train_epoch_end: Optional[bool] = None, 222 | ): 223 | super().__init__() 224 | self.monitor = monitor 225 | self.verbose = verbose 226 | self.save_last = save_last 227 | self.save_top_k = save_top_k 228 | self.save_weights_only = save_weights_only 229 | self.auto_insert_metric_name = auto_insert_metric_name 230 | self._save_on_train_epoch_end = save_on_train_epoch_end 231 | self._last_global_step_saved = 0 # no need to save when no steps were taken 232 | self._last_time_checked: Optional[float] = None 233 | self.current_score: Optional[Tensor] = None 234 | self.best_k_models: Dict[str, Tensor] = {} 235 | self.kth_best_model_path = "" 236 | self.best_model_score: Optional[Tensor] = None 237 | self.best_model_path = "" 238 | self.last_model_path = "" 239 | 240 | self.kth_value: Tensor 241 | self.dirpath: Optional[_PATH] 242 | self.__init_monitor_mode(mode) 243 | self.__init_ckpt_dir(dirpath, filename) 244 | self.__init_triggers(every_n_train_steps, every_n_epochs, train_time_interval) 245 | self.__validate_init_configuration() 246 | 247 | @property 248 | def state_key(self) -> str: 249 | return self._generate_state_key( 250 | monitor=self.monitor, 251 | mode=self.mode, 252 | every_n_train_steps=self._every_n_train_steps, 253 | every_n_epochs=self._every_n_epochs, 254 | train_time_interval=self._train_time_interval, 255 | save_on_train_epoch_end=self._save_on_train_epoch_end, 256 | ) 257 | 258 | def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: 259 | dirpath = self.__resolve_ckpt_dir(trainer) 260 | dirpath = trainer.strategy.broadcast(dirpath) 261 | self.dirpath = dirpath 262 | if trainer.is_global_zero and stage == "fit": 263 | self.__warn_if_dir_not_empty(self.dirpath) 264 | 265 | # NOTE: setting these attributes needs to happen as early as possible BEFORE reloading callback states, 266 | # because the attributes are part of the state_key which needs to be fully defined before reloading. 267 | if self._save_on_train_epoch_end is None: 268 | # if the user runs validation multiple times per training epoch or multiple training epochs without 269 | # validation, then we run after validation instead of on train epoch end 270 | self._save_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1 271 | 272 | def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: 273 | self._last_time_checked = time.monotonic() 274 | 275 | def on_train_batch_end( 276 | self, 277 | trainer: "pl.Trainer", 278 | pl_module: "pl.LightningModule", 279 | outputs: STEP_OUTPUT, 280 | batch: Any, 281 | batch_idx: int, 282 | ) -> None: 283 | """Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps`""" 284 | if self._should_skip_saving_checkpoint(trainer): 285 | return 286 | skip_batch = self._every_n_train_steps < 1 or (trainer.global_step % self._every_n_train_steps != 0) 287 | 288 | train_time_interval = self._train_time_interval 289 | skip_time = True 290 | now = time.monotonic() 291 | if train_time_interval: 292 | prev_time_check = self._last_time_checked 293 | skip_time = prev_time_check is None or (now - prev_time_check) < train_time_interval.total_seconds() 294 | # in case we have time differences across ranks 295 | # broadcast the decision on whether to checkpoint from rank 0 to avoid possible hangs 296 | skip_time = trainer.strategy.broadcast(skip_time) 297 | 298 | if skip_batch and skip_time: 299 | return 300 | if not skip_time: 301 | self._last_time_checked = now 302 | 303 | monitor_candidates = self._monitor_candidates(trainer) 304 | self._save_topk_checkpoint(trainer, monitor_candidates) 305 | self._save_last_checkpoint(trainer, monitor_candidates) 306 | 307 | def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: 308 | """Save top 1 checkpoint in every 200 epoch""" 309 | if trainer.current_epoch % 200 == 0: 310 | self.best_k_models = {} 311 | """Save a checkpoint at the end of the training epoch.""" 312 | if not self._should_skip_saving_checkpoint(trainer) and self._save_on_train_epoch_end: 313 | monitor_candidates = self._monitor_candidates(trainer) 314 | if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0: 315 | self._save_topk_checkpoint(trainer, monitor_candidates) 316 | self._save_last_checkpoint(trainer, monitor_candidates) 317 | 318 | def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: 319 | """Save a checkpoint at the end of the validation stage.""" 320 | if not self._should_skip_saving_checkpoint(trainer) and not self._save_on_train_epoch_end: 321 | monitor_candidates = self._monitor_candidates(trainer) 322 | if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0: 323 | self._save_topk_checkpoint(trainer, monitor_candidates) 324 | self._save_last_checkpoint(trainer, monitor_candidates) 325 | 326 | def state_dict(self) -> Dict[str, Any]: 327 | return { 328 | "monitor": self.monitor, 329 | "best_model_score": self.best_model_score, 330 | "best_model_path": self.best_model_path, 331 | "current_score": self.current_score, 332 | "dirpath": self.dirpath, 333 | "best_k_models": self.best_k_models, 334 | "kth_best_model_path": self.kth_best_model_path, 335 | "kth_value": self.kth_value, 336 | "last_model_path": self.last_model_path, 337 | } 338 | 339 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 340 | dirpath_from_ckpt = state_dict.get("dirpath", self.dirpath) 341 | 342 | if self.dirpath == dirpath_from_ckpt: 343 | self.best_model_score = state_dict["best_model_score"] 344 | self.kth_best_model_path = state_dict.get("kth_best_model_path", self.kth_best_model_path) 345 | self.kth_value = state_dict.get("kth_value", self.kth_value) 346 | self.best_k_models = state_dict.get("best_k_models", self.best_k_models) 347 | self.last_model_path = state_dict.get("last_model_path", self.last_model_path) 348 | else: 349 | warnings.warn( 350 | f"The dirpath has changed from {dirpath_from_ckpt!r} to {self.dirpath!r}," 351 | " therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path` and" 352 | " `best_k_models` won't be reloaded. Only `best_model_path` will be reloaded." 353 | ) 354 | 355 | self.best_model_path = state_dict["best_model_path"] 356 | 357 | def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None: 358 | if self.save_top_k == 0: 359 | return 360 | 361 | # validate metric 362 | if self.monitor is not None: 363 | if self.monitor not in monitor_candidates: 364 | m = ( 365 | f"`ModelCheckpoint(monitor={self.monitor!r})` could not find the monitored key in the returned" 366 | f" metrics: {list(monitor_candidates)}." 367 | f" HINT: Did you call `log({self.monitor!r}, value)` in the `LightningModule`?" 368 | ) 369 | if trainer.fit_loop.epoch_loop.val_loop._has_run: 370 | raise MisconfigurationException(m) 371 | warning_cache.warn(m) 372 | self._save_monitor_checkpoint(trainer, monitor_candidates) 373 | else: 374 | self._save_none_monitor_checkpoint(trainer, monitor_candidates) 375 | 376 | def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None: 377 | trainer.save_checkpoint(filepath, self.save_weights_only) 378 | 379 | self._last_global_step_saved = trainer.global_step 380 | 381 | # notify loggers 382 | if trainer.is_global_zero: 383 | for logger in trainer.loggers: 384 | logger.after_save_checkpoint(proxy(self)) 385 | 386 | def _should_skip_saving_checkpoint(self, trainer: "pl.Trainer") -> bool: 387 | from pytorch_lightning.trainer.states import TrainerFn 388 | 389 | return ( 390 | bool(trainer.fast_dev_run) # disable checkpointing with fast_dev_run 391 | or trainer.state.fn != TrainerFn.FITTING # don't save anything during non-fit 392 | or trainer.sanity_checking # don't save anything during sanity check 393 | or self._last_global_step_saved == trainer.global_step # already saved at the last step 394 | ) 395 | 396 | def __validate_init_configuration(self) -> None: 397 | if self.save_top_k < -1: 398 | raise MisconfigurationException(f"Invalid value for save_top_k={self.save_top_k}. Must be >= -1") 399 | if self._every_n_train_steps < 0: 400 | raise MisconfigurationException( 401 | f"Invalid value for every_n_train_steps={self._every_n_train_steps}. Must be >= 0" 402 | ) 403 | if self._every_n_epochs < 0: 404 | raise MisconfigurationException(f"Invalid value for every_n_epochs={self._every_n_epochs}. Must be >= 0") 405 | 406 | every_n_train_steps_triggered = self._every_n_train_steps >= 1 407 | every_n_epochs_triggered = self._every_n_epochs >= 1 408 | train_time_interval_triggered = self._train_time_interval is not None 409 | if every_n_train_steps_triggered + every_n_epochs_triggered + train_time_interval_triggered > 1: 410 | raise MisconfigurationException( 411 | f"Combination of parameters every_n_train_steps={self._every_n_train_steps}, " 412 | f"every_n_epochs={self._every_n_epochs} and train_time_interval={self._train_time_interval} " 413 | "should be mutually exclusive." 414 | ) 415 | 416 | if self.monitor is None: 417 | # -1: save all epochs, 0: nothing is saved, 1: save last epoch 418 | if self.save_top_k not in (-1, 0, 1): 419 | raise MisconfigurationException( 420 | f"ModelCheckpoint(save_top_k={self.save_top_k}, monitor=None) is not a valid" 421 | " configuration. No quantity for top_k to track." 422 | ) 423 | 424 | if self.save_top_k == -1 and self.save_last: 425 | rank_zero_info( 426 | "ModelCheckpoint(save_last=True, save_top_k=-1, monitor=None)" 427 | " will duplicate the last checkpoint saved." 428 | ) 429 | 430 | def __init_ckpt_dir(self, dirpath: Optional[_PATH], filename: Optional[str]) -> None: 431 | self._fs = get_filesystem(dirpath if dirpath else "") 432 | 433 | if dirpath and self._fs.protocol == "file": 434 | dirpath = os.path.realpath(dirpath) 435 | 436 | self.dirpath = dirpath 437 | self.filename = filename 438 | 439 | def __init_monitor_mode(self, mode: str) -> None: 440 | torch_inf = torch.tensor(np.Inf) 441 | mode_dict = {"min": (torch_inf, "min"), "max": (-torch_inf, "max")} 442 | 443 | if mode not in mode_dict: 444 | raise MisconfigurationException(f"`mode` can be {', '.join(mode_dict.keys())} but got {mode}") 445 | 446 | self.kth_value, self.mode = mode_dict[mode] 447 | 448 | def __init_triggers( 449 | self, 450 | every_n_train_steps: Optional[int], 451 | every_n_epochs: Optional[int], 452 | train_time_interval: Optional[timedelta], 453 | ) -> None: 454 | 455 | # Default to running once after each validation epoch if neither 456 | # every_n_train_steps nor every_n_epochs is set 457 | if every_n_train_steps is None and every_n_epochs is None and train_time_interval is None: 458 | every_n_epochs = 1 459 | every_n_train_steps = 0 460 | log.debug("Both every_n_train_steps and every_n_epochs are not set. Setting every_n_epochs=1") 461 | else: 462 | every_n_epochs = every_n_epochs or 0 463 | every_n_train_steps = every_n_train_steps or 0 464 | 465 | self._train_time_interval: Optional[timedelta] = train_time_interval 466 | self._every_n_epochs: int = every_n_epochs 467 | self._every_n_train_steps: int = every_n_train_steps 468 | 469 | @property 470 | def every_n_epochs(self) -> Optional[int]: 471 | return self._every_n_epochs 472 | 473 | def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[Tensor] = None) -> bool: 474 | if current is None: 475 | return False 476 | 477 | if self.save_top_k == -1: 478 | return True 479 | 480 | less_than_k_models = len(self.best_k_models) < self.save_top_k 481 | if less_than_k_models: 482 | return True 483 | 484 | monitor_op = {"min": torch.lt, "max": torch.gt}[self.mode] 485 | should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path]) 486 | 487 | # If using multiple devices, make sure all processes are unanimous on the decision. 488 | should_update_best_and_save = trainer.strategy.reduce_boolean_decision(bool(should_update_best_and_save)) 489 | 490 | return should_update_best_and_save 491 | 492 | @classmethod 493 | def _format_checkpoint_name( 494 | cls, 495 | filename: Optional[str], 496 | metrics: Dict[str, Tensor], 497 | prefix: str = "", 498 | auto_insert_metric_name: bool = True, 499 | ) -> str: 500 | if not filename: 501 | # filename is not set, use default name 502 | filename = "{epoch}" + cls.CHECKPOINT_JOIN_CHAR + "{step}" 503 | 504 | # check and parse user passed keys in the string 505 | groups = re.findall(r"(\{.*?)[:\}]", filename) 506 | if len(groups) >= 0: 507 | for group in groups: 508 | name = group[1:] 509 | 510 | if auto_insert_metric_name: 511 | filename = filename.replace(group, name + "={" + name) 512 | 513 | # support for dots: https://stackoverflow.com/a/7934969 514 | filename = filename.replace(group, f"{{0[{name}]") 515 | 516 | if name not in metrics: 517 | metrics[name] = torch.tensor(0) 518 | filename = filename.format(metrics) 519 | 520 | if prefix: 521 | filename = cls.CHECKPOINT_JOIN_CHAR.join([prefix, filename]) 522 | 523 | return filename 524 | 525 | def format_checkpoint_name( 526 | self, metrics: Dict[str, Tensor], filename: Optional[str] = None, ver: Optional[int] = None 527 | ) -> str: 528 | """Generate a filename according to the defined template. 529 | 530 | Example:: 531 | 532 | >>> tmpdir = os.path.dirname(__file__) 533 | >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}') 534 | >>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=0))) 535 | 'epoch=0.ckpt' 536 | >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch:03d}') 537 | >>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=5))) 538 | 'epoch=005.ckpt' 539 | >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}-{val_loss:.2f}') 540 | >>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=2, val_loss=0.123456))) 541 | 'epoch=2-val_loss=0.12.ckpt' 542 | >>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=2, val_loss=0.12), filename='{epoch:d}')) 543 | 'epoch=2.ckpt' 544 | >>> ckpt = ModelCheckpoint(dirpath=tmpdir, 545 | ... filename='epoch={epoch}-validation_loss={val_loss:.2f}', 546 | ... auto_insert_metric_name=False) 547 | >>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=2, val_loss=0.123456))) 548 | 'epoch=2-validation_loss=0.12.ckpt' 549 | >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{missing:d}') 550 | >>> os.path.basename(ckpt.format_checkpoint_name({})) 551 | 'missing=0.ckpt' 552 | >>> ckpt = ModelCheckpoint(filename='{step}') 553 | >>> os.path.basename(ckpt.format_checkpoint_name(dict(step=0))) 554 | 'step=0.ckpt' 555 | """ 556 | filename = filename or self.filename 557 | filename = self._format_checkpoint_name(filename, metrics, auto_insert_metric_name=self.auto_insert_metric_name) 558 | 559 | if ver is not None: 560 | filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}")) 561 | 562 | ckpt_name = f"{filename}{self.FILE_EXTENSION}" 563 | return os.path.join(self.dirpath, ckpt_name) if self.dirpath else ckpt_name 564 | 565 | def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> _PATH: 566 | """Determines model checkpoint save directory at runtime. Reference attributes from the trainer's logger to 567 | determine where to save checkpoints. The path for saving weights is set in this priority: 568 | 569 | 1. The ``ModelCheckpoint``'s ``dirpath`` if passed in 570 | 2. The ``Logger``'s ``log_dir`` if the trainer has loggers 571 | 3. The ``Trainer``'s ``default_root_dir`` if the trainer has no loggers 572 | 573 | The path gets extended with subdirectory "checkpoints". 574 | """ 575 | if self.dirpath is not None: 576 | # short circuit if dirpath was passed to ModelCheckpoint 577 | return self.dirpath 578 | 579 | if len(trainer.loggers) > 0: 580 | if trainer.loggers[0].save_dir is not None: 581 | save_dir = trainer.loggers[0].save_dir 582 | else: 583 | save_dir = trainer.default_root_dir 584 | name = trainer.loggers[0].name 585 | version = trainer.loggers[0].version 586 | version = version if isinstance(version, str) else f"version_{version}" 587 | ckpt_path = os.path.join(save_dir, str(name), version, "checkpoints") 588 | else: 589 | # if no loggers, use default_root_dir 590 | ckpt_path = os.path.join(trainer.default_root_dir, "checkpoints") 591 | 592 | return ckpt_path 593 | 594 | def _find_last_checkpoints(self, trainer: "pl.Trainer") -> Set[str]: 595 | # find all checkpoints in the folder 596 | ckpt_path = self.__resolve_ckpt_dir(trainer) 597 | if self._fs.exists(ckpt_path): 598 | return { 599 | os.path.normpath(p) 600 | for p in self._fs.ls(ckpt_path, detail=False) 601 | if self.CHECKPOINT_NAME_LAST in os.path.split(p)[1] 602 | } 603 | return set() 604 | 605 | def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None: 606 | if self.save_top_k != 0 and self._fs.isdir(dirpath) and len(self._fs.ls(dirpath)) > 0: 607 | rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.") 608 | 609 | def _get_metric_interpolated_filepath_name( 610 | self, monitor_candidates: Dict[str, Tensor], trainer: "pl.Trainer", del_filepath: Optional[str] = None 611 | ) -> str: 612 | filepath = self.format_checkpoint_name(monitor_candidates) 613 | 614 | version_cnt = self.STARTING_VERSION 615 | while self.file_exists(filepath, trainer) and filepath != del_filepath: 616 | filepath = self.format_checkpoint_name(monitor_candidates, ver=version_cnt) 617 | version_cnt += 1 618 | 619 | return filepath 620 | 621 | def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, Tensor]: 622 | monitor_candidates = deepcopy(trainer.callback_metrics) 623 | # cast to int if necessary because `self.log("epoch", 123)` will convert it to float. if it's not a tensor 624 | # or does not exist we overwrite it as it's likely an error 625 | epoch = monitor_candidates.get("epoch") 626 | monitor_candidates["epoch"] = epoch.int() if isinstance(epoch, Tensor) else torch.tensor(trainer.current_epoch) 627 | step = monitor_candidates.get("step") 628 | monitor_candidates["step"] = step.int() if isinstance(step, Tensor) else torch.tensor(trainer.global_step) 629 | return monitor_candidates 630 | 631 | def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None: 632 | if not self.save_last: 633 | return 634 | 635 | filepath = self.format_checkpoint_name(monitor_candidates, self.CHECKPOINT_NAME_LAST) 636 | 637 | version_cnt = self.STARTING_VERSION 638 | while self.file_exists(filepath, trainer) and filepath != self.last_model_path: 639 | filepath = self.format_checkpoint_name(monitor_candidates, self.CHECKPOINT_NAME_LAST, ver=version_cnt) 640 | version_cnt += 1 641 | 642 | # set the last model path before saving because it will be part of the state. 643 | previous, self.last_model_path = self.last_model_path, filepath 644 | self._save_checkpoint(trainer, filepath) 645 | if previous and previous != filepath: 646 | self._remove_checkpoint(trainer, previous) 647 | 648 | def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None: 649 | assert self.monitor 650 | current = monitor_candidates.get(self.monitor) 651 | if self.check_monitor_top_k(trainer, current): 652 | assert current is not None 653 | self._update_best_and_save(current, trainer, monitor_candidates) 654 | elif self.verbose: 655 | epoch = monitor_candidates["epoch"] 656 | step = monitor_candidates["step"] 657 | rank_zero_info(f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} was not in top {self.save_top_k}") 658 | 659 | def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None: 660 | filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer) 661 | # set the best model path before saving because it will be part of the state. 662 | previous, self.best_model_path = self.best_model_path, filepath 663 | self._save_checkpoint(trainer, filepath) 664 | if self.save_top_k == 1 and previous and previous != filepath: 665 | self._remove_checkpoint(trainer, previous) 666 | 667 | def _update_best_and_save( 668 | self, current: Tensor, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor] 669 | ) -> None: 670 | k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k 671 | 672 | del_filepath = None 673 | if len(self.best_k_models) == k and k > 0: 674 | del_filepath = self.kth_best_model_path 675 | self.best_k_models.pop(del_filepath) 676 | 677 | # do not save nan, replace with +/- inf 678 | if isinstance(current, Tensor) and torch.isnan(current): 679 | current = torch.tensor(float("inf" if self.mode == "min" else "-inf"), device=current.device) 680 | 681 | filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer, del_filepath) 682 | 683 | # save the current score 684 | self.current_score = current 685 | self.best_k_models[filepath] = current 686 | 687 | if len(self.best_k_models) == k: 688 | # monitor dict has reached k elements 689 | _op = max if self.mode == "min" else min 690 | self.kth_best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type] 691 | self.kth_value = self.best_k_models[self.kth_best_model_path] 692 | 693 | _op = min if self.mode == "min" else max 694 | self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type] 695 | self.best_model_score = self.best_k_models[self.best_model_path] 696 | 697 | if self.verbose: 698 | epoch = monitor_candidates["epoch"] 699 | step = monitor_candidates["step"] 700 | rank_zero_info( 701 | f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} reached {current:0.5f}" 702 | f" (best {self.best_model_score:0.5f}), saving model to {filepath!r} as top {k}" 703 | ) 704 | self._save_checkpoint(trainer, filepath) 705 | 706 | if del_filepath is not None and filepath != del_filepath: 707 | self._remove_checkpoint(trainer, del_filepath) 708 | 709 | def to_yaml(self, filepath: Optional[_PATH] = None) -> None: 710 | """Saves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML 711 | file.""" 712 | best_k = {k: v.item() for k, v in self.best_k_models.items()} 713 | if filepath is None: 714 | assert self.dirpath 715 | filepath = os.path.join(self.dirpath, "best_k_models.yaml") 716 | with self._fs.open(filepath, "w") as fp: 717 | yaml.dump(best_k, fp) 718 | 719 | def file_exists(self, filepath: _PATH, trainer: "pl.Trainer") -> bool: 720 | """Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing the internal 721 | state to diverge between ranks.""" 722 | exists = self._fs.exists(filepath) 723 | return trainer.strategy.broadcast(exists) 724 | 725 | def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None: 726 | """Calls the strategy to remove the checkpoint file.""" 727 | trainer.strategy.remove_checkpoint(filepath) 728 | -------------------------------------------------------------------------------- /random_search.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pathlib import Path 3 | from argparse import ArgumentParser 4 | import torchvision 5 | import torch.nn.functional as F 6 | from torch.utils.data import Dataset 7 | from torchvision import datasets 8 | from torchvision import transforms 9 | from torchvision.transforms import ToTensor 10 | from torch.utils.data.dataloader import DataLoader 11 | from torch.utils.tensorboard import SummaryWriter 12 | import matplotlib.pyplot as plt 13 | import argparse 14 | from torch import nn, optim 15 | import json 16 | import math 17 | import os 18 | import random 19 | import signal 20 | import subprocess 21 | import sys 22 | import time 23 | import ray 24 | from ray import tune 25 | from ray.air import session 26 | from ray.air.checkpoint import Checkpoint 27 | from ray.tune.schedulers import ASHAScheduler 28 | import simclr_module 29 | 30 | 31 | def random_search(): 32 | parser = ArgumentParser() 33 | parser.add_argument("--arch", default="resnet50", type=str, help="convnet architecture") 34 | parser.add_argument("--max_t", default=200, type=int, help="max epoch to report") 35 | parser.add_argument("--num_samples", default=100, type=int, help="number of samples") 36 | parser.add_argument("--search_gammas", default=[0.5, 1.0], type=float, nargs='+', help="number of samples") 37 | parser.add_argument("--search_mus", default=[1.0], type=float, nargs='+', help="projection mu") 38 | parser.add_argument("--loss_type", default="origin", type=str, help="search type, origin, sum or product") 39 | parser.add_argument("--search_acos_orders", default=[0], type=int, nargs='+', help="number of samples") 40 | # specify flags to store false 41 | parser.add_argument("--first_conv", action="store_false") 42 | parser.add_argument("--maxpool1", action="store_false") 43 | parser.add_argument("--hidden_mlp", default=2048, type=int, help="hidden layer dimension in projection head") 44 | parser.add_argument("--feat_dim", default=128, type=int, help="feature dimension") 45 | parser.add_argument("--norm_p", default=2., type=float, help="norm p, -1 for inf") 46 | parser.add_argument("--distance_p", default=2., type=float, help="distance p, -1 for inf") 47 | parser.add_argument("--acos_order", default=0, type=int, help="order of acos, 0 for not use acos kernel") 48 | parser.add_argument("--gamma", default=2., type=float, help="gamma") 49 | parser.add_argument("--online_ft", action="store_true") 50 | parser.add_argument("--fp32", action="store_true") 51 | 52 | # transform params 53 | parser.add_argument("--gaussian_blur", action="store_true", help="add gaussian blur") 54 | parser.add_argument("--jitter_strength", type=float, default=1.0, help="jitter strength") 55 | parser.add_argument("--dataset", type=str, default="cifar10", help="stl10, cifar10") 56 | parser.add_argument("--data_dir", type=str, default="/home/yjq/graph", help="path to download data") 57 | 58 | # training params 59 | parser.add_argument("--fast_dev_run", default=1, type=int) 60 | parser.add_argument("--num_nodes", default=1, type=int, help="number of nodes for training") 61 | parser.add_argument("--gpus", default=1, type=int, help="number of gpus to train on") 62 | parser.add_argument("--num_workers", default=8, type=int, help="num of workers per GPU") 63 | parser.add_argument("--optimizer", default="adam", type=str, help="choose between adam/lars") 64 | parser.add_argument("--exclude_bn_bias", action="store_true", help="exclude bn/bias from weight decay") 65 | parser.add_argument("--max_epochs", default=100, type=int, help="number of total epochs to run") 66 | parser.add_argument("--max_steps", default=-1, type=int, help="max steps") 67 | parser.add_argument("--warmup_epochs", default=10, type=int, help="number of warmup epochs") 68 | parser.add_argument("--batch_size", default=128, type=int, help="batch size per gpu") 69 | 70 | parser.add_argument("--temperature", default=0.1, type=float, help="temperature parameter in training loss") 71 | parser.add_argument("--weight_decay", default=1e-6, type=float, help="weight decay") 72 | parser.add_argument("--learning_rate", default=1e-3, type=float, help="base learning rate") 73 | parser.add_argument("--start_lr", default=0, type=float, help="initial warmup learning rate") 74 | parser.add_argument("--final_lr", type=float, default=1e-6, help="final learning rate") 75 | 76 | args = parser.parse_args() 77 | 78 | max_t = args.max_t 79 | num_samples = args.num_samples 80 | 81 | search_params = { 82 | "learning_rate": tune.loguniform(1e-2, 10), 83 | "temperature": tune.loguniform(1e-2, 1), 84 | "gamma": tune.choice(args.search_gammas), 85 | "projection_mu": tune.choice(args.search_mus), 86 | "gamma_lambd": tune.uniform(0, 1), 87 | "acos_order": tune.choice(args.search_acos_orders) 88 | } 89 | scheduler = ASHAScheduler( 90 | max_t=max_t, 91 | grace_period=20, 92 | reduction_factor=2) 93 | 94 | tuner = tune.Tuner( 95 | tune.with_resources( 96 | tune.with_parameters(simclr_module.cli_main, args=args, isTune=True), 97 | resources={"cpu": 2, "gpu": 1} 98 | ), 99 | tune_config=tune.TuneConfig( 100 | metric="online_val_acc", 101 | mode="max", 102 | scheduler=scheduler, 103 | num_samples=num_samples, 104 | ), 105 | run_config=ray.air.RunConfig( 106 | local_dir="~/ray_results" 107 | ), 108 | param_space=search_params, 109 | ) 110 | 111 | results = tuner.fit() 112 | best_result = results.get_best_result("online_val_acc", "max") 113 | print("Best trial config: {}".format(best_result.config)) 114 | print("Best trial final validation accuracy: {}".format( 115 | best_result.metrics["online_val_acc"])) 116 | 117 | 118 | if __name__ == "__main__": 119 | random_search() -------------------------------------------------------------------------------- /simclr.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from pl_bolts.datamodules import CIFAR10DataModule 3 | from pl_bolts.models.self_supervised.simclr.transforms import ( 4 | SimCLREvalDataTransform, SimCLRTrainDataTransform) 5 | import math 6 | from argparse import ArgumentParser 7 | 8 | import torch 9 | from pytorch_lightning import LightningModule, Trainer 10 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint 11 | from torch import Tensor, nn 12 | from torch.nn import functional as F 13 | 14 | from pl_bolts.models.self_supervised.resnets import resnet18, resnet50 15 | from pl_bolts.optimizers.lars import LARS 16 | from pl_bolts.optimizers.lr_scheduler import linear_warmup_decay 17 | from pl_bolts.transforms.dataset_normalizations import ( 18 | cifar10_normalization, 19 | imagenet_normalization, 20 | stl10_normalization, 21 | ) 22 | from pl_bolts.utils.stability import under_review 23 | 24 | @under_review() 25 | class SyncFunction(torch.autograd.Function): 26 | @staticmethod 27 | def forward(ctx, tensor): 28 | ctx.batch_size = tensor.shape[0] 29 | 30 | gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] 31 | 32 | torch.distributed.all_gather(gathered_tensor, tensor) 33 | gathered_tensor = torch.cat(gathered_tensor, 0) 34 | 35 | return gathered_tensor 36 | 37 | @staticmethod 38 | def backward(ctx, grad_output): 39 | grad_input = grad_output.clone() 40 | torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False) 41 | 42 | idx_from = torch.distributed.get_rank() * ctx.batch_size 43 | idx_to = (torch.distributed.get_rank() + 1) * ctx.batch_size 44 | return grad_input[idx_from:idx_to] 45 | 46 | 47 | @under_review() 48 | class Projection(nn.Module): 49 | def __init__(self, input_dim=2048, hidden_dim=2048, output_dim=128): 50 | super().__init__() 51 | self.output_dim = output_dim 52 | self.input_dim = input_dim 53 | self.hidden_dim = hidden_dim 54 | 55 | self.model = nn.Sequential( 56 | nn.Linear(self.input_dim, self.hidden_dim), 57 | nn.BatchNorm1d(self.hidden_dim), 58 | nn.ReLU(), 59 | nn.Linear(self.hidden_dim, self.output_dim, bias=False), 60 | ) 61 | 62 | def forward(self, x): 63 | x = self.model(x) 64 | return F.normalize(x, dim=1) 65 | 66 | class SimCLR(LightningModule): 67 | def __init__( 68 | self, 69 | gpus: int, 70 | num_samples: int, 71 | batch_size: int, 72 | dataset: str, 73 | num_nodes: int = 1, 74 | arch: str = "resnet50", 75 | hidden_mlp: int = 2048, 76 | feat_dim: int = 128, 77 | warmup_epochs: int = 10, 78 | max_epochs: int = 100, 79 | temperature: float = 0.1, 80 | first_conv: bool = True, 81 | maxpool1: bool = True, 82 | optimizer: str = "adam", 83 | exclude_bn_bias: bool = False, 84 | start_lr: float = 0.0, 85 | learning_rate: float = 1e-3, 86 | final_lr: float = 0.0, 87 | weight_decay: float = 1e-6, 88 | **kwargs 89 | ): 90 | """ 91 | Args: 92 | batch_size: the batch size 93 | num_samples: num samples in the dataset 94 | warmup_epochs: epochs to warmup the lr for 95 | lr: the optimizer learning rate 96 | opt_weight_decay: the optimizer weight decay 97 | loss_temperature: the loss temperature 98 | """ 99 | super().__init__() 100 | self.save_hyperparameters() 101 | 102 | self.gpus = gpus 103 | self.num_nodes = num_nodes 104 | self.arch = arch 105 | self.dataset = dataset 106 | self.num_samples = num_samples 107 | self.batch_size = batch_size 108 | 109 | self.hidden_mlp = hidden_mlp 110 | self.feat_dim = feat_dim 111 | self.first_conv = first_conv 112 | self.maxpool1 = maxpool1 113 | 114 | self.optim = optimizer 115 | self.exclude_bn_bias = exclude_bn_bias 116 | self.weight_decay = weight_decay 117 | self.temperature = temperature 118 | 119 | self.start_lr = start_lr 120 | self.final_lr = final_lr 121 | self.learning_rate = learning_rate 122 | self.warmup_epochs = warmup_epochs 123 | self.max_epochs = max_epochs 124 | 125 | self.encoder = self.init_model() 126 | 127 | self.projection = Projection(input_dim=self.hidden_mlp, hidden_dim=self.hidden_mlp, output_dim=self.feat_dim) 128 | 129 | # compute iters per epoch 130 | global_batch_size = self.num_nodes * self.gpus * self.batch_size if self.gpus > 0 else self.batch_size 131 | self.train_iters_per_epoch = self.num_samples // global_batch_size 132 | 133 | def init_model(self): 134 | if self.arch == "resnet18": 135 | backbone = resnet18 136 | elif self.arch == "resnet50": 137 | backbone = resnet50 138 | 139 | return backbone(first_conv=self.first_conv, maxpool1=self.maxpool1, return_all_feature_maps=False) 140 | 141 | def forward(self, x): 142 | # bolts resnet returns a list 143 | return self.encoder(x)[-1] 144 | 145 | def shared_step(self, batch): 146 | if self.dataset == "stl10": 147 | unlabeled_batch = batch[0] 148 | batch = unlabeled_batch 149 | 150 | # final image in tuple is for online eval 151 | (img1, img2, _), y = batch 152 | 153 | # get h representations, bolts resnet returns a list 154 | h1 = self(img1) 155 | h2 = self(img2) 156 | 157 | # get z representations 158 | z1 = self.projection(h1) 159 | z2 = self.projection(h2) 160 | 161 | loss = self.nt_xent_loss(z1, z2, self.temperature) 162 | 163 | return loss 164 | 165 | def training_step(self, batch, batch_idx): 166 | loss = self.shared_step(batch) 167 | 168 | self.log("train_loss", loss, on_step=True, on_epoch=False) 169 | return loss 170 | 171 | def validation_step(self, batch, batch_idx): 172 | loss = self.shared_step(batch) 173 | 174 | self.log("val_loss", loss, on_step=False, on_epoch=True, sync_dist=True) 175 | return loss 176 | 177 | def exclude_from_wt_decay(self, named_params, weight_decay, skip_list=("bias", "bn")): 178 | params = [] 179 | excluded_params = [] 180 | 181 | for name, param in named_params: 182 | if not param.requires_grad: 183 | continue 184 | elif any(layer_name in name for layer_name in skip_list): 185 | excluded_params.append(param) 186 | else: 187 | params.append(param) 188 | 189 | return [ 190 | {"params": params, "weight_decay": weight_decay}, 191 | { 192 | "params": excluded_params, 193 | "weight_decay": 0.0, 194 | }, 195 | ] 196 | 197 | def configure_optimizers(self): 198 | if self.exclude_bn_bias: 199 | params = self.exclude_from_wt_decay(self.named_parameters(), weight_decay=self.weight_decay) 200 | else: 201 | params = self.parameters() 202 | 203 | if self.optim == "lars": 204 | optimizer = LARS( 205 | params, 206 | lr=self.learning_rate, 207 | momentum=0.9, 208 | weight_decay=self.weight_decay, 209 | trust_coefficient=0.001, 210 | ) 211 | elif self.optim == "adam": 212 | optimizer = torch.optim.Adam(params, lr=self.learning_rate, weight_decay=self.weight_decay) 213 | 214 | warmup_steps = self.train_iters_per_epoch * self.warmup_epochs 215 | total_steps = self.train_iters_per_epoch * self.max_epochs 216 | 217 | scheduler = { 218 | "scheduler": torch.optim.lr_scheduler.LambdaLR( 219 | optimizer, 220 | linear_warmup_decay(warmup_steps, total_steps, cosine=True), 221 | ), 222 | "interval": "step", 223 | "frequency": 1, 224 | } 225 | 226 | return [optimizer], [scheduler] 227 | 228 | def nt_xent_loss(self, out_1, out_2, temperature, eps=1e-6): 229 | """ 230 | assume out_1 and out_2 are normalized 231 | out_1: [batch_size, dim] 232 | out_2: [batch_size, dim] 233 | """ 234 | # gather representations in case of distributed training 235 | # out_1_dist: [batch_size * world_size, dim] 236 | # out_2_dist: [batch_size * world_size, dim] 237 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 238 | out_1_dist = SyncFunction.apply(out_1) 239 | out_2_dist = SyncFunction.apply(out_2) 240 | else: 241 | out_1_dist = out_1 242 | out_2_dist = out_2 243 | 244 | # out: [2 * batch_size, dim] 245 | # out_dist: [2 * batch_size * world_size, dim] 246 | out = torch.cat([out_1, out_2], dim=0) 247 | out_dist = torch.cat([out_1_dist, out_2_dist], dim=0) 248 | 249 | # cov and sim: [2 * batch_size, 2 * batch_size * world_size] 250 | # neg: [2 * batch_size] 251 | cov = torch.mm(out, out_dist.t().contiguous()) 252 | sim = torch.exp(cov / temperature) 253 | neg = sim.sum(dim=-1) 254 | 255 | # from each row, subtract e^(1/temp) to remove similarity measure for x1.x1 256 | row_sub = Tensor(neg.shape).fill_(math.e ** (1 / temperature)).to(neg.device) 257 | neg = torch.clamp(neg - row_sub, min=eps) # clamp for numerical stability 258 | 259 | # Positive similarity, pos becomes [2 * batch_size] 260 | pos = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature) 261 | pos = torch.cat([pos, pos], dim=0) 262 | 263 | loss = -torch.log(pos / (neg + eps)).mean() 264 | 265 | return loss 266 | # data 267 | dm = CIFAR10DataModule(num_workers=0) 268 | dm.train_transforms = SimCLRTrainDataTransform(256) 269 | dm.val_transforms = SimCLREvalDataTransform(256) 270 | 271 | # model 272 | model = SimCLR(num_samples=dm.num_samples, batch_size=dm.batch_size, dataset='cifar10', gpus=8, optimizer='lars', 273 | learning_rate=1.5, temperature=0.5) 274 | 275 | # fit 276 | trainer = pl.Trainer() 277 | trainer.fit(model, datamodule=dm) -------------------------------------------------------------------------------- /simclr_finetune.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | 4 | from pytorch_lightning import Trainer, seed_everything 5 | 6 | from simclr_module import SimCLR 7 | from transforms import SimCLRFinetuneTransform 8 | from pl_bolts.models.self_supervised.ssl_finetuner import SSLFineTuner 9 | from cifar100_datamodule import CIFAR100DataModule 10 | from pl_bolts.transforms.dataset_normalizations import ( 11 | cifar10_normalization, 12 | imagenet_normalization, 13 | stl10_normalization, 14 | ) 15 | from pl_bolts.utils.stability import under_review 16 | 17 | 18 | @under_review() 19 | def cli_main(): # pragma: no cover 20 | from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule 21 | 22 | seed_everything(1234) 23 | 24 | parser = ArgumentParser() 25 | parser.add_argument("--dataset", type=str, help="cifar10, stl10, imagenet", default="cifar10") 26 | parser.add_argument("--ckpt_path", type=str, help="path to ckpt") 27 | parser.add_argument("--data_dir", type=str, help="path to dataset", default=os.getcwd()) 28 | 29 | parser.add_argument("--batch_size", default=64, type=int, help="batch size per gpu") 30 | parser.add_argument("--num_workers", default=8, type=int, help="num of workers per GPU") 31 | parser.add_argument("--feat_dim", default=128, type=int, help="number of feat dim(256 for product loss, 128 for others)") 32 | parser.add_argument("--gpus", default=4, type=int, help="number of GPUs") 33 | parser.add_argument("--num_epochs", default=100, type=int, help="number of epochs") 34 | 35 | # fine-tuner params 36 | parser.add_argument("--in_features", type=int, default=2048) 37 | parser.add_argument("--dropout", type=float, default=0.0) 38 | parser.add_argument("--learning_rate", type=float, default=0.3) 39 | parser.add_argument("--weight_decay", type=float, default=1e-6) 40 | parser.add_argument("--nesterov", type=bool, default=False) # fix nesterov flag here 41 | parser.add_argument("--scheduler_type", type=str, default="cosine") 42 | parser.add_argument("--gamma", type=float, default=0.1) 43 | parser.add_argument("--final_lr", type=float, default=0.0) 44 | 45 | args = parser.parse_args() 46 | 47 | if args.dataset == "cifar10" or args.dataset == "cifar100": 48 | dm = CIFAR10DataModule( 49 | data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers 50 | ) if args.dataset == "cifar10" else CIFAR100DataModule( 51 | data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers 52 | ) 53 | 54 | dm.train_transforms = SimCLRFinetuneTransform( 55 | normalize=cifar10_normalization(), input_height=dm.dims[-1], eval_transform=False 56 | ) 57 | dm.val_transforms = SimCLRFinetuneTransform( 58 | normalize=cifar10_normalization(), input_height=dm.dims[-1], eval_transform=True 59 | ) 60 | dm.test_transforms = SimCLRFinetuneTransform( 61 | normalize=cifar10_normalization(), input_height=dm.dims[-1], eval_transform=True 62 | ) 63 | 64 | args.maxpool1 = False 65 | args.first_conv = False 66 | args.num_samples = 1 67 | elif args.dataset == "stl10": 68 | dm = STL10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) 69 | 70 | dm.train_dataloader = dm.train_dataloader_labeled 71 | dm.val_dataloader = dm.val_dataloader_labeled 72 | args.num_samples = 1 73 | 74 | dm.train_transforms = SimCLRFinetuneTransform( 75 | normalize=stl10_normalization(), input_height=dm.dims[-1], eval_transform=False 76 | ) 77 | dm.val_transforms = SimCLRFinetuneTransform( 78 | normalize=stl10_normalization(), input_height=dm.dims[-1], eval_transform=True 79 | ) 80 | dm.test_transforms = SimCLRFinetuneTransform( 81 | normalize=stl10_normalization(), input_height=dm.dims[-1], eval_transform=True 82 | ) 83 | 84 | args.maxpool1 = False 85 | args.first_conv = True 86 | elif args.dataset == "imagenet": 87 | dm = ImagenetDataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) 88 | 89 | dm.train_transforms = SimCLRFinetuneTransform( 90 | normalize=imagenet_normalization(), input_height=dm.dims[-1], eval_transform=False 91 | ) 92 | dm.val_transforms = SimCLRFinetuneTransform( 93 | normalize=imagenet_normalization(), input_height=dm.dims[-1], eval_transform=True 94 | ) 95 | dm.test_transforms = SimCLRFinetuneTransform( 96 | normalize=imagenet_normalization(), input_height=dm.dims[-1], eval_transform=True 97 | ) 98 | 99 | args.num_samples = 1 100 | args.maxpool1 = True 101 | args.first_conv = True 102 | else: 103 | raise NotImplementedError("other datasets have not been implemented till now") 104 | 105 | backbone = SimCLR( 106 | gpus=args.gpus, 107 | nodes=1, 108 | num_samples=args.num_samples, 109 | batch_size=args.batch_size, 110 | maxpool1=args.maxpool1, 111 | first_conv=args.first_conv, 112 | dataset=args.dataset, 113 | feat_dim=args.feat_dim 114 | ).load_from_checkpoint(args.ckpt_path, strict=False) 115 | 116 | tuner = SSLFineTuner( 117 | backbone, 118 | in_features=args.in_features, 119 | num_classes=dm.num_classes, 120 | epochs=args.num_epochs, 121 | hidden_dim=None, 122 | dropout=args.dropout, 123 | learning_rate=args.learning_rate, 124 | weight_decay=args.weight_decay, 125 | nesterov=args.nesterov, 126 | scheduler_type=args.scheduler_type, 127 | gamma=args.gamma, 128 | final_lr=args.final_lr, 129 | ) 130 | 131 | trainer = Trainer( 132 | gpus=args.gpus, 133 | num_nodes=1, 134 | precision=16, 135 | max_epochs=args.num_epochs, 136 | accelerator="gpu", 137 | sync_batchnorm=True if args.gpus > 1 else False, 138 | ) 139 | 140 | trainer.fit(tuner, dm) 141 | trainer.test(datamodule=dm) 142 | 143 | 144 | if __name__ == "__main__": 145 | cli_main() -------------------------------------------------------------------------------- /simclr_module.py: -------------------------------------------------------------------------------- 1 | import math 2 | from argparse import ArgumentParser 3 | 4 | import torch 5 | import numpy 6 | from pytorch_lightning import LightningModule, Trainer 7 | from pytorch_lightning.callbacks import LearningRateMonitor 8 | from model_checkpoint import ModelCheckpoint 9 | from torch import Tensor, nn 10 | from torch.nn import functional as F 11 | 12 | from cifar100_datamodule import CIFAR100DataModule 13 | from tiny_imagenet_datamodule import TinyImagenetDataModule 14 | from pl_bolts.models.self_supervised.resnets import resnet18, resnet50 15 | from pl_bolts.optimizers.lars import LARS 16 | from pl_bolts.optimizers.lr_scheduler import linear_warmup_decay 17 | from pl_bolts.transforms.dataset_normalizations import ( 18 | cifar10_normalization, 19 | imagenet_normalization, 20 | stl10_normalization, 21 | ) 22 | from pl_bolts.utils.stability import under_review 23 | 24 | 25 | @under_review() 26 | class SyncFunction(torch.autograd.Function): 27 | @staticmethod 28 | def forward(ctx, tensor): 29 | ctx.batch_size = tensor.shape[0] 30 | 31 | gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] 32 | 33 | torch.distributed.all_gather(gathered_tensor, tensor) 34 | gathered_tensor = torch.cat(gathered_tensor, 0) 35 | 36 | return gathered_tensor 37 | 38 | @staticmethod 39 | def backward(ctx, grad_output): 40 | grad_input = grad_output.clone() 41 | torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False) 42 | 43 | idx_from = torch.distributed.get_rank() * ctx.batch_size 44 | idx_to = (torch.distributed.get_rank() + 1) * ctx.batch_size 45 | return grad_input[idx_from:idx_to] 46 | 47 | 48 | @under_review() 49 | class Projection(nn.Module): 50 | def __init__(self, input_dim=2048, hidden_dim=2048, output_dim=128, norm_p=2., mu=1.): 51 | super().__init__() 52 | self.output_dim = output_dim 53 | self.input_dim = input_dim 54 | self.hidden_dim = hidden_dim 55 | self.norm_p = norm_p 56 | self.mu = mu 57 | 58 | print(input_dim, output_dim, hidden_dim) 59 | self.model = nn.Sequential( 60 | nn.Linear(self.input_dim, self.hidden_dim), 61 | nn.BatchNorm1d(self.hidden_dim), 62 | nn.ReLU(), 63 | nn.Linear(self.hidden_dim, self.output_dim, bias=False), 64 | ) 65 | 66 | def forward(self, x): 67 | # print(x.size(), self.hidden_dim, self.input_dim) 68 | x = self.model(x) 69 | return F.normalize(x, dim=1, p=self.norm_p) * numpy.sqrt(self.mu) 70 | 71 | 72 | @under_review() 73 | class SimCLR(LightningModule): 74 | def __init__( 75 | self, 76 | gpus: int, 77 | num_samples: int, 78 | batch_size: int, 79 | dataset: str, 80 | num_nodes: int = 1, 81 | arch: str = "resnet50", 82 | hidden_mlp: int = 2048, 83 | feat_dim: int = 128, 84 | warmup_epochs: int = 10, 85 | max_epochs: int = 100, 86 | temperature: float = 0.1, 87 | first_conv: bool = True, 88 | maxpool1: bool = True, 89 | optimizer: str = "adam", 90 | exclude_bn_bias: bool = False, 91 | start_lr: float = 0.0, 92 | learning_rate: float = 1e-3, 93 | final_lr: float = 0.0, 94 | weight_decay: float = 1e-6, 95 | norm_p: float = 2.0, 96 | distance_p: float = 2.0, 97 | gamma: float = 2.0, 98 | acos_order: int = 0, 99 | gamma_lambd: float=1.0, 100 | loss_type: str = "origin", 101 | projection_mu: float=1.0, 102 | **kwargs 103 | ): 104 | """ 105 | Args: 106 | batch_size: the batch size 107 | num_samples: num samples in the dataset 108 | warmup_epochs: epochs to warmup the lr for 109 | lr: the optimizer learning rate 110 | opt_weight_decay: the optimizer weight decay 111 | loss_temperature: the loss temperature 112 | """ 113 | super().__init__() 114 | self.save_hyperparameters() 115 | 116 | self.gpus = gpus 117 | self.num_nodes = num_nodes 118 | self.arch = arch 119 | self.dataset = dataset 120 | self.num_samples = num_samples 121 | self.batch_size = batch_size 122 | 123 | self.loss_type = loss_type 124 | self.hidden_mlp = hidden_mlp 125 | self.feat_dim = feat_dim if self.loss_type != "product" else feat_dim * 2 126 | self.first_conv = first_conv 127 | self.maxpool1 = maxpool1 128 | self.norm_p = norm_p 129 | self.distance_p = distance_p 130 | self.gamma = gamma 131 | self.projection_mu = projection_mu 132 | self.acos_order = acos_order 133 | self.max_epochs = max_epochs 134 | 135 | 136 | self.optim = optimizer 137 | self.exclude_bn_bias = exclude_bn_bias 138 | self.weight_decay = weight_decay 139 | self.temperature = temperature 140 | 141 | self.start_lr = start_lr 142 | self.final_lr = final_lr 143 | self.learning_rate = learning_rate 144 | self.warmup_epochs = warmup_epochs 145 | self.gamma_lambd = gamma_lambd 146 | 147 | print(self.distance_p, self.norm_p, self.feat_dim) 148 | self.encoder = self.init_model() 149 | 150 | self.projection = Projection(input_dim=512 if self.arch == "resnet18" else 2048, hidden_dim=self.hidden_mlp, output_dim=self.feat_dim, norm_p=self.norm_p, mu=self.projection_mu) 151 | 152 | # compute iters per epoch 153 | global_batch_size = self.num_nodes * self.gpus * self.batch_size if self.gpus > 0 else self.batch_size 154 | self.train_iters_per_epoch = self.num_samples // global_batch_size 155 | 156 | def init_model(self): 157 | if self.arch == "resnet18": 158 | backbone = resnet18 159 | elif self.arch == "resnet50": 160 | backbone = resnet50 161 | 162 | return backbone(first_conv=self.first_conv, maxpool1=self.maxpool1, return_all_feature_maps=False) 163 | 164 | def forward(self, x): 165 | # bolts resnet returns a list 166 | return self.encoder(x)[-1] 167 | 168 | def shared_step(self, batch): 169 | if self.dataset == "stl10": 170 | unlabeled_batch = batch[0] 171 | batch = unlabeled_batch 172 | 173 | # final image in tuple is for online eval 174 | (img1, img2, _), y = batch 175 | 176 | # get h representations, bolts resnet returns a list 177 | h1 = self(img1) 178 | h2 = self(img2) 179 | 180 | # get z representations 181 | z1 = self.projection(h1) 182 | z2 = self.projection(h2) 183 | 184 | loss = self.nt_xent_loss(z1, z2, self.temperature) 185 | 186 | return loss 187 | 188 | def training_step(self, batch, batch_idx): 189 | loss = self.shared_step(batch) 190 | 191 | self.log("train_loss", loss, on_step=True, on_epoch=False) 192 | return loss 193 | 194 | def validation_step(self, batch, batch_idx): 195 | loss = self.shared_step(batch) 196 | 197 | self.log("val_loss", loss, on_step=False, on_epoch=True, sync_dist=True) 198 | return loss 199 | 200 | def exclude_from_wt_decay(self, named_params, weight_decay, skip_list=("bias", "bn")): 201 | params = [] 202 | excluded_params = [] 203 | 204 | for name, param in named_params: 205 | if not param.requires_grad: 206 | continue 207 | elif any(layer_name in name for layer_name in skip_list): 208 | excluded_params.append(param) 209 | else: 210 | params.append(param) 211 | 212 | return [ 213 | {"params": params, "weight_decay": weight_decay}, 214 | { 215 | "params": excluded_params, 216 | "weight_decay": 0.0, 217 | }, 218 | ] 219 | 220 | def configure_optimizers(self): 221 | if self.exclude_bn_bias: 222 | params = self.exclude_from_wt_decay(self.named_parameters(), weight_decay=self.weight_decay) 223 | else: 224 | params = self.parameters() 225 | 226 | if self.optim == "lars": 227 | optimizer = LARS( 228 | params, 229 | lr=self.learning_rate, 230 | momentum=0.9, 231 | weight_decay=self.weight_decay, 232 | trust_coefficient=0.001, 233 | ) 234 | elif self.optim == "adam": 235 | optimizer = torch.optim.Adam(params, lr=self.learning_rate, weight_decay=self.weight_decay) 236 | 237 | warmup_steps = self.train_iters_per_epoch * self.warmup_epochs 238 | total_steps = self.train_iters_per_epoch * self.max_epochs 239 | 240 | scheduler = { 241 | "scheduler": torch.optim.lr_scheduler.LambdaLR( 242 | optimizer, 243 | linear_warmup_decay(warmup_steps, total_steps, cosine=True), 244 | ), 245 | "interval": "step", 246 | "frequency": 1, 247 | } 248 | 249 | return [optimizer], [scheduler] 250 | 251 | def acos_kernel_distance(self, angle): 252 | if self.acos_order == 1: 253 | dis = numpy.pi - angle 254 | elif self.acos_order == 2: 255 | dis = torch.sin(angle) + (numpy.pi - angle) * torch.cos(angle) 256 | elif self.acos_order == 3: 257 | dis = torch.sin(angle) * torch.cos(angle) * 3. + (numpy.pi - angle) * ( 258 | 1 + torch.cos(angle) * torch.cos(angle) * 2.) 259 | else: 260 | raise NotImplementedError 261 | return dis 262 | 263 | 264 | def gamma_loss(self, out_1, out_2, gamma, temperature, eps=1e-6): 265 | 266 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 267 | out_1_dist = SyncFunction.apply(out_1) 268 | out_2_dist = SyncFunction.apply(out_2) 269 | else: 270 | out_1_dist = out_1 271 | out_2_dist = out_2 272 | # out: [2 * batch_size, dim] 273 | # out_dist: [2 * batch_size * world_size, dim] 274 | out = torch.cat([out_1, out_2], dim=0) 275 | out_dist = torch.cat([out_1_dist, out_2_dist], dim=0) 276 | cov = torch.pow(torch.cdist(out, out_dist, p=self.distance_p), gamma) * -1. 277 | # if self.norm_p == 2.0 and self.distance_p == 2.0: 278 | # cov = 1 - (cov * 0.5) 279 | # # cov2 = torch.mm(out, out_dist.t().contiguous()) 280 | # # cov3 = cov - cov2 281 | # # print(cov3) 282 | sim = torch.exp(cov / temperature) 283 | neg = torch.clamp(sim.sum(dim=-1) - sim.diag(), min=eps) 284 | sim_adj = torch.pow(torch.norm(out_1 - out_2, dim=-1, p=self.distance_p), gamma) * -1. 285 | # if self.norm_p == 2.0 and self.distance_p == 2.0: 286 | # sim_adj = 1 - (sim_adj * 0.5) 287 | pos = torch.exp(sim_adj / temperature) 288 | pos = torch.cat([pos, pos], dim=0) 289 | loss = -torch.log(pos / (neg + eps)).mean() 290 | return loss 291 | 292 | def spectral_loss(self, out_1, out_2, eps=1e-6): 293 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 294 | out_1_dist = SyncFunction.apply(out_1) 295 | out_2_dist = SyncFunction.apply(out_2) 296 | else: 297 | out_1_dist = out_1 298 | out_2_dist = out_2 299 | # out: [2 * batch_size, dim] 300 | # out_dist: [2 * batch_size * world_size, dim] 301 | out = torch.cat([out_1, out_2], dim=0) 302 | out_dist = torch.cat([out_1_dist, out_2_dist], dim=0) 303 | cov = torch.pow(torch.mm(out, out_dist.t().contiguous()), 2) 304 | pos = torch.sum(torch.clamp(cov.sum(dim=-1) - cov.diag(), min=eps) * (1. / (out_1.shape[0] * (out_1.shape[0] - 1)))) 305 | neg = torch.sum(out_1 * out_2) * (2. / (out_1.shape[0])) 306 | return pos - neg 307 | 308 | def nt_xent_loss(self, out_1, out_2, temperature, eps=1e-6): 309 | """ 310 | assume out_1 and out_2 are normalized 311 | out_1: [batch_size, dim] 312 | out_2: [batch_size, dim] 313 | """ 314 | 315 | if torch.distributed.is_available() and torch.distributed.is_initialized(): 316 | out_1_dist = SyncFunction.apply(out_1) 317 | out_2_dist = SyncFunction.apply(out_2) 318 | else: 319 | out_1_dist = out_1 320 | out_2_dist = out_2 321 | out = torch.cat([out_1, out_2], dim=0) 322 | out_dist = torch.cat([out_1_dist, out_2_dist], dim=0) 323 | # gather representations in case of distributed training 324 | # out_1_dist: [batch_size * world_size, dim] 325 | # out_2_dist: [batch_size * world_size, dim] 326 | 327 | # cov and sim: [2 * batch_size, 2 * batch_size * world_size] 328 | # neg: [2 * batch_size] 329 | # if self.distance_p == 2.0: 330 | # cov = torch.mm(out, out_dist.t().contiguous()) 331 | # sim = torch.exp(cov / temperature) 332 | # neg = sim.sum(dim=-1) 333 | # 334 | # # from each row, subtract e^(1/temp) to remove similarity measure for x1.x1 335 | # row_sub = Tensor(neg.shape).fill_(math.e ** (1 / temperature)).to(neg.device) 336 | # neg = torch.clamp(neg - row_sub, min=eps) # clamp for numerical stability 337 | # 338 | # # Positive similarity, pos becomes [2 * batch_size] 339 | # pos = torch.exp(torch.sum(out_1 * out_2, dim=-1) / temperature) 340 | # 341 | # else: 342 | if self.acos_order == 0: 343 | if self.loss_type == "sum": 344 | loss = self.gamma_loss(out_1=out_1, out_2=out_2, gamma=self.gamma, temperature=self.temperature) * self.gamma_lambd + self.gamma_loss(out_1=out_1, out_2=out_2, gamma=2.0, temperature=self.temperature) * (1. - self.gamma_lambd) 345 | elif self.loss_type == "origin": 346 | loss = self.gamma_loss(out_1=out_1, out_2=out_2, gamma=self.gamma, temperature=self.temperature) 347 | elif self.loss_type == "product": 348 | loss = self.gamma_loss(out_1=out_1[:, 0:self.feat_dim // 2], out_2=out_2[:, 0:self.feat_dim // 2], gamma=self.gamma, 349 | temperature=self.temperature) * self.gamma_lambd + self.gamma_loss(out_1=out_1[:, self.feat_dim // 2: self.feat_dim], 350 | out_2=out_2[:, self.feat_dim // 2: self.feat_dim], 351 | gamma=2.0, 352 | temperature=self.temperature) * (1. - self.gamma_lambd) 353 | elif self.loss_type == "spectral": 354 | loss = self.spectral_loss(out_1=out_1, out_2=out_2) 355 | else: 356 | raise NotImplementedError 357 | else: 358 | sim = self.acos_kernel_distance(torch.acos(self.temperature * torch.mm(out, out_dist.t().contiguous()) + 1 - self.temperature + eps)) 359 | neg = torch.clamp(sim.sum(dim=-1) - sim.diag(), min=eps) 360 | pos = self.acos_kernel_distance(torch.acos(self.temperature * torch.sum(out_1 * out_2, dim=-1) + 1 - self.temperature + eps)) 361 | pos = torch.cat([pos, pos], dim=0) 362 | loss = -torch.log(pos / (neg + eps)).mean() 363 | 364 | return loss 365 | 366 | @staticmethod 367 | def add_model_specific_args(parent_parser): 368 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 369 | 370 | # model params 371 | parser.add_argument("--arch", default="resnet50", type=str, help="convnet architecture") 372 | # specify flags to store false 373 | parser.add_argument("--first_conv", action="store_false") 374 | parser.add_argument("--maxpool1", action="store_false") 375 | parser.add_argument("--hidden_mlp", default=2048, type=int, help="hidden layer dimension in projection head") 376 | parser.add_argument("--feat_dim", default=128, type=int, help="feature dimension") 377 | parser.add_argument("--norm_p", default=2., type=float, help="norm p, -1 for inf") 378 | parser.add_argument("--distance_p", default=2., type=float, help="distance p, -1 for inf") 379 | parser.add_argument("--acos_order", default=0, type=int, help="order of acos, 0 for not use acos kernel") 380 | parser.add_argument("--gamma", default=2., type=float, help="gamma") 381 | parser.add_argument("--gamma_lambd", default=1., type=float, help="gamma lambd") 382 | parser.add_argument("--online_ft", action="store_true") 383 | parser.add_argument("--fp32", action="store_true") 384 | 385 | # transform params 386 | parser.add_argument("--gaussian_blur", action="store_true", help="add gaussian blur") 387 | parser.add_argument("--jitter_strength", type=float, default=1.0, help="jitter strength") 388 | parser.add_argument("--dataset", type=str, default="cifar10", help="stl10, cifar10") 389 | parser.add_argument("--data_dir", type=str, default=".", help="path to download data") 390 | 391 | # training params 392 | parser.add_argument("--fast_dev_run", default=1, type=int) 393 | parser.add_argument("--num_nodes", default=1, type=int, help="number of nodes for training") 394 | parser.add_argument("--gpus", default=1, type=int, help="number of gpus to train on") 395 | parser.add_argument("--num_workers", default=8, type=int, help="num of workers per GPU") 396 | parser.add_argument("--optimizer", default="adam", type=str, help="choose between adam/lars") 397 | parser.add_argument("--exclude_bn_bias", action="store_true", help="exclude bn/bias from weight decay") 398 | parser.add_argument("--max_epochs", default=100, type=int, help="number of total epochs to run") 399 | parser.add_argument("--max_steps", default=-1, type=int, help="max steps") 400 | parser.add_argument("--warmup_epochs", default=10, type=int, help="number of warmup epochs") 401 | parser.add_argument("--batch_size", default=128, type=int, help="batch size per gpu") 402 | 403 | parser.add_argument("--temperature", default=0.1, type=float, help="temperature parameter in training loss") 404 | parser.add_argument("--weight_decay", default=1e-6, type=float, help="weight decay") 405 | parser.add_argument("--learning_rate", default=1e-3, type=float, help="base learning rate") 406 | parser.add_argument("--start_lr", default=0, type=float, help="initial warmup learning rate") 407 | parser.add_argument("--final_lr", type=float, default=1e-6, help="final learning rate") 408 | 409 | return parser 410 | 411 | 412 | @under_review() 413 | def cli_main(config, args, isTune=False): 414 | from ssl_online import SSLOnlineEvaluator 415 | from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule 416 | from pl_bolts.models.self_supervised.simclr.transforms import SimCLREvalDataTransform, SimCLRTrainDataTransform 417 | 418 | # parser = ArgumentParser() 419 | 420 | # model args 421 | # parser = SimCLR.add_model_specific_args(parser) 422 | # args = parser.parse_args() 423 | args.__dict__.update(config) 424 | 425 | if args.norm_p == -1.: 426 | args.norm_p = numpy.inf 427 | if args.distance_p == -1.: 428 | args.distance_p = numpy.inf 429 | 430 | if args.dataset == "stl10": 431 | dm = STL10DataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) 432 | 433 | dm.train_dataloader = dm.train_dataloader_mixed 434 | dm.val_dataloader = dm.val_dataloader_mixed 435 | args.num_samples = dm.num_unlabeled_samples 436 | 437 | args.maxpool1 = False 438 | args.first_conv = True 439 | args.input_height = dm.dims[-1] 440 | 441 | normalization = stl10_normalization() 442 | 443 | args.gaussian_blur = True 444 | args.jitter_strength = 1.0 445 | elif args.dataset == "cifar10" or args.dataset == "cifar100": 446 | val_split = 5000 447 | if args.num_nodes * args.gpus * args.batch_size > val_split: 448 | val_split = args.num_nodes * args.gpus * args.batch_size 449 | 450 | dm = CIFAR10DataModule( 451 | data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers, val_split=val_split 452 | ) if args.dataset == "cifar10" else CIFAR100DataModule( 453 | data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers, val_split=val_split 454 | ) 455 | 456 | args.num_samples = dm.num_samples 457 | 458 | args.maxpool1 = False 459 | args.first_conv = False 460 | args.input_height = dm.dims[-1] 461 | # args.temperature = 0.5 462 | 463 | normalization = cifar10_normalization() 464 | 465 | args.gaussian_blur = False 466 | args.jitter_strength = 0.5 467 | elif args.dataset == "imagenet" or args.dataset == "tiny_imagenet": 468 | args.maxpool1 = True 469 | args.first_conv = True 470 | normalization = imagenet_normalization() 471 | 472 | args.gaussian_blur = True 473 | args.jitter_strength = 1.0 474 | 475 | # args.batch_size = 64 476 | # args.num_nodes = 8 477 | # args.gpus = 8 # per-node 478 | args.max_epochs = 800 479 | 480 | # args.optimizer = "lars" 481 | # args.learning_rate = 4.8 482 | # args.final_lr = 0.0048 483 | # args.start_lr = 0.3 484 | args.online_ft = True 485 | 486 | dm = ImagenetDataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) if args.dataset == "imagenet" else TinyImagenetDataModule(data_dir=args.data_dir, batch_size=args.batch_size, num_workers=args.num_workers) 487 | 488 | args.num_samples = dm.num_samples 489 | args.input_height = dm.dims[-1] 490 | else: 491 | raise NotImplementedError("other datasets have not been implemented till now") 492 | 493 | dm.train_transforms = SimCLRTrainDataTransform( 494 | input_height=args.input_height, 495 | gaussian_blur=args.gaussian_blur, 496 | jitter_strength=args.jitter_strength, 497 | normalize=normalization, 498 | ) 499 | 500 | dm.val_transforms = SimCLREvalDataTransform( 501 | input_height=args.input_height, 502 | gaussian_blur=args.gaussian_blur, 503 | jitter_strength=args.jitter_strength, 504 | normalize=normalization, 505 | ) 506 | 507 | # print(args) 508 | model = SimCLR(**args.__dict__) 509 | 510 | online_evaluator = None 511 | if args.online_ft: 512 | # online eval 513 | online_evaluator = SSLOnlineEvaluator( 514 | drop_p=0.0, 515 | hidden_dim=None, 516 | z_dim=args.hidden_mlp, 517 | num_classes=dm.num_classes, 518 | dataset=args.dataset, 519 | isTune=isTune 520 | ) 521 | 522 | lr_monitor = LearningRateMonitor(logging_interval="step") 523 | model_checkpoint = ModelCheckpoint(save_last=True, save_top_k=1, monitor="val_loss") 524 | callbacks = [] if isTune else [model_checkpoint] 525 | if args.online_ft: 526 | callbacks.append(online_evaluator) 527 | callbacks.append(lr_monitor) 528 | 529 | # print(args.max_steps) 530 | trainer = Trainer( 531 | max_epochs=args.max_epochs, 532 | max_steps=args.max_steps, 533 | gpus=args.gpus, 534 | num_nodes=args.num_nodes, 535 | accelerator="gpu", 536 | sync_batchnorm=True if args.gpus > 1 else False, 537 | precision=32 if args.fp32 else 16, 538 | callbacks=callbacks, 539 | fast_dev_run=args.fast_dev_run, 540 | ) 541 | 542 | trainer.fit(model, datamodule=dm) 543 | 544 | 545 | if __name__ == "__main__": 546 | 547 | parser = ArgumentParser() 548 | parser.add_argument("--arch", default="resnet50", type=str, help="convnet architecture") 549 | parser.add_argument("--max_t", default=200, type=int, help="max epoch to report") 550 | parser.add_argument("--num_samples", default=100, type=int, help="number of samples") 551 | # specify flags to store false 552 | parser.add_argument("--first_conv", action="store_false") 553 | parser.add_argument("--maxpool1", action="store_false") 554 | parser.add_argument("--hidden_mlp", default=2048, type=int, help="hidden layer dimension in projection head") 555 | parser.add_argument("--feat_dim", default=128, type=int, help="feature dimension") 556 | parser.add_argument("--norm_p", default=2., type=float, help="norm p, -1 for inf") 557 | parser.add_argument("--distance_p", default=2., type=float, help="distance p, -1 for inf") 558 | parser.add_argument("--acos_order", default=0, type=int, help="order of acos, 0 for not use acos kernel") 559 | parser.add_argument("--gamma", default=2., type=float, help="gamma") 560 | parser.add_argument("--gamma_lambd", default=1., type=float, help="gamma lambd") 561 | parser.add_argument("--projection_mu", default=1., type=float, help="projection mu") 562 | parser.add_argument("--loss_type", default="origin", type=str, help="search type, origin, sum , product or spectral") 563 | parser.add_argument("--online_ft", action="store_true") 564 | parser.add_argument("--fp32", action="store_true") 565 | 566 | # transform params 567 | parser.add_argument("--gaussian_blur", action="store_true", help="add gaussian blur") 568 | parser.add_argument("--jitter_strength", type=float, default=1.0, help="jitter strength") 569 | parser.add_argument("--dataset", type=str, default="cifar10", help="stl10, cifar10") 570 | parser.add_argument("--data_dir", type=str, default=".", help="path to download data") 571 | 572 | # training params 573 | parser.add_argument("--fast_dev_run", default=1, type=int) 574 | parser.add_argument("--num_nodes", default=1, type=int, help="number of nodes for training") 575 | parser.add_argument("--gpus", default=1, type=int, help="number of gpus to train on") 576 | parser.add_argument("--num_workers", default=8, type=int, help="num of workers per GPU") 577 | parser.add_argument("--optimizer", default="adam", type=str, help="choose between adam/lars") 578 | parser.add_argument("--exclude_bn_bias", action="store_true", help="exclude bn/bias from weight decay") 579 | parser.add_argument("--max_epochs", default=100, type=int, help="number of total epochs to run") 580 | parser.add_argument("--max_steps", default=-1, type=int, help="max steps") 581 | parser.add_argument("--warmup_epochs", default=10, type=int, help="number of warmup epochs") 582 | parser.add_argument("--batch_size", default=128, type=int, help="batch size per gpu") 583 | 584 | parser.add_argument("--temperature", default=0.1, type=float, help="temperature parameter in training loss") 585 | parser.add_argument("--weight_decay", default=1e-6, type=float, help="weight decay") 586 | parser.add_argument("--learning_rate", default=1e-3, type=float, help="base learning rate") 587 | parser.add_argument("--start_lr", default=0, type=float, help="initial warmup learning rate") 588 | parser.add_argument("--final_lr", type=float, default=1e-6, help="final learning rate") 589 | 590 | args = parser.parse_args() 591 | cli_main({}, args, isTune=False) 592 | -------------------------------------------------------------------------------- /ssl_online.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | from typing import Any, Dict, Optional, Sequence, Tuple, Union 3 | 4 | import torch 5 | from pytorch_lightning import Callback, LightningModule, Trainer 6 | from pytorch_lightning.utilities import rank_zero_warn 7 | from torch import Tensor, nn 8 | from torch.nn import functional as F 9 | from torch.optim import Optimizer 10 | from torchmetrics.functional import accuracy 11 | 12 | from pl_bolts.models.self_supervised.evaluator import SSLEvaluator 13 | from pl_bolts.utils.stability import under_review 14 | 15 | import ray 16 | from ray import tune 17 | from ray.air import session 18 | from ray.air.checkpoint import Checkpoint 19 | from ray.tune.schedulers import ASHAScheduler 20 | 21 | 22 | @under_review() 23 | class SSLOnlineEvaluator(Callback): # pragma: no cover 24 | """Attaches a MLP for fine-tuning using the standard self-supervised protocol. 25 | Example:: 26 | # your datamodule must have 2 attributes 27 | dm = DataModule() 28 | dm.num_classes = ... # the num of classes in the datamodule 29 | dm.name = ... # name of the datamodule (e.g. ImageNet, STL10, CIFAR10) 30 | # your model must have 1 attribute 31 | model = Model() 32 | model.z_dim = ... # the representation dim 33 | online_eval = SSLOnlineEvaluator( 34 | z_dim=model.z_dim 35 | ) 36 | """ 37 | 38 | def __init__( 39 | self, 40 | z_dim: int, 41 | drop_p: float = 0.2, 42 | hidden_dim: Optional[int] = None, 43 | num_classes: Optional[int] = None, 44 | dataset: Optional[str] = None, 45 | isTune=False 46 | ): 47 | """ 48 | Args: 49 | z_dim: Representation dimension 50 | drop_p: Dropout probability 51 | hidden_dim: Hidden dimension for the fine-tune MLP 52 | """ 53 | super().__init__() 54 | 55 | self.z_dim = z_dim 56 | self.hidden_dim = hidden_dim 57 | self.drop_p = drop_p 58 | 59 | self.optimizer: Optional[Optimizer] = None 60 | self.online_evaluator: Optional[SSLEvaluator] = None 61 | self.num_classes: Optional[int] = None 62 | self.dataset: Optional[str] = None 63 | self.num_classes: Optional[int] = num_classes 64 | self.dataset: Optional[str] = dataset 65 | self.isTune = isTune 66 | 67 | self._recovered_callback_state: Optional[Dict[str, Any]] = None 68 | 69 | def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None: 70 | if self.num_classes is None: 71 | self.num_classes = trainer.datamodule.num_classes 72 | if self.dataset is None: 73 | self.dataset = trainer.datamodule.name 74 | 75 | def on_fit_start(self, trainer: Trainer, pl_module: LightningModule) -> None: 76 | # must move to device after setup, as during setup, pl_module is still on cpu 77 | self.online_evaluator = SSLEvaluator( 78 | n_input=self.z_dim, 79 | n_classes=self.num_classes, 80 | p=self.drop_p, 81 | n_hidden=self.hidden_dim, 82 | ).to(pl_module.device) 83 | 84 | # switch fo PL compatibility reasons 85 | accel = ( 86 | trainer.accelerator_connector 87 | if hasattr(trainer, "accelerator_connector") 88 | else trainer._accelerator_connector 89 | ) 90 | if accel.is_distributed: 91 | if accel.use_ddp: 92 | from torch.nn.parallel import DistributedDataParallel as DDP 93 | 94 | self.online_evaluator = DDP(self.online_evaluator, device_ids=[pl_module.device]) 95 | elif accel.use_dp: 96 | from torch.nn.parallel import DataParallel as DP 97 | 98 | self.online_evaluator = DP(self.online_evaluator, device_ids=[pl_module.device]) 99 | else: 100 | rank_zero_warn( 101 | "Does not support this type of distributed accelerator. The online evaluator will not sync." 102 | ) 103 | 104 | self.optimizer = torch.optim.Adam(self.online_evaluator.parameters(), lr=1e-4) 105 | 106 | if self._recovered_callback_state is not None: 107 | self.online_evaluator.load_state_dict(self._recovered_callback_state["state_dict"]) 108 | self.optimizer.load_state_dict(self._recovered_callback_state["optimizer_state"]) 109 | 110 | def to_device(self, batch: Sequence, device: Union[str, torch.device]) -> Tuple[Tensor, Tensor]: 111 | # get the labeled batch 112 | if self.dataset == "stl10": 113 | labeled_batch = batch[1] 114 | batch = labeled_batch 115 | 116 | inputs, y = batch 117 | 118 | # last input is for online eval 119 | x = inputs[-1] 120 | x = x.to(device) 121 | y = y.to(device) 122 | 123 | return x, y 124 | 125 | def shared_step( 126 | self, 127 | pl_module: LightningModule, 128 | batch: Sequence, 129 | ): 130 | with torch.no_grad(): 131 | with set_training(pl_module, False): 132 | x, y = self.to_device(batch, pl_module.device) 133 | representations = pl_module(x).flatten(start_dim=1) 134 | 135 | # forward pass 136 | mlp_logits = self.online_evaluator(representations) # type: ignore[operator] 137 | mlp_loss = F.cross_entropy(mlp_logits, y) 138 | 139 | acc = accuracy(mlp_logits.softmax(-1), y) 140 | 141 | return acc, mlp_loss 142 | 143 | def on_train_batch_end( 144 | self, 145 | trainer: Trainer, 146 | pl_module: LightningModule, 147 | outputs: Sequence, 148 | batch: Sequence, 149 | batch_idx: int, 150 | ) -> None: 151 | train_acc, mlp_loss = self.shared_step(pl_module, batch) 152 | 153 | # update finetune weights 154 | mlp_loss.backward() 155 | self.optimizer.step() 156 | self.optimizer.zero_grad() 157 | 158 | pl_module.log("online_train_acc", train_acc, on_step=True, on_epoch=False) 159 | pl_module.log("online_train_loss", mlp_loss, on_step=True, on_epoch=False) 160 | 161 | def on_validation_batch_end( 162 | self, 163 | trainer: Trainer, 164 | pl_module: LightningModule, 165 | outputs: Sequence, 166 | batch: Sequence, 167 | batch_idx: int, 168 | dataloader_idx: int, 169 | ) -> None: 170 | val_acc, mlp_loss = self.shared_step(pl_module, batch) 171 | if self.isTune: 172 | session.report({"online_val_acc": val_acc.item(), "online_val_loss": mlp_loss.item()}) 173 | pl_module.log("online_val_acc", val_acc, on_step=False, on_epoch=True, sync_dist=True) 174 | pl_module.log("online_val_loss", mlp_loss, on_step=False, on_epoch=True, sync_dist=True) 175 | 176 | def state_dict(self) -> dict: 177 | return {"state_dict": self.online_evaluator.state_dict(), "optimizer_state": self.optimizer.state_dict()} 178 | 179 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 180 | self._recovered_callback_state = state_dict 181 | 182 | 183 | @under_review() 184 | @contextmanager 185 | def set_training(module: nn.Module, mode: bool): 186 | """Context manager to set training mode. 187 | When exit, recover the original training mode. 188 | Args: 189 | module: module to set training mode 190 | mode: whether to set training mode (True) or evaluation mode (False). 191 | """ 192 | original_mode = module.training 193 | 194 | try: 195 | module.train(mode) 196 | yield module 197 | finally: 198 | module.train(original_mode) -------------------------------------------------------------------------------- /tiny_imagenet_datamodule.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from typing import Any, Callable, Optional, Sequence, Union 3 | 4 | from pl_bolts.datamodules.vision_datamodule import VisionDataModule 5 | from pl_bolts.datasets import TrialCIFAR10 6 | from pl_bolts.transforms.dataset_normalizations import cifar10_normalization 7 | from pl_bolts.utils import _TORCHVISION_AVAILABLE 8 | from pl_bolts.utils.stability import under_review 9 | from pl_bolts.utils.warnings import warn_missing_pkg 10 | 11 | if _TORCHVISION_AVAILABLE: 12 | from torchvision import transforms as transform_lib 13 | from torchvision.datasets import CIFAR100 14 | from torchvision.datasets import ImageFolder 15 | 16 | else: # pragma: no cover 17 | warn_missing_pkg("torchvision") 18 | CIFAR100 = None 19 | 20 | 21 | class TinyImagenetDataModule(VisionDataModule): 22 | 23 | name = "tiny_imagenet" 24 | dims = (3, 64, 64) 25 | 26 | def __init__( 27 | self, 28 | data_dir: Optional[str] = None, 29 | val_split: Union[int, float] = 0.1, 30 | num_workers: int = 0, 31 | normalize: bool = False, 32 | batch_size: int = 32, 33 | seed: int = 42, 34 | shuffle: bool = True, 35 | pin_memory: bool = True, 36 | drop_last: bool = False, 37 | *args: Any, 38 | **kwargs: Any, 39 | ) -> None: 40 | """ 41 | Args: 42 | data_dir: Where to save/load the data 43 | val_split: Percent (float) or number (int) of samples to use for the validation split 44 | num_workers: How many workers to use for loading data 45 | normalize: If true applies image normalize 46 | batch_size: How many samples per batch to load 47 | seed: Random seed to be used for train/val/test splits 48 | shuffle: If true shuffles the train data every epoch 49 | pin_memory: If true, the data loader will copy Tensors into CUDA pinned memory before 50 | returning them 51 | drop_last: If true drops the last incomplete batch 52 | """ 53 | super().__init__( # type: ignore[misc] 54 | data_dir=data_dir, 55 | val_split=val_split, 56 | num_workers=num_workers, 57 | normalize=normalize, 58 | batch_size=batch_size, 59 | seed=seed, 60 | shuffle=shuffle, 61 | pin_memory=pin_memory, 62 | drop_last=drop_last, 63 | *args, 64 | **kwargs, 65 | ) 66 | 67 | @property 68 | def num_samples(self) -> int: 69 | train_len, _ = self._get_splits(len_dataset=100000) 70 | return train_len 71 | 72 | @property 73 | def num_classes(self) -> int: 74 | return 200 75 | 76 | def default_transforms(self) -> Callable: 77 | if self.normalize: 78 | cf10_transforms = transform_lib.Compose([transform_lib.ToTensor(), cifar10_normalization()]) 79 | else: 80 | cf10_transforms = transform_lib.Compose([transform_lib.ToTensor()]) 81 | 82 | return cf10_transforms 83 | 84 | def prepare_data(self, *args: Any, **kwargs: Any) -> None: 85 | pass 86 | 87 | def setup(self, stage: Optional[str] = None) -> None: 88 | """Creates train, val, and test dataset.""" 89 | if stage == "fit" or stage is None: 90 | train_transforms = self.default_transforms() if self.train_transforms is None else self.train_transforms 91 | val_transforms = self.default_transforms() if self.val_transforms is None else self.val_transforms 92 | 93 | dataset_train = ImageFolder(self.data_dir + "/tiny-imagenet-200/train", transform=train_transforms, **self.EXTRA_ARGS) 94 | dataset_val = ImageFolder(self.data_dir + "/tiny-imagenet-200/train", transform=val_transforms, **self.EXTRA_ARGS) 95 | 96 | # Split 97 | self.dataset_train = self._split_dataset(dataset_train) 98 | self.dataset_val = self._split_dataset(dataset_val, train=False) 99 | 100 | if stage == "test" or stage is None: 101 | test_transforms = self.default_transforms() if self.test_transforms is None else self.test_transforms 102 | self.dataset_test = ImageFolder( 103 | self.data_dir + "/tiny-imagenet-200/val", transform=test_transforms, **self.EXTRA_ARGS 104 | ) 105 | 106 | @staticmethod 107 | def add_dataset_specific_args(parent_parser: ArgumentParser) -> ArgumentParser: 108 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 109 | 110 | parser.add_argument("--data_dir", type=str, default=".") 111 | parser.add_argument("--num_workers", type=int, default=0) 112 | parser.add_argument("--batch_size", type=int, default=32) 113 | 114 | return parser -------------------------------------------------------------------------------- /transforms.py: -------------------------------------------------------------------------------- 1 | from pl_bolts.utils import _TORCHVISION_AVAILABLE 2 | from pl_bolts.utils.warnings import warn_missing_pkg 3 | 4 | if _TORCHVISION_AVAILABLE: 5 | from torchvision import transforms 6 | else: # pragma: no cover 7 | warn_missing_pkg("torchvision") 8 | 9 | 10 | class SimCLRTrainDataTransform: 11 | """Transforms for SimCLR during training step of the pre-training stage. 12 | Transform:: 13 | RandomResizedCrop(size=self.input_height) 14 | RandomHorizontalFlip() 15 | RandomApply([color_jitter], p=0.8) 16 | RandomGrayscale(p=0.2) 17 | RandomApply([GaussianBlur(kernel_size=int(0.1 * self.input_height))], p=0.5) 18 | transforms.ToTensor() 19 | Example:: 20 | from pl_bolts.models.self_supervised.simclr.transforms import SimCLRTrainDataTransform 21 | transform = SimCLRTrainDataTransform(input_height=32) 22 | x = sample() 23 | (xi, xj, xk) = transform(x) # xk is only for the online evaluator if used 24 | """ 25 | 26 | def __init__( 27 | self, input_height: int = 224, gaussian_blur: bool = True, jitter_strength: float = 1.0, normalize=None 28 | ) -> None: 29 | 30 | if not _TORCHVISION_AVAILABLE: # pragma: no cover 31 | raise ModuleNotFoundError("You want to use `transforms` from `torchvision` which is not installed yet.") 32 | 33 | self.jitter_strength = jitter_strength 34 | self.input_height = input_height 35 | self.gaussian_blur = gaussian_blur 36 | self.normalize = normalize 37 | 38 | self.color_jitter = transforms.ColorJitter( 39 | 0.8 * self.jitter_strength, 40 | 0.8 * self.jitter_strength, 41 | 0.8 * self.jitter_strength, 42 | 0.2 * self.jitter_strength, 43 | ) 44 | 45 | data_transforms = [ 46 | transforms.RandomResizedCrop(size=self.input_height), 47 | transforms.RandomHorizontalFlip(p=0.5), 48 | transforms.RandomApply([self.color_jitter], p=0.8), 49 | transforms.RandomGrayscale(p=0.2), 50 | ] 51 | 52 | if self.gaussian_blur: 53 | kernel_size = int(0.1 * self.input_height) 54 | if kernel_size % 2 == 0: 55 | kernel_size += 1 56 | 57 | data_transforms.append(transforms.RandomApply([transforms.GaussianBlur(kernel_size=kernel_size)], p=0.5)) 58 | 59 | self.data_transforms = transforms.Compose(data_transforms) 60 | 61 | if normalize is None: 62 | self.final_transform = transforms.ToTensor() 63 | else: 64 | self.final_transform = transforms.Compose([transforms.ToTensor(), normalize]) 65 | 66 | self.train_transform = transforms.Compose([self.data_transforms, self.final_transform]) 67 | 68 | # add online train transform of the size of global view 69 | self.online_transform = transforms.Compose( 70 | [transforms.RandomResizedCrop(self.input_height), transforms.RandomHorizontalFlip(), self.final_transform] 71 | ) 72 | 73 | def __call__(self, sample): 74 | transform = self.train_transform 75 | 76 | xi = transform(sample) 77 | xj = transform(sample) 78 | 79 | return xi, xj, self.online_transform(sample) 80 | 81 | 82 | class SimCLREvalDataTransform(SimCLRTrainDataTransform): 83 | """Transforms for SimCLR during the validation step of the pre-training stage. 84 | Transform:: 85 | Resize(input_height + 10, interpolation=3) 86 | transforms.CenterCrop(input_height), 87 | transforms.ToTensor() 88 | Example:: 89 | from pl_bolts.models.self_supervised.simclr.transforms import SimCLREvalDataTransform 90 | transform = SimCLREvalDataTransform(input_height=32) 91 | x = sample() 92 | (xi, xj, xk) = transform(x) # xk is only for the online evaluator if used 93 | """ 94 | 95 | def __init__( 96 | self, input_height: int = 224, gaussian_blur: bool = True, jitter_strength: float = 1.0, normalize=None 97 | ): 98 | super().__init__( 99 | normalize=normalize, input_height=input_height, gaussian_blur=gaussian_blur, jitter_strength=jitter_strength 100 | ) 101 | 102 | # replace online transform with eval time transform 103 | self.online_transform = transforms.Compose( 104 | [ 105 | transforms.Resize(int(self.input_height + 0.1 * self.input_height)), 106 | transforms.CenterCrop(self.input_height), 107 | self.final_transform, 108 | ] 109 | ) 110 | 111 | 112 | class SimCLRFinetuneTransform(SimCLRTrainDataTransform): 113 | """Transforms for SimCLR during the fine-tuning stage. 114 | Transform:: 115 | Resize(input_height + 10, interpolation=3) 116 | transforms.CenterCrop(input_height), 117 | transforms.ToTensor() 118 | Example:: 119 | from pl_bolts.models.self_supervised.simclr.transforms import SimCLREvalDataTransform 120 | transform = SimCLREvalDataTransform(input_height=32) 121 | x = sample() 122 | xk = transform(x) 123 | """ 124 | 125 | def __init__( 126 | self, input_height: int = 224, jitter_strength: float = 1.0, normalize=None, eval_transform: bool = False 127 | ) -> None: 128 | 129 | super().__init__( 130 | normalize=normalize, input_height=input_height, gaussian_blur=None, jitter_strength=jitter_strength 131 | ) 132 | 133 | if eval_transform: 134 | self.data_transforms = transforms.Compose([ 135 | transforms.Resize(int(self.input_height + 0.1 * self.input_height)), 136 | transforms.CenterCrop(self.input_height), 137 | ]) 138 | 139 | self.transform = transforms.Compose([self.data_transforms, self.final_transform]) 140 | 141 | def __call__(self, sample): 142 | return self.transform(sample) --------------------------------------------------------------------------------