├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── byol_pytorch ├── __init__.py ├── byol_pytorch.py └── trainer.py ├── diagram.png ├── examples └── lightning │ ├── README.md │ └── train.py └── setup.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: '3.x' 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install build 30 | - name: Build package 31 | run: python -m build 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 34 | with: 35 | user: __token__ 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Bootstrap Your Own Latent (BYOL), in Pytorch 4 | 5 | [![PyPI version](https://badge.fury.io/py/byol-pytorch.svg)](https://badge.fury.io/py/byol-pytorch) 6 | 7 | Practical implementation of an astoundingly simple method for self-supervised learning that achieves a new state of the art (surpassing SimCLR) without contrastive learning and having to designate negative pairs. 8 | 9 | This repository offers a module that one can easily wrap any image-based neural network (residual network, discriminator, policy network) to immediately start benefitting from unlabelled image data. 10 | 11 | Update 1: There is now new evidence that batch normalization is key to making this technique work well 12 | 13 | Update 2: A new paper has successfully replaced batch norm with group norm + weight standardization, refuting that batch statistics are needed for BYOL to work 14 | 15 | Update 3: Finally, we have some analysis for why this works 16 | 17 | Yannic Kilcher's excellent explanation 18 | 19 | Now go save your organization from having to pay for labels :) 20 | 21 | ## Install 22 | 23 | ```bash 24 | $ pip install byol-pytorch 25 | ``` 26 | 27 | ## Usage 28 | 29 | Simply plugin your neural network, specifying (1) the image dimensions as well as (2) the name (or index) of the hidden layer, whose output is used as the latent representation used for self-supervised training. 30 | 31 | ```python 32 | import torch 33 | from byol_pytorch import BYOL 34 | from torchvision import models 35 | 36 | resnet = models.resnet50(pretrained=True) 37 | 38 | learner = BYOL( 39 | resnet, 40 | image_size = 256, 41 | hidden_layer = 'avgpool' 42 | ) 43 | 44 | opt = torch.optim.Adam(learner.parameters(), lr=3e-4) 45 | 46 | def sample_unlabelled_images(): 47 | return torch.randn(20, 3, 256, 256) 48 | 49 | for _ in range(100): 50 | images = sample_unlabelled_images() 51 | loss = learner(images) 52 | opt.zero_grad() 53 | loss.backward() 54 | opt.step() 55 | learner.update_moving_average() # update moving average of target encoder 56 | 57 | # save your improved network 58 | torch.save(resnet.state_dict(), './improved-net.pt') 59 | ``` 60 | 61 | That's pretty much it. After much training, the residual network should now perform better on its downstream supervised tasks. 62 | 63 | ## BYOL → SimSiam 64 | 65 | A new paper from Kaiming He suggests that BYOL does not even need the target encoder to be an exponential moving average of the online encoder. I've decided to build in this option so that you can easily use that variant for training, simply by setting the `use_momentum` flag to `False`. You will no longer need to invoke `update_moving_average` if you go this route as shown in the example below. 66 | 67 | ```python 68 | import torch 69 | from byol_pytorch import BYOL 70 | from torchvision import models 71 | 72 | resnet = models.resnet50(pretrained=True) 73 | 74 | learner = BYOL( 75 | resnet, 76 | image_size = 256, 77 | hidden_layer = 'avgpool', 78 | use_momentum = False # turn off momentum in the target encoder 79 | ) 80 | 81 | opt = torch.optim.Adam(learner.parameters(), lr=3e-4) 82 | 83 | def sample_unlabelled_images(): 84 | return torch.randn(20, 3, 256, 256) 85 | 86 | for _ in range(100): 87 | images = sample_unlabelled_images() 88 | loss = learner(images) 89 | opt.zero_grad() 90 | loss.backward() 91 | opt.step() 92 | 93 | # save your improved network 94 | torch.save(resnet.state_dict(), './improved-net.pt') 95 | ``` 96 | 97 | ## Advanced 98 | 99 | While the hyperparameters have already been set to what the paper has found optimal, you can change them with extra keyword arguments to the base wrapper class. 100 | 101 | ```python 102 | learner = BYOL( 103 | resnet, 104 | image_size = 256, 105 | hidden_layer = 'avgpool', 106 | projection_size = 256, # the projection size 107 | projection_hidden_size = 4096, # the hidden dimension of the MLP for both the projection and prediction 108 | moving_average_decay = 0.99 # the moving average decay factor for the target encoder, already set at what paper recommends 109 | ) 110 | ``` 111 | 112 | By default, this library will use the augmentations from the SimCLR paper (which is also used in the BYOL paper). However, if you would like to specify your own augmentation pipeline, you can simply pass in your own custom augmentation function with the `augment_fn` keyword. 113 | 114 | ```python 115 | augment_fn = nn.Sequential( 116 | kornia.augmentation.RandomHorizontalFlip() 117 | ) 118 | 119 | learner = BYOL( 120 | resnet, 121 | image_size = 256, 122 | hidden_layer = -2, 123 | augment_fn = augment_fn 124 | ) 125 | ``` 126 | 127 | In the paper, they seem to assure that one of the augmentations have a higher gaussian blur probability than the other. You can also adjust this to your heart's delight. 128 | 129 | ```python 130 | augment_fn = nn.Sequential( 131 | kornia.augmentation.RandomHorizontalFlip() 132 | ) 133 | 134 | augment_fn2 = nn.Sequential( 135 | kornia.augmentation.RandomHorizontalFlip(), 136 | kornia.filters.GaussianBlur2d((3, 3), (1.5, 1.5)) 137 | ) 138 | 139 | learner = BYOL( 140 | resnet, 141 | image_size = 256, 142 | hidden_layer = -2, 143 | augment_fn = augment_fn, 144 | augment_fn2 = augment_fn2, 145 | ) 146 | ``` 147 | 148 | To fetch the embeddings or the projections, you simply have to pass in a `return_embeddings = True` flag to the `BYOL` learner instance 149 | 150 | ```python 151 | import torch 152 | from byol_pytorch import BYOL 153 | from torchvision import models 154 | 155 | resnet = models.resnet50(pretrained=True) 156 | 157 | learner = BYOL( 158 | resnet, 159 | image_size = 256, 160 | hidden_layer = 'avgpool' 161 | ) 162 | 163 | imgs = torch.randn(2, 3, 256, 256) 164 | projection, embedding = learner(imgs, return_embedding = True) 165 | ``` 166 | 167 | ## Distributed Training 168 | 169 | The repository now offers distributed training with 🤗 Huggingface Accelerate. You just have to pass in your own `Dataset` into the imported `BYOLTrainer` 170 | 171 | First setup the configuration for distributed training by invoking the accelerate CLI 172 | 173 | ```bash 174 | $ accelerate config 175 | ``` 176 | 177 | Then craft your training script as shown below, say in `./train.py` 178 | 179 | ```python 180 | from torchvision import models 181 | 182 | from byol_pytorch import ( 183 | BYOL, 184 | BYOLTrainer, 185 | MockDataset 186 | ) 187 | 188 | resnet = models.resnet50(pretrained = True) 189 | 190 | dataset = MockDataset(256, 10000) 191 | 192 | trainer = BYOLTrainer( 193 | resnet, 194 | dataset = dataset, 195 | image_size = 256, 196 | hidden_layer = 'avgpool', 197 | learning_rate = 3e-4, 198 | num_train_steps = 100_000, 199 | batch_size = 16, 200 | checkpoint_every = 1000 # improved model will be saved periodically to ./checkpoints folder 201 | ) 202 | 203 | trainer() 204 | ``` 205 | 206 | Then use the accelerate CLI again to launch the script 207 | 208 | ```bash 209 | $ accelerate launch ./train.py 210 | ``` 211 | 212 | ## Alternatives 213 | 214 | If your downstream task involves segmentation, please look at the following repository, which extends BYOL to 'pixel'-level learning. 215 | 216 | https://github.com/lucidrains/pixel-level-contrastive-learning 217 | 218 | ## Citation 219 | 220 | ```bibtex 221 | @misc{grill2020bootstrap, 222 | title = {Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning}, 223 | author = {Jean-Bastien Grill and Florian Strub and Florent Altché and Corentin Tallec and Pierre H. Richemond and Elena Buchatskaya and Carl Doersch and Bernardo Avila Pires and Zhaohan Daniel Guo and Mohammad Gheshlaghi Azar and Bilal Piot and Koray Kavukcuoglu and Rémi Munos and Michal Valko}, 224 | year = {2020}, 225 | eprint = {2006.07733}, 226 | archivePrefix = {arXiv}, 227 | primaryClass = {cs.LG} 228 | } 229 | ``` 230 | 231 | ```bibtex 232 | @misc{chen2020exploring, 233 | title={Exploring Simple Siamese Representation Learning}, 234 | author={Xinlei Chen and Kaiming He}, 235 | year={2020}, 236 | eprint={2011.10566}, 237 | archivePrefix={arXiv}, 238 | primaryClass={cs.CV} 239 | } 240 | ``` 241 | -------------------------------------------------------------------------------- /byol_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from byol_pytorch.byol_pytorch import BYOL 2 | from byol_pytorch.trainer import BYOLTrainer, MockDataset 3 | -------------------------------------------------------------------------------- /byol_pytorch/byol_pytorch.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | from functools import wraps 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | import torch.distributed as dist 9 | 10 | from torchvision import transforms as T 11 | 12 | # helper functions 13 | 14 | def default(val, def_val): 15 | return def_val if val is None else val 16 | 17 | def flatten(t): 18 | return t.reshape(t.shape[0], -1) 19 | 20 | def singleton(cache_key): 21 | def inner_fn(fn): 22 | @wraps(fn) 23 | def wrapper(self, *args, **kwargs): 24 | instance = getattr(self, cache_key) 25 | if instance is not None: 26 | return instance 27 | 28 | instance = fn(self, *args, **kwargs) 29 | setattr(self, cache_key, instance) 30 | return instance 31 | return wrapper 32 | return inner_fn 33 | 34 | def get_module_device(module): 35 | return next(module.parameters()).device 36 | 37 | def set_requires_grad(model, val): 38 | for p in model.parameters(): 39 | p.requires_grad = val 40 | 41 | def MaybeSyncBatchnorm(is_distributed = None): 42 | is_distributed = default(is_distributed, dist.is_initialized() and dist.get_world_size() > 1) 43 | return nn.SyncBatchNorm if is_distributed else nn.BatchNorm1d 44 | 45 | # loss fn 46 | 47 | def loss_fn(x, y): 48 | x = F.normalize(x, dim=-1, p=2) 49 | y = F.normalize(y, dim=-1, p=2) 50 | return 2 - 2 * (x * y).sum(dim=-1) 51 | 52 | # augmentation utils 53 | 54 | class RandomApply(nn.Module): 55 | def __init__(self, fn, p): 56 | super().__init__() 57 | self.fn = fn 58 | self.p = p 59 | def forward(self, x): 60 | if random.random() > self.p: 61 | return x 62 | return self.fn(x) 63 | 64 | # exponential moving average 65 | 66 | class EMA(): 67 | def __init__(self, beta): 68 | super().__init__() 69 | self.beta = beta 70 | 71 | def update_average(self, old, new): 72 | if old is None: 73 | return new 74 | return old * self.beta + (1 - self.beta) * new 75 | 76 | def update_moving_average(ema_updater, ma_model, current_model): 77 | for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): 78 | old_weight, up_weight = ma_params.data, current_params.data 79 | ma_params.data = ema_updater.update_average(old_weight, up_weight) 80 | 81 | # MLP class for projector and predictor 82 | 83 | def MLP(dim, projection_size, hidden_size=4096, sync_batchnorm=None): 84 | return nn.Sequential( 85 | nn.Linear(dim, hidden_size), 86 | MaybeSyncBatchnorm(sync_batchnorm)(hidden_size), 87 | nn.ReLU(inplace=True), 88 | nn.Linear(hidden_size, projection_size) 89 | ) 90 | 91 | def SimSiamMLP(dim, projection_size, hidden_size=4096, sync_batchnorm=None): 92 | return nn.Sequential( 93 | nn.Linear(dim, hidden_size, bias=False), 94 | MaybeSyncBatchnorm(sync_batchnorm)(hidden_size), 95 | nn.ReLU(inplace=True), 96 | nn.Linear(hidden_size, hidden_size, bias=False), 97 | MaybeSyncBatchnorm(sync_batchnorm)(hidden_size), 98 | nn.ReLU(inplace=True), 99 | nn.Linear(hidden_size, projection_size, bias=False), 100 | MaybeSyncBatchnorm(sync_batchnorm)(projection_size, affine=False) 101 | ) 102 | 103 | # a wrapper class for the base neural network 104 | # will manage the interception of the hidden layer output 105 | # and pipe it into the projecter and predictor nets 106 | 107 | class NetWrapper(nn.Module): 108 | def __init__(self, net, projection_size, projection_hidden_size, layer = -2, use_simsiam_mlp = False, sync_batchnorm = None): 109 | super().__init__() 110 | self.net = net 111 | self.layer = layer 112 | 113 | self.projector = None 114 | self.projection_size = projection_size 115 | self.projection_hidden_size = projection_hidden_size 116 | 117 | self.use_simsiam_mlp = use_simsiam_mlp 118 | self.sync_batchnorm = sync_batchnorm 119 | 120 | self.hidden = {} 121 | self.hook_registered = False 122 | 123 | def _find_layer(self): 124 | if type(self.layer) == str: 125 | modules = dict([*self.net.named_modules()]) 126 | return modules.get(self.layer, None) 127 | elif type(self.layer) == int: 128 | children = [*self.net.children()] 129 | return children[self.layer] 130 | return None 131 | 132 | def _hook(self, _, input, output): 133 | device = input[0].device 134 | self.hidden[device] = flatten(output) 135 | 136 | def _register_hook(self): 137 | layer = self._find_layer() 138 | assert layer is not None, f'hidden layer ({self.layer}) not found' 139 | handle = layer.register_forward_hook(self._hook) 140 | self.hook_registered = True 141 | 142 | @singleton('projector') 143 | def _get_projector(self, hidden): 144 | _, dim = hidden.shape 145 | create_mlp_fn = MLP if not self.use_simsiam_mlp else SimSiamMLP 146 | projector = create_mlp_fn(dim, self.projection_size, self.projection_hidden_size, sync_batchnorm = self.sync_batchnorm) 147 | return projector.to(hidden) 148 | 149 | def get_representation(self, x): 150 | if self.layer == -1: 151 | return self.net(x) 152 | 153 | if not self.hook_registered: 154 | self._register_hook() 155 | 156 | self.hidden.clear() 157 | _ = self.net(x) 158 | hidden = self.hidden[x.device] 159 | self.hidden.clear() 160 | 161 | assert hidden is not None, f'hidden layer {self.layer} never emitted an output' 162 | return hidden 163 | 164 | def forward(self, x, return_projection = True): 165 | representation = self.get_representation(x) 166 | 167 | if not return_projection: 168 | return representation 169 | 170 | projector = self._get_projector(representation) 171 | projection = projector(representation) 172 | return projection, representation 173 | 174 | # main class 175 | 176 | class BYOL(nn.Module): 177 | def __init__( 178 | self, 179 | net, 180 | image_size, 181 | hidden_layer = -2, 182 | projection_size = 256, 183 | projection_hidden_size = 4096, 184 | augment_fn = None, 185 | augment_fn2 = None, 186 | moving_average_decay = 0.99, 187 | use_momentum = True, 188 | sync_batchnorm = None 189 | ): 190 | super().__init__() 191 | self.net = net 192 | 193 | # default SimCLR augmentation 194 | 195 | DEFAULT_AUG = torch.nn.Sequential( 196 | RandomApply( 197 | T.ColorJitter(0.8, 0.8, 0.8, 0.2), 198 | p = 0.3 199 | ), 200 | T.RandomGrayscale(p=0.2), 201 | T.RandomHorizontalFlip(), 202 | RandomApply( 203 | T.GaussianBlur((3, 3), (1.0, 2.0)), 204 | p = 0.2 205 | ), 206 | T.RandomResizedCrop((image_size, image_size)), 207 | T.Normalize( 208 | mean=torch.tensor([0.485, 0.456, 0.406]), 209 | std=torch.tensor([0.229, 0.224, 0.225])), 210 | ) 211 | 212 | self.augment1 = default(augment_fn, DEFAULT_AUG) 213 | self.augment2 = default(augment_fn2, self.augment1) 214 | 215 | self.online_encoder = NetWrapper( 216 | net, 217 | projection_size, 218 | projection_hidden_size, 219 | layer = hidden_layer, 220 | use_simsiam_mlp = not use_momentum, 221 | sync_batchnorm = sync_batchnorm 222 | ) 223 | 224 | self.use_momentum = use_momentum 225 | self.target_encoder = None 226 | self.target_ema_updater = EMA(moving_average_decay) 227 | 228 | self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size) 229 | 230 | # get device of network and make wrapper same device 231 | device = get_module_device(net) 232 | self.to(device) 233 | 234 | # send a mock image tensor to instantiate singleton parameters 235 | self.forward(torch.randn(2, 3, image_size, image_size, device=device)) 236 | 237 | @singleton('target_encoder') 238 | def _get_target_encoder(self): 239 | target_encoder = copy.deepcopy(self.online_encoder) 240 | set_requires_grad(target_encoder, False) 241 | return target_encoder 242 | 243 | def reset_moving_average(self): 244 | del self.target_encoder 245 | self.target_encoder = None 246 | 247 | def update_moving_average(self): 248 | assert self.use_momentum, 'you do not need to update the moving average, since you have turned off momentum for the target encoder' 249 | assert self.target_encoder is not None, 'target encoder has not been created yet' 250 | update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder) 251 | 252 | def forward( 253 | self, 254 | x, 255 | return_embedding = False, 256 | return_projection = True 257 | ): 258 | assert not (self.training and x.shape[0] == 1), 'you must have greater than 1 sample when training, due to the batchnorm in the projection layer' 259 | 260 | if return_embedding: 261 | return self.online_encoder(x, return_projection = return_projection) 262 | 263 | image_one, image_two = self.augment1(x), self.augment2(x) 264 | 265 | images = torch.cat((image_one, image_two), dim = 0) 266 | 267 | online_projections, _ = self.online_encoder(images) 268 | online_predictions = self.online_predictor(online_projections) 269 | 270 | online_pred_one, online_pred_two = online_predictions.chunk(2, dim = 0) 271 | 272 | with torch.no_grad(): 273 | target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder 274 | 275 | target_projections, _ = target_encoder(images) 276 | target_projections = target_projections.detach() 277 | 278 | target_proj_one, target_proj_two = target_projections.chunk(2, dim = 0) 279 | 280 | loss_one = loss_fn(online_pred_one, target_proj_two.detach()) 281 | loss_two = loss_fn(online_pred_two, target_proj_one.detach()) 282 | 283 | loss = loss_one + loss_two 284 | return loss.mean() 285 | -------------------------------------------------------------------------------- /byol_pytorch/trainer.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | import torch.distributed as dist 5 | from torch.nn import Module 6 | from torch.nn import SyncBatchNorm 7 | 8 | from torch.optim import Optimizer, Adam 9 | from torch.utils.data import Dataset, DataLoader 10 | 11 | from byol_pytorch.byol_pytorch import BYOL 12 | 13 | from beartype import beartype 14 | from beartype.typing import Optional 15 | 16 | from accelerate import Accelerator 17 | from accelerate.utils import DistributedDataParallelKwargs 18 | 19 | # constants 20 | 21 | DEFAULT_DDP_KWARGS = DistributedDataParallelKwargs( 22 | find_unused_parameters = True 23 | ) 24 | 25 | # functions 26 | 27 | def exists(v): 28 | return v is not None 29 | 30 | def cycle(dl): 31 | while True: 32 | for batch in dl: 33 | yield batch 34 | 35 | # class 36 | 37 | class MockDataset(Dataset): 38 | def __init__(self, image_size, length): 39 | self.length = length 40 | self.image_size = image_size 41 | 42 | def __len__(self): 43 | return self.length 44 | 45 | def __getitem__(self, idx): 46 | return torch.randn(3, self.image_size, self.image_size) 47 | 48 | # main trainer 49 | 50 | class BYOLTrainer(Module): 51 | @beartype 52 | def __init__( 53 | self, 54 | net: Module, 55 | *, 56 | image_size: int, 57 | hidden_layer: str, 58 | learning_rate: float, 59 | dataset: Dataset, 60 | num_train_steps: int, 61 | batch_size: int = 16, 62 | optimizer_klass = Adam, 63 | checkpoint_every: int = 1000, 64 | checkpoint_folder: str = './checkpoints', 65 | byol_kwargs: dict = dict(), 66 | optimizer_kwargs: dict = dict(), 67 | accelerator_kwargs: dict = dict(), 68 | ): 69 | super().__init__() 70 | 71 | if 'kwargs_handlers' not in accelerator_kwargs: 72 | accelerator_kwargs['kwargs_handlers'] = [DEFAULT_DDP_KWARGS] 73 | 74 | self.accelerator = Accelerator(**accelerator_kwargs) 75 | 76 | if dist.is_initialized() and dist.get_world_size() > 1: 77 | net = SyncBatchNorm.convert_sync_batchnorm(net) 78 | 79 | self.net = net 80 | 81 | self.byol = BYOL(net, image_size = image_size, hidden_layer = hidden_layer, **byol_kwargs) 82 | 83 | self.optimizer = optimizer_klass(self.byol.parameters(), lr = learning_rate, **optimizer_kwargs) 84 | 85 | self.dataloader = DataLoader(dataset, shuffle = True, batch_size = batch_size) 86 | 87 | self.num_train_steps = num_train_steps 88 | 89 | self.checkpoint_every = checkpoint_every 90 | self.checkpoint_folder = Path(checkpoint_folder) 91 | self.checkpoint_folder.mkdir(exist_ok = True, parents = True) 92 | assert self.checkpoint_folder.is_dir() 93 | 94 | # prepare with accelerate 95 | 96 | ( 97 | self.byol, 98 | self.optimizer, 99 | self.dataloader 100 | ) = self.accelerator.prepare( 101 | self.byol, 102 | self.optimizer, 103 | self.dataloader 104 | ) 105 | 106 | self.register_buffer('step', torch.tensor(0)) 107 | 108 | def wait(self): 109 | return self.accelerator.wait_for_everyone() 110 | 111 | def print(self, msg): 112 | return self.accelerator.print(msg) 113 | 114 | def forward(self): 115 | step = self.step.item() 116 | data_it = cycle(self.dataloader) 117 | 118 | for _ in range(self.num_train_steps): 119 | images = next(data_it) 120 | 121 | with self.accelerator.autocast(): 122 | loss = self.byol(images) 123 | self.accelerator.backward(loss) 124 | 125 | self.print(f'loss {loss.item():.3f}') 126 | 127 | self.optimizer.step() 128 | self.optimizer.zero_grad() 129 | 130 | self.wait() 131 | 132 | self.byol.update_moving_average() 133 | 134 | self.wait() 135 | 136 | if not (step % self.checkpoint_every) and self.accelerator.is_main_process: 137 | checkpoint_num = step // self.checkpoint_every 138 | checkpoint_path = self.checkpoint_folder / f'checkpoint.{checkpoint_num}.pt' 139 | torch.save(self.net.state_dict(), str(checkpoint_path)) 140 | 141 | self.wait() 142 | 143 | step += 1 144 | 145 | self.print('training complete') 146 | -------------------------------------------------------------------------------- /diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/byol-pytorch/0c3ab5409181852f8495ef924dce9186f94d9126/diagram.png -------------------------------------------------------------------------------- /examples/lightning/README.md: -------------------------------------------------------------------------------- 1 | ## Pytorch-lightning example script 2 | 3 | ### Requirements 4 | 5 | ```bash 6 | $ pip install pytorch-lightning 7 | $ pip install pillow 8 | ``` 9 | 10 | ### Run 11 | 12 | ```bash 13 | $ python train.py --image_folder /path/to/your/images 14 | ``` 15 | -------------------------------------------------------------------------------- /examples/lightning/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import multiprocessing 4 | from pathlib import Path 5 | from PIL import Image 6 | 7 | import torch 8 | from torchvision import models, transforms 9 | from torch.utils.data import DataLoader, Dataset 10 | 11 | from byol_pytorch import BYOL 12 | import pytorch_lightning as pl 13 | 14 | # test model, a resnet 50 15 | 16 | resnet = models.resnet50(pretrained=True) 17 | 18 | # arguments 19 | 20 | parser = argparse.ArgumentParser(description='byol-lightning-test') 21 | 22 | parser.add_argument('--image_folder', type=str, required = True, 23 | help='path to your folder of images for self-supervised learning') 24 | 25 | args = parser.parse_args() 26 | 27 | # constants 28 | 29 | BATCH_SIZE = 32 30 | EPOCHS = 1000 31 | LR = 3e-4 32 | NUM_GPUS = 2 33 | IMAGE_SIZE = 256 34 | IMAGE_EXTS = ['.jpg', '.png', '.jpeg'] 35 | NUM_WORKERS = multiprocessing.cpu_count() 36 | 37 | # pytorch lightning module 38 | 39 | class SelfSupervisedLearner(pl.LightningModule): 40 | def __init__(self, net, **kwargs): 41 | super().__init__() 42 | self.learner = BYOL(net, **kwargs) 43 | 44 | def forward(self, images): 45 | return self.learner(images) 46 | 47 | def training_step(self, images, _): 48 | loss = self.forward(images) 49 | return {'loss': loss} 50 | 51 | def configure_optimizers(self): 52 | return torch.optim.Adam(self.parameters(), lr=LR) 53 | 54 | def on_before_zero_grad(self, _): 55 | if self.learner.use_momentum: 56 | self.learner.update_moving_average() 57 | 58 | # images dataset 59 | 60 | def expand_greyscale(t): 61 | return t.expand(3, -1, -1) 62 | 63 | class ImagesDataset(Dataset): 64 | def __init__(self, folder, image_size): 65 | super().__init__() 66 | self.folder = folder 67 | self.paths = [] 68 | 69 | for path in Path(f'{folder}').glob('**/*'): 70 | _, ext = os.path.splitext(path) 71 | if ext.lower() in IMAGE_EXTS: 72 | self.paths.append(path) 73 | 74 | print(f'{len(self.paths)} images found') 75 | 76 | self.transform = transforms.Compose([ 77 | transforms.Resize(image_size), 78 | transforms.CenterCrop(image_size), 79 | transforms.ToTensor(), 80 | transforms.Lambda(expand_greyscale) 81 | ]) 82 | 83 | def __len__(self): 84 | return len(self.paths) 85 | 86 | def __getitem__(self, index): 87 | path = self.paths[index] 88 | img = Image.open(path) 89 | img = img.convert('RGB') 90 | return self.transform(img) 91 | 92 | # main 93 | 94 | if __name__ == '__main__': 95 | ds = ImagesDataset(args.image_folder, IMAGE_SIZE) 96 | train_loader = DataLoader(ds, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True) 97 | 98 | model = SelfSupervisedLearner( 99 | resnet, 100 | image_size = IMAGE_SIZE, 101 | hidden_layer = 'avgpool', 102 | projection_size = 256, 103 | projection_hidden_size = 4096, 104 | moving_average_decay = 0.99 105 | ) 106 | 107 | trainer = pl.Trainer( 108 | gpus = NUM_GPUS, 109 | max_epochs = EPOCHS, 110 | accumulate_grad_batches = 1, 111 | sync_batchnorm = True 112 | ) 113 | 114 | trainer.fit(model, train_loader) 115 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'byol-pytorch', 5 | packages = find_packages(exclude=['examples']), 6 | version = '0.8.2', 7 | license='MIT', 8 | description = 'Self-supervised contrastive learning made simple', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | url = 'https://github.com/lucidrains/byol-pytorch', 12 | long_description_content_type = 'text/markdown', 13 | keywords = [ 14 | 'self-supervised learning', 15 | 'artificial intelligence' 16 | ], 17 | install_requires=[ 18 | 'accelerate', 19 | 'beartype', 20 | 'torch>=1.6', 21 | 'torchvision>=0.8' 22 | ], 23 | classifiers=[ 24 | 'Development Status :: 4 - Beta', 25 | 'Intended Audience :: Developers', 26 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 27 | 'License :: OSI Approved :: MIT License', 28 | 'Programming Language :: Python :: 3.6', 29 | ], 30 | ) 31 | --------------------------------------------------------------------------------