├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── batchrenorm.py ├── lars.py ├── linear_classifier.py ├── moco.py ├── model_params.py ├── requirements.txt ├── train_blog.py ├── utils.py └── ws_resnet.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # PyCharm 132 | .idea/ 133 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Untiled AI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch-Lightning Implementation of Self-Supervised Learning Methods 2 | 3 | This is a [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning) implementation of the following self-supervised representation learning methods: 4 | - [MoCo](https://arxiv.org/abs/1911.05722) 5 | - [MoCo v2](https://arxiv.org/abs/2003.04297) 6 | - [SimCLR](https://arxiv.org/abs/2002.05709) 7 | - [BYOL](https://arxiv.org/abs/2006.07733) 8 | - [EqCo](https://arxiv.org/abs/2010.01929) 9 | - [VICReg](https://arxiv.org/abs/2105.04906) 10 | 11 | Supported datasets: ImageNet, STL-10, and CIFAR-10. 12 | 13 | During training, the top1/top5 accuracies (out of 1+K examples) are reported where possible. During validation, an `sklearn` linear classifier is trained on half the test set and validated on the other half. The top1 accuracy is logged as `train_class_acc` / `valid_class_acc`. 14 | 15 | 16 | ## Installing 17 | 18 | Make sure you're in a fresh `conda` or `venv` environment, then run: 19 | 20 | ```bash 21 | git clone https://github.com/untitled-ai/self_supervised 22 | cd self_supervised 23 | pip install -r requirements.txt 24 | ``` 25 | 26 | ## Replicating our BYOL blog post 27 | 28 | We found some surprising results about the role of batch norm in BYOL. See the blog post [Understanding self-supervised and contrastive learning with "Bootstrap Your Own Latent" (BYOL)](https://untitled-ai.github.io/understanding-self-supervised-contrastive-learning.html) for more details about our experiments. 29 | 30 | You can replicate the results of our blog post by running `python train_blog.py`. The cosine similarity between z and z' is reported as `step_neg_cos` (for negative examples) and `step_pos_cos` (for positive examples). Classification accuracy is reported as `valid_class_acc`. 31 | 32 | ## Getting started with MoCo v2 33 | 34 | To get started with training a ResNet-18 with MoCo v2 on STL-10 (the default configuration): 35 | 36 | ```python 37 | import os 38 | import pytorch_lightning as pl 39 | from moco import SelfSupervisedMethod 40 | from model_params import ModelParams 41 | 42 | os.environ["DATA_PATH"] = "~/data" 43 | 44 | params = ModelParams() 45 | model = SelfSupervisedMethod(params) 46 | trainer = pl.Trainer(gpus=1, max_epochs=320) 47 | trainer.fit(model) 48 | trainer.save_checkpoint("example.ckpt") 49 | ``` 50 | 51 | For convenience, you can instead pass these parameters as keyword args, for example with `model = SelfSupervisedMethod(batch_size=128)`. 52 | 53 | ## VICReg 54 | 55 | To train VICReg rather than MoCo v2, use the following parameters: 56 | 57 | ```python 58 | import os 59 | import pytorch_lightning as pl 60 | from moco import SelfSupervisedMethod 61 | from model_params import VICRegParams 62 | 63 | os.environ["DATA_PATH"] = "~/data" 64 | 65 | params = VICRegParams() 66 | model = SelfSupervisedMethod(params) 67 | trainer = pl.Trainer(gpus=1, max_epochs=320) 68 | trainer.fit(model) 69 | trainer.save_checkpoint("example.ckpt") 70 | ``` 71 | 72 | Note that we have not tuned these parameters for STL-10, and the parameters used for ImageNet are slightly different. See the comment on VICRegParams for details. 73 | 74 | ## BYOL 75 | 76 | To train BYOL rather than MoCo v2, use the following parameters: 77 | 78 | ```python 79 | import os 80 | import pytorch_lightning as pl 81 | from moco import SelfSupervisedMethod 82 | from model_params import BYOLParams 83 | 84 | os.environ["DATA_PATH"] = "~/data" 85 | 86 | params = BYOLParams() 87 | model = SelfSupervisedMethod(params) 88 | trainer = pl.Trainer(gpus=1, max_epochs=320) 89 | trainer.fit(model) 90 | trainer.save_checkpoint("example.ckpt") 91 | ``` 92 | 93 | ## SimCLR 94 | 95 | To train SimCLR rather than MoCo v2, use the following parameters: 96 | 97 | ```python 98 | import os 99 | import pytorch_lightning as pl 100 | from moco import SelfSupervisedMethod 101 | from model_params import SimCLRParams 102 | 103 | os.environ["DATA_PATH"] = "~/data" 104 | 105 | params = SimCLRParams() 106 | model = SelfSupervisedMethod(params) 107 | trainer = pl.Trainer(gpus=1, max_epochs=320) 108 | trainer.fit(model) 109 | trainer.save_checkpoint("example.ckpt") 110 | ``` 111 | 112 | **Note for multi-GPU setups**: this currently only uses negatives on the same GPU, and will not sync negatives across multiple GPUs. 113 | 114 | 115 | # Evaluating a trained model 116 | 117 | To train a linear classifier on the result: 118 | 119 | ```python 120 | import pytorch_lightning as pl 121 | from linear_classifier import LinearClassifierMethod 122 | linear_model = LinearClassifierMethod.from_moco_checkpoint("example.ckpt") 123 | trainer = pl.Trainer(gpus=1, max_epochs=100) 124 | 125 | trainer.fit(linear_model) 126 | ``` 127 | 128 | # Results on STL-10 and ImageNet 129 | 130 | Training a ResNet-18 for 320 epochs on STL-10 achieved 85% linear classification accuracy on the test set (1 fold of 5000). This used all default parameters. 131 | 132 | Training a ResNet-50 for 200 epochs on ImageNet achieves 65.6% linear classification accuracy on the test set. 133 | This used 8 gpus with `ddp` and parameters: 134 | 135 | ```python 136 | hparams = ModelParams( 137 | encoder_arch="resnet50", 138 | shuffle_batch_norm=True, 139 | embedding_dim=2048, 140 | mlp_hidden_dim=2048, 141 | dataset_name="imagenet", 142 | batch_size=32, 143 | lr=0.03, 144 | max_epochs=200, 145 | transform_crop_size=224, 146 | num_data_workers=32, 147 | gather_keys_for_queue=True, 148 | ) 149 | ``` 150 | 151 | (the `batch_size` differs from the moco documentation due to the way PyTorch-Lightning handles multi-gpu 152 | training in `ddp` - the effective number is `batch_size=256`). **Note that for ImageNet we suggest using 153 | `val_percent_check=0.1` when calling `pl.Trainer`** to reduce the time fitting the sklearn model. 154 | 155 | 156 | # All training options 157 | 158 | All possible `hparams` for SelfSupervisedMethod, along with defaults: 159 | 160 | ```python 161 | class ModelParams: 162 | # encoder model selection 163 | encoder_arch: str = "resnet18" 164 | shuffle_batch_norm: bool = False 165 | embedding_dim: int = 512 # must match embedding dim of encoder 166 | 167 | # data-related parameters 168 | dataset_name: str = "stl10" 169 | batch_size: int = 256 170 | 171 | # MoCo parameters 172 | K: int = 65536 # number of examples in queue 173 | dim: int = 128 174 | m: float = 0.996 175 | T: float = 0.2 176 | 177 | # eqco parameters 178 | eqco_alpha: int = 65536 179 | use_eqco_margin: bool = False 180 | use_negative_examples_from_batch: bool = False 181 | 182 | # optimization parameters 183 | lr: float = 0.5 184 | momentum: float = 0.9 185 | weight_decay: float = 1e-4 186 | max_epochs: int = 320 187 | final_lr_schedule_value: float = 0.0 188 | 189 | # transform parameters 190 | transform_s: float = 0.5 191 | transform_apply_blur: bool = True 192 | 193 | # Change these to make more like BYOL 194 | use_momentum_schedule: bool = False 195 | loss_type: str = "ce" 196 | use_negative_examples_from_queue: bool = True 197 | use_both_augmentations_as_queries: bool = False 198 | optimizer_name: str = "sgd" 199 | lars_warmup_epochs: int = 1 200 | lars_eta: float = 1e-3 201 | exclude_matching_parameters_from_lars: List[str] = [] # set to [".bias", ".bn"] to match paper 202 | loss_constant_factor: float = 1 203 | 204 | # Change these to make more like VICReg 205 | use_vicreg_loss: bool = False 206 | use_lagging_model: bool = True 207 | use_unit_sphere_projection: bool = True 208 | invariance_loss_weight: float = 25.0 209 | variance_loss_weight: float = 25.0 210 | covariance_loss_weight: float = 1.0 211 | variance_loss_epsilon: float = 1e-04 212 | 213 | # MLP parameters 214 | projection_mlp_layers: int = 2 215 | prediction_mlp_layers: int = 0 216 | mlp_hidden_dim: int = 512 217 | 218 | mlp_normalization: Optional[str] = None 219 | prediction_mlp_normalization: Optional[str] = "same" # if same will use mlp_normalization 220 | use_mlp_weight_standardization: bool = False 221 | 222 | # data loader parameters 223 | num_data_workers: int = 4 224 | drop_last_batch: bool = True 225 | pin_data_memory: bool = True 226 | gather_keys_for_queue: bool = False 227 | ``` 228 | 229 | A few options require more explanation: 230 | 231 | - **encoder_arch** can be any torchvision model, or can be one of the ResNet models with weight standardization defined in 232 | `ws_resnet.py`. 233 | 234 | - **dataset_name** can be `imagenet`, `stl10`, or `cifar10`. `os.environ["DATA_PATH"]` will be used as the path to the data. STL-10 and CIFAR-10 will 235 | be downloaded if they do not already exist. 236 | 237 | - **loss_type** can be `ce` (cross entropy) with one of the `use_negative_examples` to correspond to MoCo or `ip` (inner product) 238 | with both `use_negative_examples=False` to correspond to BYOL. It can also be `bce`, which is similar to `ip` but applies the 239 | binary cross entropy loss function to the result. Or it can be `vic` for VICReg loss. 240 | 241 | - **optimizer_name**, currently just `sgd` or `lars`. 242 | 243 | - **exclude_matching_parameters_from_lars** will remove weight decay and LARS learning rate from matching parameters. Set 244 | to `[".bias", ".bn"]` to match BYOL paper implementation. 245 | 246 | - **mlp_normalization** can be None for no normalization, `bn` for batch normalization, `ln` for layer norm, `gn` for group 247 | norm, or `br` for [batch renormalization](https://github.com/ludvb/batchrenorm). 248 | 249 | - **prediction_mlp_normalization** defaults to `same` to use the same normalization as above, but can be given any of the 250 | above parameters to use a different normalization. 251 | 252 | - **shuffle_batch_norm** and **gather_keys_for_queue** are both related to multi-gpu training. **shuffle_batch_norm** 253 | will shuffle the *key* images among GPUs, which is needed for training if batch norm is used. **gather_keys_for_queue** 254 | will gather key projections (z' in the blog post) from all gpus to add to the MoCo queue. 255 | 256 | # Training with custom options 257 | 258 | You can train using any settings of the above parameters. This configuration represents the settings from BYOL: 259 | 260 | ```python 261 | hparams = ModelParams( 262 | prediction_mlp_layers=2, 263 | mlp_normalization="bn", 264 | loss_type="ip", 265 | use_negative_examples_from_queue=False, 266 | use_both_augmentations_as_queries=True, 267 | use_momentum_schedule=True, 268 | optimizer_name="lars", 269 | exclude_matching_parameters_from_lars=[".bias", ".bn"], 270 | loss_constant_factor=2 271 | ) 272 | 273 | ``` 274 | Or here is our recommended way to modify VICReg for CIFAR-10: 275 | ```python 276 | from model_params import VICRegParams 277 | 278 | hparams = VICRegParams( 279 | dataset_name="cifar10", 280 | transform_apply_blur=False, 281 | mlp_hidden_dim=2048, 282 | dim=2048, 283 | batch_size=256, 284 | lr=0.3, 285 | final_lr_schedule_value=0, 286 | weight_decay=1e-4, 287 | lars_warmup_epochs=10, 288 | lars_eta=0.02 289 | ) 290 | ``` 291 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import utils 2 | 3 | from .linear_classifier import LinearClassifierMethod 4 | from .linear_classifier import LinearClassifierMethodParams 5 | from .moco import SelfSupervisedMethod 6 | from .model_params import ModelParams 7 | 8 | __all__ = [ 9 | "SelfSupervisedMethod", 10 | "ModelParams", 11 | "LinearClassifierMethod", 12 | "LinearClassifierMethodParams", 13 | "utils", 14 | ] 15 | -------------------------------------------------------------------------------- /batchrenorm.py: -------------------------------------------------------------------------------- 1 | """ 2 | From https://github.com/ludvb/batchrenorm 3 | @article{batchrenomalization, 4 | author = {Sergey Ioffe}, 5 | title = {Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models}, 6 | journal = {arXiv preprint arXiv:1702.03275}, 7 | year = {2017}, 8 | } 9 | """ 10 | 11 | import torch 12 | 13 | __all__ = ["BatchRenorm1d", "BatchRenorm2d", "BatchRenorm3d"] 14 | 15 | 16 | class BatchRenorm(torch.nn.Module): 17 | def __init__( 18 | self, 19 | num_features: int, 20 | eps: float = 1e-3, 21 | momentum: float = 0.01, 22 | affine: bool = True, 23 | ): 24 | super().__init__() 25 | self.register_buffer("running_mean", torch.zeros(num_features, dtype=torch.float)) 26 | self.register_buffer("running_std", torch.ones(num_features, dtype=torch.float)) 27 | self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long)) 28 | self.weight = torch.nn.Parameter(torch.ones(num_features, dtype=torch.float)) 29 | self.bias = torch.nn.Parameter(torch.zeros(num_features, dtype=torch.float)) 30 | self.affine = affine 31 | self.eps = eps 32 | self.step = 0 33 | self.momentum = momentum 34 | 35 | def _check_input_dim(self, x: torch.Tensor) -> None: 36 | raise NotImplementedError() # pragma: no cover 37 | 38 | @property 39 | def rmax(self) -> torch.Tensor: 40 | return (2 / 35000 * self.num_batches_tracked + 25 / 35).clamp_(1.0, 3.0) 41 | 42 | @property 43 | def dmax(self) -> torch.Tensor: 44 | return (5 / 20000 * self.num_batches_tracked - 25 / 20).clamp_(0.0, 5.0) 45 | 46 | def forward(self, x: torch.Tensor) -> torch.Tensor: 47 | self._check_input_dim(x) 48 | if x.dim() > 2: 49 | x = x.transpose(1, -1) 50 | if self.training: 51 | dims = [i for i in range(x.dim() - 1)] 52 | batch_mean = x.mean(dims) 53 | batch_std = x.std(dims, unbiased=False) + self.eps 54 | r = (batch_std.detach() / self.running_std.view_as(batch_std)).clamp_( 55 | 1 / self.rmax.item(), self.rmax.item() 56 | ) 57 | d = ( 58 | (batch_mean.detach() - self.running_mean.view_as(batch_mean)) / self.running_std.view_as(batch_std) 59 | ).clamp_(-self.dmax.item(), self.dmax.item()) 60 | x = (x - batch_mean) / batch_std * r + d 61 | self.running_mean += self.momentum * (batch_mean.detach() - self.running_mean) 62 | self.running_std += self.momentum * (batch_std.detach() - self.running_std) 63 | self.num_batches_tracked += 1 64 | else: 65 | x = (x - self.running_mean) / self.running_std 66 | if self.affine: 67 | x = self.weight * x + self.bias 68 | if x.dim() > 2: 69 | x = x.transpose(1, -1) 70 | return x 71 | 72 | 73 | class BatchRenorm1d(BatchRenorm): 74 | def _check_input_dim(self, x: torch.Tensor) -> None: 75 | if x.dim() not in [2, 3]: 76 | raise ValueError("expected 2D or 3D input (got {x.dim()}D input)") 77 | 78 | 79 | class BatchRenorm2d(BatchRenorm): 80 | def _check_input_dim(self, x: torch.Tensor) -> None: 81 | if x.dim() != 4: 82 | raise ValueError("expected 4D input (got {x.dim()}D input)") 83 | 84 | 85 | class BatchRenorm3d(BatchRenorm): 86 | def _check_input_dim(self, x: torch.Tensor) -> None: 87 | if x.dim() != 5: 88 | raise ValueError("expected 5D input (got {x.dim()}D input)") 89 | -------------------------------------------------------------------------------- /lars.py: -------------------------------------------------------------------------------- 1 | """ 2 | Layer-wise adaptive rate scaling for SGD in PyTorch! 3 | Based on https://github.com/noahgolmant/pytorch-lars 4 | """ 5 | import torch 6 | from torch.optim.optimizer import Optimizer 7 | 8 | 9 | class LARS(Optimizer): 10 | r"""Implements layer-wise adaptive rate scaling for SGD. 11 | 12 | Args: 13 | params (iterable): iterable of parameters to optimize or dicts defining 14 | parameter groups 15 | lr (float): base learning rate (\gamma_0) 16 | momentum (float, optional): momentum factor (default: 0) ("m") 17 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 18 | ("\beta") 19 | eta (float, optional): LARS coefficient 20 | max_epoch: maximum training epoch to determine polynomial LR decay. 21 | 22 | Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg. 23 | Large Batch Training of Convolutional Networks: 24 | https://arxiv.org/abs/1708.03888 25 | 26 | Example: 27 | >>> optimizer = LARS(model.parameters(), lr=0.1, eta=1e-3) 28 | >>> optimizer.zero_grad() 29 | >>> loss_fn(model(input), target).backward() 30 | >>> optimizer.step() 31 | """ 32 | 33 | def __init__(self, params, lr=1.0, momentum=0.9, weight_decay=0.0005, eta=0.001, max_epoch=200, warmup_epochs=1): 34 | if lr < 0.0: 35 | raise ValueError("Invalid learning rate: {}".format(lr)) 36 | if momentum < 0.0: 37 | raise ValueError("Invalid momentum value: {}".format(momentum)) 38 | if weight_decay < 0.0: 39 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 40 | if eta < 0.0: 41 | raise ValueError("Invalid LARS coefficient value: {}".format(eta)) 42 | 43 | self.epoch = 0 44 | defaults = dict( 45 | lr=lr, 46 | momentum=momentum, 47 | weight_decay=weight_decay, 48 | eta=eta, 49 | max_epoch=max_epoch, 50 | warmup_epochs=warmup_epochs, 51 | use_lars=True, 52 | ) 53 | super().__init__(params, defaults) 54 | 55 | def step(self, epoch=None, closure=None): 56 | """Performs a single optimization step. 57 | 58 | Arguments: 59 | closure (callable, optional): A closure that reevaluates the model 60 | and returns the loss. 61 | epoch: current epoch to calculate polynomial LR decay schedule. 62 | if None, uses self.epoch and increments it. 63 | """ 64 | loss = None 65 | if closure is not None: 66 | loss = closure() 67 | 68 | if epoch is None: 69 | epoch = self.epoch 70 | self.epoch += 1 71 | 72 | for group in self.param_groups: 73 | weight_decay = group["weight_decay"] 74 | momentum = group["momentum"] 75 | eta = group["eta"] 76 | lr = group["lr"] 77 | warmup_epochs = group["warmup_epochs"] 78 | use_lars = group["use_lars"] 79 | group["lars_lrs"] = [] 80 | 81 | for p in group["params"]: 82 | if p.grad is None: 83 | continue 84 | 85 | param_state = self.state[p] 86 | d_p = p.grad.data 87 | 88 | weight_norm = torch.norm(p.data) 89 | grad_norm = torch.norm(d_p) 90 | 91 | # Global LR computed on polynomial decay schedule 92 | warmup = min((1 + float(epoch)) / warmup_epochs, 1) 93 | global_lr = lr * warmup 94 | 95 | # Update the momentum term 96 | if use_lars: 97 | # Compute local learning rate for this layer 98 | local_lr = eta * weight_norm / (grad_norm + weight_decay * weight_norm) 99 | actual_lr = local_lr * global_lr 100 | group["lars_lrs"].append(actual_lr.item()) 101 | else: 102 | actual_lr = global_lr 103 | group["lars_lrs"].append(global_lr) 104 | 105 | if "momentum_buffer" not in param_state: 106 | buf = param_state["momentum_buffer"] = torch.zeros_like(p.data) 107 | else: 108 | buf = param_state["momentum_buffer"] 109 | 110 | buf.mul_(momentum).add_(d_p + weight_decay * p.data, alpha=actual_lr) 111 | p.data.add_(-buf) 112 | 113 | return loss 114 | -------------------------------------------------------------------------------- /linear_classifier.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import attr 4 | import pytorch_lightning as pl 5 | import torch 6 | import torch.nn.functional as F 7 | from pytorch_lightning.utilities import AttributeDict 8 | from torch.utils.data import DataLoader 9 | 10 | import utils 11 | 12 | 13 | @attr.s(auto_attribs=True) 14 | class LinearClassifierMethodParams: 15 | # encoder model selection 16 | encoder_arch: str = "resnet18" 17 | embedding_dim: int = 512 18 | 19 | # data-related parameters 20 | dataset_name: str = "stl10" 21 | batch_size: int = 256 22 | 23 | # optimization parameters 24 | lr: float = 30.0 25 | momentum: float = 0.9 26 | weight_decay: float = 0.0 27 | max_epochs: int = 100 28 | 29 | # data loader parameters 30 | num_data_workers: int = 4 31 | drop_last_batch: bool = True 32 | pin_data_memory: bool = True 33 | multi_gpu_training: bool = False 34 | 35 | 36 | class LinearClassifierMethod(pl.LightningModule): 37 | model: torch.nn.Module 38 | dataset: utils.DatasetBase 39 | hparams: AttributeDict 40 | 41 | def __init__( 42 | self, 43 | hparams: LinearClassifierMethodParams = None, 44 | **kwargs, 45 | ): 46 | super().__init__() 47 | 48 | if hparams is None: 49 | hparams = self.params(**kwargs) 50 | elif isinstance(hparams, dict): 51 | hparams = self.params(**hparams, **kwargs) 52 | 53 | self.hparams = AttributeDict(attr.asdict(hparams)) 54 | 55 | # actually do a load that is a little more flexible 56 | self.model = utils.get_encoder(hparams.encoder_arch) 57 | 58 | self.dataset = utils.get_class_dataset(hparams.dataset_name) 59 | 60 | self.classifier = torch.nn.Linear(hparams.embedding_dim, self.dataset.num_classes) 61 | 62 | def load_model_from_checkpoint(self, checkpoint_path: str): 63 | checkpoint = torch.load(checkpoint_path) 64 | state_dict = checkpoint["state_dict"] 65 | for k in list(state_dict.keys()): 66 | if not k.startswith("model."): 67 | del state_dict[k] 68 | self.load_state_dict(state_dict, strict=False) 69 | 70 | def forward(self, x): 71 | with torch.no_grad(): 72 | embedding = self.model(x) 73 | return self.classifier(embedding) 74 | 75 | def training_step(self, batch, batch_idx, **kwargs): 76 | x, y = batch 77 | y_hat = self.forward(x) 78 | loss = F.cross_entropy(y_hat, y) 79 | acc1, acc5 = utils.calculate_accuracy(y_hat, y, topk=(1, 5)) 80 | 81 | log_data = {"step_train_loss": loss, "step_train_acc1": acc1, "step_train_acc5": acc5} 82 | return {"loss": loss, "log": log_data} 83 | 84 | def validation_step(self, batch, batch_idx, **kwargs): 85 | x, y = batch 86 | y_hat = self.forward(x) 87 | acc1, acc5 = utils.calculate_accuracy(y_hat, y, topk=(1, 5)) 88 | return { 89 | "valid_loss": F.cross_entropy(y_hat, y), 90 | "valid_acc1": acc1, 91 | "valid_acc5": acc5, 92 | } 93 | 94 | def validation_epoch_end(self, outputs): 95 | avg_loss = torch.stack([x["valid_loss"] for x in outputs]).mean() 96 | avg_acc1 = torch.stack([x["valid_acc1"] for x in outputs]).mean() 97 | avg_acc5 = torch.stack([x["valid_acc5"] for x in outputs]).mean() 98 | 99 | log_data = {"valid_loss": avg_loss, "valid_acc1": avg_acc1, "valid_acc5": avg_acc5} 100 | print(log_data) 101 | return { 102 | "val_loss": avg_loss, 103 | "log": log_data, 104 | } 105 | 106 | def configure_optimizers(self): 107 | optimizer = torch.optim.SGD( 108 | self.parameters(), 109 | lr=self.hparams.lr, 110 | momentum=self.hparams.momentum, 111 | weight_decay=self.hparams.weight_decay, 112 | ) 113 | milestones = [math.floor(self.hparams.max_epochs * 0.6), math.floor(self.hparams.max_epochs * 0.8)] 114 | self.lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones) 115 | return [optimizer], [self.lr_scheduler] 116 | 117 | def train_dataloader(self): 118 | return DataLoader( 119 | self.dataset.get_train(), 120 | batch_size=self.hparams.batch_size, 121 | num_workers=self.hparams.num_data_workers, 122 | pin_memory=self.hparams.pin_data_memory, 123 | drop_last=self.hparams.drop_last_batch, 124 | shuffle=True, 125 | ) 126 | 127 | def val_dataloader(self): 128 | return DataLoader( 129 | self.dataset.get_validation(), 130 | batch_size=self.hparams.batch_size, 131 | num_workers=self.hparams.num_data_workers, 132 | pin_memory=self.hparams.pin_data_memory, 133 | drop_last=self.hparams.drop_last_batch, 134 | ) 135 | 136 | @classmethod 137 | def params(cls, **kwargs) -> LinearClassifierMethodParams: 138 | return LinearClassifierMethodParams(**kwargs) 139 | 140 | @classmethod 141 | def from_moco_checkpoint(cls, checkpoint_path, **kwargs): 142 | """ Loads hyperparameters and model from moco checkpoint """ 143 | checkpoint = torch.load(checkpoint_path) 144 | moco_hparams = checkpoint["hyper_parameters"] 145 | params = cls.params( 146 | encoder_arch=moco_hparams["encoder_arch"], 147 | embedding_dim=moco_hparams["embedding_dim"], 148 | dataset_name=moco_hparams["dataset_name"], 149 | **kwargs, 150 | ) 151 | model = cls(params) 152 | model.load_model_from_checkpoint(checkpoint_path) 153 | return model 154 | -------------------------------------------------------------------------------- /moco.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import warnings 4 | from functools import partial 5 | from typing import Optional 6 | from typing import Union 7 | 8 | import attr 9 | import pytorch_lightning as pl 10 | import torch 11 | import torch.nn.functional as F 12 | from pytorch_lightning.utilities import AttributeDict 13 | from torch.utils.data import DataLoader 14 | 15 | import utils 16 | from batchrenorm import BatchRenorm1d 17 | from lars import LARS 18 | from model_params import ModelParams 19 | from sklearn.linear_model import LogisticRegression 20 | 21 | 22 | def get_mlp_normalization(hparams: ModelParams, prediction=False): 23 | normalization_str = hparams.mlp_normalization 24 | if prediction and hparams.prediction_mlp_normalization != "same": 25 | normalization_str = hparams.prediction_mlp_normalization 26 | 27 | if normalization_str is None: 28 | return None 29 | elif normalization_str == "bn": 30 | return partial(torch.nn.BatchNorm1d, num_features=hparams.mlp_hidden_dim) 31 | elif normalization_str == "br": 32 | return partial(BatchRenorm1d, num_features=hparams.mlp_hidden_dim) 33 | elif normalization_str == "ln": 34 | return partial(torch.nn.LayerNorm, normalized_shape=[hparams.mlp_hidden_dim]) 35 | elif normalization_str == "gn": 36 | return partial(torch.nn.GroupNorm, num_channels=hparams.mlp_hidden_dim, num_groups=32) 37 | else: 38 | raise NotImplementedError(f"mlp normalization {normalization_str} not implemented") 39 | 40 | 41 | class SelfSupervisedMethod(pl.LightningModule): 42 | model: torch.nn.Module 43 | dataset: utils.DatasetBase 44 | hparams: AttributeDict 45 | embedding_dim: Optional[int] 46 | 47 | def __init__( 48 | self, 49 | hparams: Union[ModelParams, dict, None] = None, 50 | **kwargs, 51 | ): 52 | super().__init__() 53 | 54 | if hparams is None: 55 | hparams = self.params(**kwargs) 56 | elif isinstance(hparams, dict): 57 | hparams = self.params(**hparams, **kwargs) 58 | 59 | if isinstance(self.hparams, AttributeDict): 60 | self.hparams.update(AttributeDict(attr.asdict(hparams))) 61 | else: 62 | self.hparams = AttributeDict(attr.asdict(hparams)) 63 | 64 | # Check for configuration issues 65 | if ( 66 | hparams.gather_keys_for_queue 67 | and not hparams.shuffle_batch_norm 68 | and not hparams.encoder_arch.startswith("ws_") 69 | ): 70 | warnings.warn( 71 | "Configuration suspicious: gather_keys_for_queue without shuffle_batch_norm or weight standardization" 72 | ) 73 | 74 | some_negative_examples = hparams.use_negative_examples_from_batch or hparams.use_negative_examples_from_queue 75 | if hparams.loss_type == "ce" and not some_negative_examples: 76 | warnings.warn("Configuration suspicious: cross entropy loss without negative examples") 77 | 78 | # Create encoder model 79 | self.model = utils.get_encoder(hparams.encoder_arch, hparams.dataset_name) 80 | 81 | # Create dataset 82 | self.dataset = utils.get_moco_dataset(hparams) 83 | 84 | if hparams.use_lagging_model: 85 | # "key" function (no grad) 86 | self.lagging_model = copy.deepcopy(self.model) 87 | for param in self.lagging_model.parameters(): 88 | param.requires_grad = False 89 | else: 90 | self.lagging_model = None 91 | 92 | self.projection_model = utils.MLP( 93 | hparams.embedding_dim, 94 | hparams.dim, 95 | hparams.mlp_hidden_dim, 96 | num_layers=hparams.projection_mlp_layers, 97 | normalization=get_mlp_normalization(hparams), 98 | weight_standardization=hparams.use_mlp_weight_standardization, 99 | ) 100 | 101 | self.prediction_model = utils.MLP( 102 | hparams.dim, 103 | hparams.dim, 104 | hparams.mlp_hidden_dim, 105 | num_layers=hparams.prediction_mlp_layers, 106 | normalization=get_mlp_normalization(hparams, prediction=True), 107 | weight_standardization=hparams.use_mlp_weight_standardization, 108 | ) 109 | 110 | if hparams.use_lagging_model: 111 | # "key" function (no grad) 112 | self.lagging_projection_model = copy.deepcopy(self.projection_model) 113 | for param in self.lagging_projection_model.parameters(): 114 | param.requires_grad = False 115 | else: 116 | self.lagging_projection_model = None 117 | 118 | # this classifier is used to compute representation quality each epoch 119 | self.sklearn_classifier = LogisticRegression(max_iter=100, solver="liblinear") 120 | 121 | if hparams.use_negative_examples_from_queue: 122 | # create the queue 123 | self.register_buffer("queue", torch.randn(hparams.dim, hparams.K)) 124 | self.queue = torch.nn.functional.normalize(self.queue, dim=0) 125 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 126 | else: 127 | self.queue = None 128 | 129 | def _get_embeddings(self, x): 130 | """ 131 | Input: 132 | im_q: a batch of query images 133 | im_k: a batch of key images 134 | Output: 135 | logits, targets 136 | """ 137 | bsz, nd, nc, nh, nw = x.shape 138 | assert nd == 2, "second dimension should be the split image -- dims should be N2CHW" 139 | im_q = x[:, 0].contiguous() 140 | im_k = x[:, 1].contiguous() 141 | 142 | # compute query features 143 | emb_q = self.model(im_q) 144 | q_projection = self.projection_model(emb_q) 145 | q = self.prediction_model(q_projection) # queries: NxC 146 | if self.hparams.use_lagging_model: 147 | # compute key features 148 | with torch.no_grad(): # no gradient to keys 149 | if self.hparams.shuffle_batch_norm: 150 | im_k, idx_unshuffle = utils.BatchShuffleDDP.shuffle(im_k) 151 | k = self.lagging_projection_model(self.lagging_model(im_k)) # keys: NxC 152 | if self.hparams.shuffle_batch_norm: 153 | k = utils.BatchShuffleDDP.unshuffle(k, idx_unshuffle) 154 | else: 155 | emb_k = self.model(im_k) 156 | k_projection = self.projection_model(emb_k) 157 | k = self.prediction_model(k_projection) # queries: NxC 158 | 159 | if self.hparams.use_unit_sphere_projection: 160 | q = torch.nn.functional.normalize(q, dim=1) 161 | k = torch.nn.functional.normalize(k, dim=1) 162 | 163 | return emb_q, q, k 164 | 165 | def _get_contrastive_predictions(self, q, k): 166 | if self.hparams.use_negative_examples_from_batch: 167 | logits = torch.mm(q, k.T) 168 | labels = torch.arange(0, q.shape[0], dtype=torch.long).to(logits.device) 169 | return logits, labels 170 | 171 | # compute logits 172 | # Einstein sum is more intuitive 173 | # positive logits: Nx1 174 | l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1) 175 | 176 | if self.hparams.use_negative_examples_from_queue: 177 | # negative logits: NxK 178 | l_neg = torch.einsum("nc,ck->nk", [q, self.queue.clone().detach()]) 179 | logits = torch.cat([l_pos, l_neg], dim=1) 180 | else: 181 | logits = l_pos 182 | 183 | # labels: positive key indicators 184 | labels = torch.zeros(logits.shape[0], dtype=torch.long).to(logits.device) 185 | 186 | return logits, labels 187 | 188 | def _get_pos_neg_ip(self, emb_q, k): 189 | with torch.no_grad(): 190 | z = self.projection_model(emb_q) 191 | z = torch.nn.functional.normalize(z, dim=1) 192 | ip = torch.mm(z, k.T) 193 | eye = torch.eye(z.shape[0]).to(z.device) 194 | pos_ip = (ip * eye).sum() / z.shape[0] 195 | neg_ip = (ip * (1 - eye)).sum() / (z.shape[0] * (z.shape[0] - 1)) 196 | 197 | return pos_ip, neg_ip 198 | 199 | def _get_contrastive_loss(self, logits, labels): 200 | if self.hparams.loss_type == "ce": 201 | if self.hparams.use_eqco_margin: 202 | if self.hparams.use_negative_examples_from_batch: 203 | neg_factor = self.hparams.eqco_alpha / self.hparams.batch_size 204 | elif self.hparams.use_negative_examples_from_queue: 205 | neg_factor = self.hparams.eqco_alpha / self.hparams.K 206 | else: 207 | raise Exception("Must have negative examples for ce loss") 208 | 209 | predictions = utils.log_softmax_with_factors(logits / self.hparams.T, neg_factor=neg_factor) 210 | return F.nll_loss(predictions, labels) 211 | 212 | return F.cross_entropy(logits / self.hparams.T, labels) 213 | 214 | new_labels = torch.zeros_like(logits) 215 | new_labels.scatter_(1, labels.unsqueeze(1), 1) 216 | if self.hparams.loss_type == "bce": 217 | return F.binary_cross_entropy_with_logits(logits / self.hparams.T, new_labels) * logits.shape[1] 218 | 219 | if self.hparams.loss_type == "ip": 220 | # inner product 221 | # negative sign for label=1 (maximize ip), positive sign for label=0 (minimize ip) 222 | inner_product = (1 - new_labels * 2) * logits 223 | return torch.mean((inner_product + 1).sum(dim=-1)) 224 | 225 | raise NotImplementedError(f"Loss function {self.hparams.loss_type} not implemented") 226 | 227 | def _get_vicreg_loss(self, z_a, z_b, batch_idx): 228 | assert z_a.shape == z_b.shape and len(z_a.shape) == 2 229 | 230 | # invariance loss 231 | loss_inv = F.mse_loss(z_a, z_b) 232 | 233 | # variance loss 234 | std_z_a = torch.sqrt(z_a.var(dim=0) + self.hparams.variance_loss_epsilon) 235 | std_z_b = torch.sqrt(z_b.var(dim=0) + self.hparams.variance_loss_epsilon) 236 | loss_v_a = torch.mean(F.relu(1 - std_z_a)) 237 | loss_v_b = torch.mean(F.relu(1 - std_z_b)) 238 | loss_var = loss_v_a + loss_v_b 239 | 240 | # covariance loss 241 | N, D = z_a.shape 242 | z_a = z_a - z_a.mean(dim=0) 243 | z_b = z_b - z_b.mean(dim=0) 244 | cov_z_a = ((z_a.T @ z_a) / (N - 1)).square() # DxD 245 | cov_z_b = ((z_b.T @ z_b) / (N - 1)).square() # DxD 246 | loss_c_a = (cov_z_a.sum() - cov_z_a.diagonal().sum()) / D 247 | loss_c_b = (cov_z_b.sum() - cov_z_b.diagonal().sum()) / D 248 | loss_cov = loss_c_a + loss_c_b 249 | 250 | weighted_inv = loss_inv * self.hparams.invariance_loss_weight 251 | weighted_var = loss_var * self.hparams.variance_loss_weight 252 | weighted_cov = loss_cov * self.hparams.covariance_loss_weight 253 | 254 | loss = weighted_inv + weighted_var + weighted_cov 255 | 256 | return { 257 | "loss": loss, 258 | "loss_invariance": weighted_inv, 259 | "loss_variance": weighted_var, 260 | "loss_covariance": weighted_cov, 261 | } 262 | 263 | def forward(self, x): 264 | return self.model(x) 265 | 266 | def training_step(self, batch, batch_idx, optimizer_idx=None): 267 | all_params = list(self.model.parameters()) 268 | x, class_labels = batch # batch is a tuple, we just want the image 269 | 270 | emb_q, q, k = self._get_embeddings(x) 271 | pos_ip, neg_ip = self._get_pos_neg_ip(emb_q, k) 272 | 273 | logits, labels = self._get_contrastive_predictions(q, k) 274 | if self.hparams.use_vicreg_loss: 275 | losses = self._get_vicreg_loss(q, k, batch_idx) 276 | contrastive_loss = losses["loss"] 277 | else: 278 | losses = {} 279 | contrastive_loss = self._get_contrastive_loss(logits, labels) 280 | 281 | if self.hparams.use_both_augmentations_as_queries: 282 | x_flip = torch.flip(x, dims=[1]) 283 | emb_q2, q2, k2 = self._get_embeddings(x_flip) 284 | logits2, labels2 = self._get_contrastive_predictions(q2, k2) 285 | 286 | pos_ip2, neg_ip2 = self._get_pos_neg_ip(emb_q2, k2) 287 | pos_ip = (pos_ip + pos_ip2) / 2 288 | neg_ip = (neg_ip + neg_ip2) / 2 289 | contrastive_loss += self._get_contrastive_loss(logits2, labels2) 290 | 291 | contrastive_loss = contrastive_loss.mean() * self.hparams.loss_constant_factor 292 | 293 | log_data = { 294 | "step_train_loss": contrastive_loss, 295 | "step_pos_cos": pos_ip, 296 | "step_neg_cos": neg_ip, 297 | **losses, 298 | } 299 | 300 | with torch.no_grad(): 301 | self._momentum_update_key_encoder() 302 | 303 | some_negative_examples = ( 304 | self.hparams.use_negative_examples_from_batch or self.hparams.use_negative_examples_from_queue 305 | ) 306 | if some_negative_examples: 307 | acc1, acc5 = utils.calculate_accuracy(logits, labels, topk=(1, 5)) 308 | log_data.update({"step_train_acc1": acc1, "step_train_acc5": acc5}) 309 | 310 | # dequeue and enqueue 311 | if self.hparams.use_negative_examples_from_queue: 312 | self._dequeue_and_enqueue(k) 313 | 314 | self.log_dict(log_data) 315 | return {"loss": contrastive_loss} 316 | 317 | def validation_step(self, batch, batch_idx): 318 | x, class_labels = batch 319 | with torch.no_grad(): 320 | emb = self.model(x) 321 | 322 | return {"emb": emb, "labels": class_labels} 323 | 324 | def validation_epoch_end(self, outputs): 325 | embeddings = torch.cat([x["emb"] for x in outputs]).cpu().detach().numpy() 326 | labels = torch.cat([x["labels"] for x in outputs]).cpu().detach().numpy() 327 | num_split_linear = embeddings.shape[0] // 2 328 | self.sklearn_classifier.fit(embeddings[:num_split_linear], labels[:num_split_linear]) 329 | train_accuracy = self.sklearn_classifier.score(embeddings[:num_split_linear], labels[:num_split_linear]) * 100 330 | valid_accuracy = self.sklearn_classifier.score(embeddings[num_split_linear:], labels[num_split_linear:]) * 100 331 | 332 | log_data = { 333 | "epoch": self.current_epoch, 334 | "train_class_acc": train_accuracy, 335 | "valid_class_acc": valid_accuracy, 336 | "T": self._get_temp(), 337 | "m": self._get_m(), 338 | } 339 | print(f"Epoch {self.current_epoch} accuracy: train: {train_accuracy:.1f}%, validation: {valid_accuracy:.1f}%") 340 | self.log_dict(log_data) 341 | 342 | def configure_optimizers(self): 343 | # exclude bias and batch norm from LARS and weight decay 344 | regular_parameters = [] 345 | regular_parameter_names = [] 346 | excluded_parameters = [] 347 | excluded_parameter_names = [] 348 | for name, parameter in self.named_parameters(): 349 | if parameter.requires_grad is False: 350 | continue 351 | if any(x in name for x in self.hparams.exclude_matching_parameters_from_lars): 352 | excluded_parameters.append(parameter) 353 | excluded_parameter_names.append(name) 354 | else: 355 | regular_parameters.append(parameter) 356 | regular_parameter_names.append(name) 357 | 358 | param_groups = [ 359 | {"params": regular_parameters, "names": regular_parameter_names, "use_lars": True}, 360 | { 361 | "params": excluded_parameters, 362 | "names": excluded_parameter_names, 363 | "use_lars": False, 364 | "weight_decay": 0, 365 | }, 366 | ] 367 | if self.hparams.optimizer_name == "sgd": 368 | optimizer = torch.optim.SGD 369 | elif self.hparams.optimizer_name == "lars": 370 | optimizer = partial(LARS, warmup_epochs=self.hparams.lars_warmup_epochs, eta=self.hparams.lars_eta) 371 | else: 372 | raise NotImplementedError(f"No such optimizer {self.hparams.optimizer_name}") 373 | 374 | encoding_optimizer = optimizer( 375 | param_groups, 376 | lr=self.hparams.lr, 377 | momentum=self.hparams.momentum, 378 | weight_decay=self.hparams.weight_decay, 379 | ) 380 | self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 381 | encoding_optimizer, 382 | self.hparams.max_epochs, 383 | eta_min=self.hparams.final_lr_schedule_value, 384 | ) 385 | return [encoding_optimizer], [self.lr_scheduler] 386 | 387 | def _get_m(self): 388 | if self.hparams.use_momentum_schedule is False: 389 | return self.hparams.m 390 | return 1 - (1 - self.hparams.m) * (math.cos(math.pi * self.current_epoch / self.hparams.max_epochs) + 1) / 2 391 | 392 | def _get_temp(self): 393 | return self.hparams.T 394 | 395 | @torch.no_grad() 396 | def _momentum_update_key_encoder(self): 397 | """ 398 | Momentum update of the key encoder 399 | """ 400 | if not self.hparams.use_lagging_model: 401 | return 402 | m = self._get_m() 403 | for param_q, param_k in zip(self.model.parameters(), self.lagging_model.parameters()): 404 | param_k.data = param_k.data * m + param_q.data * (1.0 - m) 405 | for param_q, param_k in zip(self.projection_model.parameters(), self.lagging_projection_model.parameters()): 406 | param_k.data = param_k.data * m + param_q.data * (1.0 - m) 407 | 408 | @torch.no_grad() 409 | def _dequeue_and_enqueue(self, keys): 410 | # gather keys before updating queue 411 | if self.hparams.gather_keys_for_queue: 412 | keys = utils.concat_all_gather(keys) 413 | 414 | batch_size = keys.shape[0] 415 | 416 | ptr = int(self.queue_ptr) 417 | assert self.hparams.K % batch_size == 0 # for simplicity 418 | 419 | # replace the keys at ptr (dequeue and enqueue) 420 | self.queue[:, ptr : ptr + batch_size] = keys.T 421 | ptr = (ptr + batch_size) % self.hparams.K # move pointer 422 | 423 | self.queue_ptr[0] = ptr 424 | 425 | def prepare_data(self) -> None: 426 | self.dataset.get_train() 427 | self.dataset.get_validation() 428 | 429 | def train_dataloader(self): 430 | return DataLoader( 431 | self.dataset.get_train(), 432 | batch_size=self.hparams.batch_size, 433 | num_workers=self.hparams.num_data_workers, 434 | pin_memory=self.hparams.pin_data_memory, 435 | drop_last=self.hparams.drop_last_batch, 436 | shuffle=True, 437 | ) 438 | 439 | def val_dataloader(self): 440 | return DataLoader( 441 | self.dataset.get_validation(), 442 | batch_size=self.hparams.batch_size, 443 | num_workers=self.hparams.num_data_workers, 444 | pin_memory=self.hparams.pin_data_memory, 445 | drop_last=self.hparams.drop_last_batch, 446 | ) 447 | 448 | @classmethod 449 | def params(cls, **kwargs) -> ModelParams: 450 | return ModelParams(**kwargs) 451 | -------------------------------------------------------------------------------- /model_params.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import List 3 | from typing import Optional 4 | 5 | import attr 6 | 7 | 8 | @attr.s(auto_attribs=True) 9 | class ModelParams: 10 | # encoder model selection 11 | encoder_arch: str = "resnet18" 12 | shuffle_batch_norm: bool = False 13 | embedding_dim: int = 512 # must match embedding dim of encoder 14 | 15 | # data-related parameters 16 | dataset_name: str = "stl10" 17 | batch_size: int = 256 18 | 19 | # MoCo parameters 20 | K: int = 65536 # number of examples in queue 21 | dim: int = 128 22 | m: float = 0.996 23 | T: float = 0.2 24 | 25 | # eqco parameters 26 | eqco_alpha: int = 65536 27 | use_eqco_margin: bool = False 28 | use_negative_examples_from_batch: bool = False 29 | 30 | # optimization parameters 31 | lr: float = 0.5 32 | momentum: float = 0.9 33 | weight_decay: float = 1e-4 34 | max_epochs: int = 320 35 | final_lr_schedule_value: float = 0.0 36 | 37 | # transform parameters 38 | transform_s: float = 0.5 39 | transform_apply_blur: bool = True 40 | 41 | # Change these to make more like BYOL 42 | use_momentum_schedule: bool = False 43 | loss_type: str = "ce" 44 | use_negative_examples_from_queue: bool = True 45 | use_both_augmentations_as_queries: bool = False 46 | optimizer_name: str = "sgd" 47 | lars_warmup_epochs: int = 1 48 | lars_eta: float = 1e-3 49 | exclude_matching_parameters_from_lars: List[str] = [] # set to [".bias", ".bn"] to match paper 50 | loss_constant_factor: float = 1 51 | 52 | # Change these to make more like VICReg 53 | use_vicreg_loss: bool = False 54 | use_lagging_model: bool = True 55 | use_unit_sphere_projection: bool = True 56 | invariance_loss_weight: float = 25.0 57 | variance_loss_weight: float = 25.0 58 | covariance_loss_weight: float = 1.0 59 | variance_loss_epsilon: float = 1e-04 60 | 61 | # MLP parameters 62 | projection_mlp_layers: int = 2 63 | prediction_mlp_layers: int = 0 64 | mlp_hidden_dim: int = 512 65 | 66 | mlp_normalization: Optional[str] = None 67 | prediction_mlp_normalization: Optional[str] = "same" # if same will use mlp_normalization 68 | use_mlp_weight_standardization: bool = False 69 | 70 | # data loader parameters 71 | num_data_workers: int = 4 72 | drop_last_batch: bool = True 73 | pin_data_memory: bool = True 74 | gather_keys_for_queue: bool = False 75 | 76 | 77 | # Differences between these parameters and those used in the paper (on image net): 78 | # max_epochs=1000, 79 | # lr=1.6, 80 | # batch_size=2048, 81 | # weight_decay=1e-6, 82 | # mlp_hidden_dim=8192, 83 | # dim=8192, 84 | VICRegParams = partial( 85 | ModelParams, 86 | use_vicreg_loss=True, 87 | loss_type="vic", 88 | use_lagging_model=False, 89 | use_unit_sphere_projection=False, 90 | use_negative_examples_from_queue=False, 91 | optimizer_name="lars", 92 | exclude_matching_parameters_from_lars=[".bias", ".bn"], 93 | projection_mlp_layers=3, 94 | final_lr_schedule_value=0.002, 95 | mlp_normalization="bn", 96 | lars_warmup_epochs=10, 97 | ) 98 | 99 | BYOLParams = partial( 100 | ModelParams, 101 | prediction_mlp_layers=2, 102 | mlp_normalization="bn", 103 | loss_type="ip", 104 | use_negative_examples_from_queue=False, 105 | use_both_augmentations_as_queries=True, 106 | use_momentum_schedule=True, 107 | optimizer_name="lars", 108 | exclude_matching_parameters_from_lars=[".bias", ".bn"], 109 | loss_constant_factor=2, 110 | ) 111 | 112 | SimCLRParams = partial( 113 | ModelParams, 114 | use_negative_examples_from_batch=True, 115 | use_negative_examples_from_queue=False, 116 | use_lagging_model=False, 117 | K=0, 118 | m=0.0, 119 | use_both_augmentations_as_queries=True, 120 | ) 121 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torchvision>=0.5.0 2 | pytorch_lightning>=1.0.1 3 | torch>=1.6.0 4 | scikit_learn>=0.22.1 5 | Pillow>=7.0.0 6 | attrs>=19.3.0 -------------------------------------------------------------------------------- /train_blog.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from attr import evolve 3 | from pytorch_lightning.loggers import TensorBoardLogger 4 | 5 | from model_params import ModelParams 6 | from moco import SelfSupervisedMethod 7 | 8 | 9 | def main(): 10 | base_config = ModelParams( 11 | lr=0.8, 12 | batch_size=256, 13 | gather_keys_for_queue=False, 14 | loss_type="ip", 15 | use_negative_examples_from_queue=False, 16 | use_both_augmentations_as_queries=True, 17 | mlp_normalization="bn", 18 | prediction_mlp_layers=2, 19 | projection_mlp_layers=2, 20 | m=0.996, 21 | use_momentum_schedule=True, 22 | ) 23 | configs = { 24 | "base": base_config, 25 | "pred_only": evolve(base_config, mlp_normalization=None, prediction_mlp_normalization="bn"), 26 | "proj_only": evolve(base_config, mlp_normalization="bn", prediction_mlp_normalization=None), 27 | "no_norm": evolve(base_config, mlp_normalization=None), 28 | "layer_norm": evolve(base_config, mlp_normalization="ln"), 29 | "xent": evolve( 30 | base_config, use_negative_examples_from_queue=True, loss_type="ce", mlp_normalization=None, lr=0.02 31 | ), 32 | } 33 | for seed in range(3): 34 | for name, config in configs.items(): 35 | method = SelfSupervisedMethod(config) 36 | logger = TensorBoardLogger("tb_logs", name=f"{name}_{seed}") 37 | 38 | trainer = pl.Trainer(gpus=1, max_epochs=10, logger=logger) 39 | 40 | trainer.fit(method) 41 | 42 | 43 | if __name__ == "__main__": 44 | main() 45 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from typing import Any 4 | from typing import Callable 5 | from typing import Optional 6 | 7 | import attr 8 | import torch 9 | import torchvision 10 | from PIL import ImageFilter 11 | from torchvision import transforms 12 | from torchvision.datasets import CIFAR10 13 | from torchvision.datasets import STL10 14 | from torchvision.datasets import ImageFolder 15 | 16 | import ws_resnet 17 | from model_params import ModelParams 18 | 19 | ################### 20 | # Transform utils # 21 | ################### 22 | 23 | 24 | class GaussianBlur(object): 25 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" 26 | 27 | def __init__(self, sigma=[0.1, 2.0]): 28 | self.sigma = sigma 29 | 30 | def __call__(self, x): 31 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 32 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 33 | return x 34 | 35 | 36 | @attr.s(auto_attribs=True) 37 | class MoCoTransforms: 38 | crop_size: int = 224 39 | resize: int = 256 40 | normalize_means: list = [0.4914, 0.4822, 0.4465] 41 | normalize_stds: list = [0.2023, 0.1994, 0.2010] 42 | s: float = 0.5 43 | apply_blur: bool = True 44 | 45 | def split_transform(self, img) -> torch.Tensor: 46 | transform = self.single_transform() 47 | return torch.stack((transform(img), transform(img))) 48 | 49 | def single_transform(self): 50 | transform_list = [ 51 | transforms.RandomResizedCrop(self.crop_size, scale=(0.2, 1.0)), 52 | transforms.RandomApply( 53 | [transforms.ColorJitter(0.8 * self.s, 0.8 * self.s, 0.8 * self.s, 0.2 * self.s)], p=0.8 54 | ), 55 | transforms.RandomGrayscale(p=0.2), 56 | transforms.RandomHorizontalFlip(), 57 | ] 58 | if self.apply_blur: 59 | transform_list.append(transforms.RandomApply([GaussianBlur([0.1, 2.0])], p=0.5)) 60 | transform_list.append(transforms.ToTensor()) 61 | transform_list.append(transforms.Normalize(mean=self.normalize_means, std=self.normalize_stds)) 62 | return transforms.Compose(transform_list) 63 | 64 | def get_test_transform(self): 65 | return transforms.Compose( 66 | [ 67 | transforms.Resize(self.resize), 68 | transforms.CenterCrop(self.crop_size), 69 | transforms.ToTensor(), 70 | transforms.Normalize(mean=self.normalize_means, std=self.normalize_stds), 71 | ] 72 | ) 73 | 74 | 75 | ################# 76 | # Dataset utils # 77 | ################# 78 | 79 | 80 | @attr.s(auto_attribs=True, slots=True) 81 | class DatasetBase: 82 | _train_ds: Optional[torch.utils.data.Dataset] = None 83 | _validation_ds: Optional[torch.utils.data.Dataset] = None 84 | _test_ds: Optional[torch.utils.data.Dataset] = None 85 | transform_train: Optional[Callable] = None 86 | transform_test: Optional[Callable] = None 87 | 88 | def get_train(self) -> torch.utils.data.Dataset: 89 | if self._train_ds is None: 90 | self._train_ds = self.configure_train() 91 | return self._train_ds 92 | 93 | def configure_train(self) -> torch.utils.data.Dataset: 94 | raise NotImplementedError 95 | 96 | def get_validation(self) -> torch.utils.data.Dataset: 97 | if self._validation_ds is None: 98 | self._validation_ds = self.configure_validation() 99 | return self._validation_ds 100 | 101 | def configure_validation(self) -> torch.utils.data.Dataset: 102 | raise NotImplementedError 103 | 104 | @property 105 | def data_path(self): 106 | pathstr = os.environ.get("DATA_PATH", os.getcwd()) 107 | os.makedirs(pathstr, exist_ok=True) 108 | return pathstr 109 | 110 | @property 111 | def instance_shape(self): 112 | img = next(iter(self.get_train()))[0] 113 | return img.shape 114 | 115 | @property 116 | def num_classes(self): 117 | train_ds = self.get_train() 118 | if hasattr(train_ds, "classes"): 119 | return len(train_ds.classes) 120 | return None 121 | 122 | 123 | stl10_default_transform = transforms.Compose( 124 | [ 125 | transforms.ToTensor(), 126 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 127 | ] 128 | ) 129 | 130 | 131 | @attr.s(auto_attribs=True, slots=True) 132 | class STL10UnlabeledDataset(DatasetBase): 133 | transform_train: Callable[[Any], torch.Tensor] = stl10_default_transform 134 | transform_test: Callable[[Any], torch.Tensor] = stl10_default_transform 135 | 136 | def configure_train(self): 137 | return STL10(self.data_path, split="train+unlabeled", download=True, transform=self.transform_train) 138 | 139 | def configure_validation(self): 140 | return STL10(self.data_path, split="test", download=True, transform=self.transform_test) 141 | 142 | 143 | @attr.s(auto_attribs=True, slots=True) 144 | class STL10LabeledDataset(DatasetBase): 145 | transform_train: Callable[[Any], torch.Tensor] = stl10_default_transform 146 | transform_test: Callable[[Any], torch.Tensor] = stl10_default_transform 147 | 148 | def configure_train(self): 149 | return STL10(self.data_path, split="train", download=True, transform=self.transform_train) 150 | 151 | def configure_validation(self): 152 | return STL10(self.data_path, split="test", download=True, transform=self.transform_test) 153 | 154 | 155 | imagenet_default_transform = transforms.Compose( 156 | [ 157 | transforms.ToTensor(), 158 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 159 | ] 160 | ) 161 | 162 | 163 | @attr.s(auto_attribs=True, slots=True) 164 | class ImagenetDataset(DatasetBase): 165 | transform_train: Callable[[Any], torch.Tensor] = imagenet_default_transform 166 | transform_test: Callable[[Any], torch.Tensor] = imagenet_default_transform 167 | 168 | def configure_train(self): 169 | assert os.path.exists(self.data_path + "/imagenet/train") 170 | return ImageFolder(self.data_path + "/imagenet/train", transform=self.transform_train) 171 | 172 | def configure_validation(self): 173 | assert os.path.exists(self.data_path + "/imagenet/val") 174 | return ImageFolder(self.data_path + "/imagenet/val", transform=self.transform_test) 175 | 176 | 177 | cifar10_default_transform = transforms.Compose( 178 | [ 179 | transforms.ToTensor(), 180 | transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]), 181 | ] 182 | ) 183 | 184 | 185 | @attr.s(auto_attribs=True, slots=True) 186 | class CIFAR10Dataset(DatasetBase): 187 | transform_train: Callable[[Any], torch.Tensor] = cifar10_default_transform 188 | transform_test: Callable[[Any], torch.Tensor] = cifar10_default_transform 189 | 190 | def configure_train(self): 191 | return CIFAR10(self.data_path, train=True, download=True, transform=self.transform_train) 192 | 193 | def configure_validation(self): 194 | return CIFAR10(self.data_path, train=False, download=True, transform=self.transform_test) 195 | 196 | 197 | def get_moco_dataset(hparams: ModelParams) -> DatasetBase: 198 | if hparams.dataset_name == "stl10": 199 | crop_size = 96 200 | resize = 124 201 | normalize_means = [0.4914, 0.4823, 0.4466] 202 | normalize_stds = [0.247, 0.243, 0.261] 203 | transforms = MoCoTransforms( 204 | crop_size, resize, normalize_means, normalize_stds, hparams.transform_s, hparams.transform_apply_blur 205 | ) 206 | return STL10UnlabeledDataset( 207 | transform_train=transforms.split_transform, transform_test=transforms.get_test_transform() 208 | ) 209 | elif hparams.dataset_name == "imagenet": 210 | crop_size = 224 211 | resize = 256 212 | normalize_means = [0.485, 0.456, 0.406] 213 | normalize_stds = [0.228, 0.224, 0.225] 214 | transforms = MoCoTransforms( 215 | crop_size, resize, normalize_means, normalize_stds, hparams.transform_s, hparams.transform_apply_blur 216 | ) 217 | return ImagenetDataset( 218 | transform_train=transforms.split_transform, transform_test=transforms.get_test_transform() 219 | ) 220 | elif hparams.dataset_name == "cifar10": 221 | crop_size = 32 222 | resize = 36 223 | normalize_means = [0.4914, 0.4822, 0.4465] 224 | normalize_stds = [0.2023, 0.1994, 0.2010] 225 | transforms = MoCoTransforms( 226 | crop_size, resize, normalize_means, normalize_stds, hparams.transform_s, hparams.transform_apply_blur 227 | ) 228 | return CIFAR10Dataset( 229 | transform_train=transforms.split_transform, transform_test=transforms.get_test_transform() 230 | ) 231 | else: 232 | raise NotImplementedError(f"Dataset {name} not defined") 233 | 234 | 235 | def get_class_transforms(crop_size, resize): 236 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 237 | transform_train = transforms.Compose( 238 | [transforms.RandomResizedCrop(crop_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize] 239 | ) 240 | transform_test = transforms.Compose( 241 | [transforms.Resize(resize), transforms.CenterCrop(crop_size), transforms.ToTensor(), normalize] 242 | ) 243 | return transform_train, transform_test 244 | 245 | 246 | def get_class_dataset(name: str) -> DatasetBase: 247 | if name == "stl10": 248 | transform_train, transform_test = get_class_transforms(96, 128) 249 | return STL10LabeledDataset(transform_train=transform_train, transform_test=transform_test) 250 | elif name == "imagenet": 251 | transform_train, transform_test = get_class_transforms(224, 256) 252 | return ImagenetDataset(transform_train=transform_train, transform_test=transform_test) 253 | elif name == "cifar10": 254 | transform_train, transform_test = get_class_transforms(32, 36) 255 | return CIFAR10Dataset(transform_train=transform_train, transform_test=transform_test) 256 | raise NotImplementedError(f"Dataset {name} not defined") 257 | 258 | 259 | ##################### 260 | # Parallelism utils # 261 | ##################### 262 | 263 | 264 | @torch.no_grad() 265 | def concat_all_gather(tensor): 266 | """ 267 | Performs all_gather operation on the provided tensors. 268 | *** Warning ***: torch.distributed.all_gather has no gradient. 269 | """ 270 | tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())] 271 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 272 | 273 | output = torch.cat(tensors_gather, dim=0) 274 | return output 275 | 276 | 277 | class BatchShuffleDDP: 278 | @staticmethod 279 | @torch.no_grad() 280 | def shuffle(x): 281 | """ 282 | Batch shuffle, for making use of BatchNorm. 283 | *** Only support DistributedDataParallel (DDP) model. *** 284 | """ 285 | # gather from all gpus 286 | batch_size_this = x.shape[0] 287 | x_gather = concat_all_gather(x) 288 | batch_size_all = x_gather.shape[0] 289 | 290 | num_gpus = batch_size_all // batch_size_this 291 | 292 | # random shuffle index 293 | idx_shuffle = torch.randperm(batch_size_all).to(x.device) 294 | 295 | # broadcast to all gpus 296 | torch.distributed.broadcast(idx_shuffle, src=0) 297 | 298 | # index for restoring 299 | idx_unshuffle = torch.argsort(idx_shuffle) 300 | 301 | # shuffled index for this gpu 302 | gpu_idx = torch.distributed.get_rank() 303 | idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] 304 | 305 | return x_gather[idx_this], idx_unshuffle 306 | 307 | @staticmethod 308 | @torch.no_grad() 309 | def unshuffle(x, idx_unshuffle): 310 | """ 311 | Undo batch shuffle. 312 | *** Only support DistributedDataParallel (DDP) model. *** 313 | """ 314 | # gather from all gpus 315 | batch_size_this = x.shape[0] 316 | x_gather = concat_all_gather(x) 317 | batch_size_all = x_gather.shape[0] 318 | 319 | num_gpus = batch_size_all // batch_size_this 320 | 321 | # restored index for this gpu 322 | gpu_idx = torch.distributed.get_rank() 323 | idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] 324 | 325 | return x_gather[idx_this] 326 | 327 | 328 | ############### 329 | # Model utils # 330 | ############### 331 | 332 | 333 | class MLP(torch.nn.Module): 334 | def __init__( 335 | self, input_dim, output_dim, hidden_dim, num_layers, weight_standardization=False, normalization=None 336 | ): 337 | super().__init__() 338 | assert num_layers >= 0, "negative layers?!?" 339 | if normalization is not None: 340 | assert callable(normalization), "normalization must be callable" 341 | 342 | if num_layers == 0: 343 | self.net = torch.nn.Identity() 344 | return 345 | 346 | if num_layers == 1: 347 | self.net = torch.nn.Linear(input_dim, output_dim) 348 | return 349 | 350 | linear_net = ws_resnet.Linear if weight_standardization else torch.nn.Linear 351 | 352 | layers = [] 353 | prev_dim = input_dim 354 | for _ in range(num_layers - 1): 355 | layers.append(linear_net(prev_dim, hidden_dim)) 356 | if normalization is not None: 357 | layers.append(normalization()) 358 | layers.append(torch.nn.ReLU()) 359 | prev_dim = hidden_dim 360 | 361 | layers.append(torch.nn.Linear(hidden_dim, output_dim)) 362 | 363 | self.net = torch.nn.Sequential(*layers) 364 | 365 | def forward(self, x): 366 | return self.net(x) 367 | 368 | 369 | def get_encoder(name: str, dataset: str, **kwargs) -> torch.nn.Module: 370 | """ 371 | Gets just the encoder portion of a torchvision model (replaces final layer with identity) 372 | :param name: (str) name of the model 373 | :param name: (str) name of the dataset 374 | :param kwargs: kwargs to send to the model 375 | :return: 376 | """ 377 | 378 | if name in ws_resnet.__dict__: 379 | model_creator = ws_resnet.__dict__.get(name) 380 | elif name in torchvision.models.__dict__: 381 | model_creator = torchvision.models.__dict__.get(name) 382 | else: 383 | raise AttributeError(f"Unknown architecture {name}") 384 | 385 | assert model_creator is not None, f"no torchvision model named {name}" 386 | model = model_creator(**kwargs) 387 | if hasattr(model, "fc"): 388 | model.fc = torch.nn.Identity() 389 | if dataset == "cifar10": 390 | model.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=2, bias=False) 391 | model.maxpool = torch.nn.Identity() 392 | elif hasattr(model, "classifier"): 393 | model.classifier = torch.nn.Identity() 394 | else: 395 | raise NotImplementedError(f"Unknown class {model.__class__}") 396 | 397 | return model 398 | 399 | 400 | #################### 401 | # Evaluation utils # 402 | #################### 403 | 404 | 405 | def calculate_accuracy(output, target, topk=(1,)): 406 | """Computes the accuracy over the k top predictions for the specified values of k""" 407 | with torch.no_grad(): 408 | maxk = max(topk) 409 | batch_size = target.size(0) 410 | 411 | _, pred = output.topk(maxk, 1, True, True) 412 | pred = pred.t() 413 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 414 | 415 | res = [] 416 | for k in topk: 417 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 418 | res.append(correct_k.mul_(100.0 / batch_size)) 419 | return res 420 | 421 | 422 | def log_softmax_with_factors(logits: torch.Tensor, log_factor: float = 1, neg_factor: float = 1) -> torch.Tensor: 423 | exp_sum_neg_logits = torch.exp(logits).sum(dim=-1, keepdim=True) - torch.exp(logits) 424 | softmax_result = logits - log_factor * torch.log(torch.exp(logits) + neg_factor * exp_sum_neg_logits) 425 | return softmax_result 426 | -------------------------------------------------------------------------------- /ws_resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | From https://github.com/joe-siyuan-qiao/pytorch-classification 3 | @article{weightstandardization, 4 | author = {Siyuan Qiao and Huiyu Wang and Chenxi Liu and Wei Shen and Alan Yuille}, 5 | title = {Weight Standardization}, 6 | journal = {arXiv preprint arXiv:1903.10520}, 7 | year = {2019}, 8 | } 9 | """ 10 | 11 | import torch.nn as nn 12 | from torch.nn import functional as F 13 | 14 | 15 | class Conv2d(nn.Conv2d): 16 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): 17 | super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias) 18 | 19 | def forward(self, x): 20 | # return super(Conv2d, self).forward(x) 21 | weight = self.weight 22 | weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True).mean(dim=3, keepdim=True) 23 | weight = weight - weight_mean 24 | std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5 25 | weight = weight / std.expand_as(weight) 26 | return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 27 | 28 | 29 | class Linear(nn.Linear): 30 | def forward(self, x): 31 | weight = self.weight 32 | weight_mean = weight.mean(dim=1, keepdim=True) 33 | weight = weight - weight_mean 34 | std = weight.std(dim=1, keepdim=True) + 1e-5 35 | weight = weight / std.expand_as(weight) 36 | return F.linear(x, weight, self.bias) 37 | 38 | 39 | def BatchNorm2d(num_features): 40 | return nn.GroupNorm(num_channels=num_features, num_groups=32) 41 | 42 | 43 | __all__ = ["ws_resnet18", "ws_resnet34", "ws_resnet50", "ws_resnet101", "ws_resnet152"] 44 | 45 | 46 | def conv3x3(in_planes, out_planes, stride=1): 47 | """3x3 convolution with padding""" 48 | return Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 49 | 50 | 51 | def conv1x1(in_planes, out_planes, stride=1): 52 | """1x1 convolution""" 53 | return Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 54 | 55 | 56 | class BasicBlock(nn.Module): 57 | expansion = 1 58 | 59 | def __init__(self, inplanes, planes, stride=1, downsample=None): 60 | super(BasicBlock, self).__init__() 61 | self.conv1 = conv3x3(inplanes, planes, stride) 62 | self.bn1 = BatchNorm2d(planes) 63 | self.relu = nn.ReLU(inplace=True) 64 | self.conv2 = conv3x3(planes, planes) 65 | self.bn2 = BatchNorm2d(planes) 66 | self.downsample = downsample 67 | self.stride = stride 68 | 69 | def forward(self, x): 70 | identity = x 71 | 72 | out = self.conv1(x) 73 | out = self.bn1(out) 74 | out = self.relu(out) 75 | 76 | out = self.conv2(out) 77 | out = self.bn2(out) 78 | 79 | if self.downsample is not None: 80 | identity = self.downsample(x) 81 | 82 | out += identity 83 | out = self.relu(out) 84 | 85 | return out 86 | 87 | 88 | class Bottleneck(nn.Module): 89 | expansion = 4 90 | 91 | def __init__(self, inplanes, planes, stride=1, downsample=None): 92 | super(Bottleneck, self).__init__() 93 | self.conv1 = conv1x1(inplanes, planes) 94 | self.bn1 = BatchNorm2d(planes) 95 | self.conv2 = conv3x3(planes, planes, stride) 96 | self.bn2 = BatchNorm2d(planes) 97 | self.conv3 = conv1x1(planes, planes * self.expansion) 98 | self.bn3 = BatchNorm2d(planes * self.expansion) 99 | self.relu = nn.ReLU(inplace=True) 100 | self.downsample = downsample 101 | self.stride = stride 102 | 103 | def forward(self, x): 104 | identity = x 105 | 106 | out = self.conv1(x) 107 | out = self.bn1(out) 108 | out = self.relu(out) 109 | 110 | out = self.conv2(out) 111 | out = self.bn2(out) 112 | out = self.relu(out) 113 | 114 | out = self.conv3(out) 115 | out = self.bn3(out) 116 | 117 | if self.downsample is not None: 118 | identity = self.downsample(x) 119 | 120 | out += identity 121 | out = self.relu(out) 122 | 123 | return out 124 | 125 | 126 | class ResNet(nn.Module): 127 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False): 128 | super(ResNet, self).__init__() 129 | self.inplanes = 64 130 | self.conv1 = Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 131 | self.bn1 = BatchNorm2d(64) 132 | self.relu = nn.ReLU(inplace=True) 133 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 134 | self.layer1 = self._make_layer(block, 64, layers[0]) 135 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 136 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 137 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 138 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 139 | self.fc = nn.Linear(512 * block.expansion, num_classes) 140 | 141 | for m in self.modules(): 142 | if isinstance(m, nn.Conv2d): 143 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 144 | elif isinstance(m, nn.BatchNorm2d): 145 | nn.init.constant_(m.weight, 1) 146 | nn.init.constant_(m.bias, 0) 147 | 148 | # Zero-initialize the last BN in each residual branch, 149 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 150 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 151 | if zero_init_residual: 152 | for m in self.modules(): 153 | if isinstance(m, Bottleneck): 154 | nn.init.constant_(m.bn3.weight, 0) 155 | elif isinstance(m, BasicBlock): 156 | nn.init.constant_(m.bn2.weight, 0) 157 | 158 | def _make_layer(self, block, planes, blocks, stride=1): 159 | downsample = None 160 | if stride != 1 or self.inplanes != planes * block.expansion: 161 | downsample = nn.Sequential( 162 | conv1x1(self.inplanes, planes * block.expansion, stride), 163 | BatchNorm2d(planes * block.expansion), 164 | ) 165 | 166 | layers = [] 167 | layers.append(block(self.inplanes, planes, stride, downsample)) 168 | self.inplanes = planes * block.expansion 169 | for _ in range(1, blocks): 170 | layers.append(block(self.inplanes, planes)) 171 | 172 | return nn.Sequential(*layers) 173 | 174 | def forward(self, x): 175 | x = self.conv1(x) 176 | x = self.bn1(x) 177 | x = self.relu(x) 178 | x = self.maxpool(x) 179 | 180 | x = self.layer1(x) 181 | x = self.layer2(x) 182 | x = self.layer3(x) 183 | x = self.layer4(x) 184 | 185 | x = self.avgpool(x) 186 | x = x.view(x.size(0), -1) 187 | x = self.fc(x) 188 | 189 | return x 190 | 191 | 192 | def ws_resnet18(pretrained=False, **kwargs): 193 | """Constructs a ResNet-18 model. 194 | Args: 195 | pretrained (bool): If True, returns a model pre-trained on ImageNet 196 | """ 197 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 198 | return model 199 | 200 | 201 | def ws_resnet34(pretrained=False, **kwargs): 202 | """Constructs a ResNet-34 model. 203 | Args: 204 | pretrained (bool): If True, returns a model pre-trained on ImageNet 205 | """ 206 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 207 | return model 208 | 209 | 210 | def ws_resnet50(pretrained=False, **kwargs): 211 | """Constructs a ResNet-50 model. 212 | Args: 213 | pretrained (bool): If True, returns a model pre-trained on ImageNet 214 | """ 215 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 216 | return model 217 | 218 | 219 | def ws_resnet101(pretrained=False, **kwargs): 220 | """Constructs a ResNet-101 model. 221 | Args: 222 | pretrained (bool): If True, returns a model pre-trained on ImageNet 223 | """ 224 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 225 | return model 226 | 227 | 228 | def ws_resnet152(pretrained=False, **kwargs): 229 | """Constructs a ResNet-152 model. 230 | Args: 231 | pretrained (bool): If True, returns a model pre-trained on ImageNet 232 | """ 233 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 234 | return model 235 | --------------------------------------------------------------------------------