├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── byol_pytorch
├── __init__.py
├── byol_pytorch.py
└── trainer.py
├── diagram.png
├── examples
└── lightning
│ ├── README.md
│ └── train.py
└── setup.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | jobs:
16 | deploy:
17 |
18 | runs-on: ubuntu-latest
19 |
20 | steps:
21 | - uses: actions/checkout@v2
22 | - name: Set up Python
23 | uses: actions/setup-python@v2
24 | with:
25 | python-version: '3.x'
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | pip install build
30 | - name: Build package
31 | run: python -m build
32 | - name: Publish package
33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
34 | with:
35 | user: __token__
36 | password: ${{ secrets.PYPI_API_TOKEN }}
37 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Bootstrap Your Own Latent (BYOL), in Pytorch
4 |
5 | [](https://badge.fury.io/py/byol-pytorch)
6 |
7 | Practical implementation of an astoundingly simple method for self-supervised learning that achieves a new state of the art (surpassing SimCLR) without contrastive learning and having to designate negative pairs.
8 |
9 | This repository offers a module that one can easily wrap any image-based neural network (residual network, discriminator, policy network) to immediately start benefitting from unlabelled image data.
10 |
11 | Update 1: There is now new evidence that batch normalization is key to making this technique work well
12 |
13 | Update 2: A new paper has successfully replaced batch norm with group norm + weight standardization, refuting that batch statistics are needed for BYOL to work
14 |
15 | Update 3: Finally, we have some analysis for why this works
16 |
17 | Yannic Kilcher's excellent explanation
18 |
19 | Now go save your organization from having to pay for labels :)
20 |
21 | ## Install
22 |
23 | ```bash
24 | $ pip install byol-pytorch
25 | ```
26 |
27 | ## Usage
28 |
29 | Simply plugin your neural network, specifying (1) the image dimensions as well as (2) the name (or index) of the hidden layer, whose output is used as the latent representation used for self-supervised training.
30 |
31 | ```python
32 | import torch
33 | from byol_pytorch import BYOL
34 | from torchvision import models
35 |
36 | resnet = models.resnet50(pretrained=True)
37 |
38 | learner = BYOL(
39 | resnet,
40 | image_size = 256,
41 | hidden_layer = 'avgpool'
42 | )
43 |
44 | opt = torch.optim.Adam(learner.parameters(), lr=3e-4)
45 |
46 | def sample_unlabelled_images():
47 | return torch.randn(20, 3, 256, 256)
48 |
49 | for _ in range(100):
50 | images = sample_unlabelled_images()
51 | loss = learner(images)
52 | opt.zero_grad()
53 | loss.backward()
54 | opt.step()
55 | learner.update_moving_average() # update moving average of target encoder
56 |
57 | # save your improved network
58 | torch.save(resnet.state_dict(), './improved-net.pt')
59 | ```
60 |
61 | That's pretty much it. After much training, the residual network should now perform better on its downstream supervised tasks.
62 |
63 | ## BYOL → SimSiam
64 |
65 | A new paper from Kaiming He suggests that BYOL does not even need the target encoder to be an exponential moving average of the online encoder. I've decided to build in this option so that you can easily use that variant for training, simply by setting the `use_momentum` flag to `False`. You will no longer need to invoke `update_moving_average` if you go this route as shown in the example below.
66 |
67 | ```python
68 | import torch
69 | from byol_pytorch import BYOL
70 | from torchvision import models
71 |
72 | resnet = models.resnet50(pretrained=True)
73 |
74 | learner = BYOL(
75 | resnet,
76 | image_size = 256,
77 | hidden_layer = 'avgpool',
78 | use_momentum = False # turn off momentum in the target encoder
79 | )
80 |
81 | opt = torch.optim.Adam(learner.parameters(), lr=3e-4)
82 |
83 | def sample_unlabelled_images():
84 | return torch.randn(20, 3, 256, 256)
85 |
86 | for _ in range(100):
87 | images = sample_unlabelled_images()
88 | loss = learner(images)
89 | opt.zero_grad()
90 | loss.backward()
91 | opt.step()
92 |
93 | # save your improved network
94 | torch.save(resnet.state_dict(), './improved-net.pt')
95 | ```
96 |
97 | ## Advanced
98 |
99 | While the hyperparameters have already been set to what the paper has found optimal, you can change them with extra keyword arguments to the base wrapper class.
100 |
101 | ```python
102 | learner = BYOL(
103 | resnet,
104 | image_size = 256,
105 | hidden_layer = 'avgpool',
106 | projection_size = 256, # the projection size
107 | projection_hidden_size = 4096, # the hidden dimension of the MLP for both the projection and prediction
108 | moving_average_decay = 0.99 # the moving average decay factor for the target encoder, already set at what paper recommends
109 | )
110 | ```
111 |
112 | By default, this library will use the augmentations from the SimCLR paper (which is also used in the BYOL paper). However, if you would like to specify your own augmentation pipeline, you can simply pass in your own custom augmentation function with the `augment_fn` keyword.
113 |
114 | ```python
115 | augment_fn = nn.Sequential(
116 | kornia.augmentation.RandomHorizontalFlip()
117 | )
118 |
119 | learner = BYOL(
120 | resnet,
121 | image_size = 256,
122 | hidden_layer = -2,
123 | augment_fn = augment_fn
124 | )
125 | ```
126 |
127 | In the paper, they seem to assure that one of the augmentations have a higher gaussian blur probability than the other. You can also adjust this to your heart's delight.
128 |
129 | ```python
130 | augment_fn = nn.Sequential(
131 | kornia.augmentation.RandomHorizontalFlip()
132 | )
133 |
134 | augment_fn2 = nn.Sequential(
135 | kornia.augmentation.RandomHorizontalFlip(),
136 | kornia.filters.GaussianBlur2d((3, 3), (1.5, 1.5))
137 | )
138 |
139 | learner = BYOL(
140 | resnet,
141 | image_size = 256,
142 | hidden_layer = -2,
143 | augment_fn = augment_fn,
144 | augment_fn2 = augment_fn2,
145 | )
146 | ```
147 |
148 | To fetch the embeddings or the projections, you simply have to pass in a `return_embeddings = True` flag to the `BYOL` learner instance
149 |
150 | ```python
151 | import torch
152 | from byol_pytorch import BYOL
153 | from torchvision import models
154 |
155 | resnet = models.resnet50(pretrained=True)
156 |
157 | learner = BYOL(
158 | resnet,
159 | image_size = 256,
160 | hidden_layer = 'avgpool'
161 | )
162 |
163 | imgs = torch.randn(2, 3, 256, 256)
164 | projection, embedding = learner(imgs, return_embedding = True)
165 | ```
166 |
167 | ## Distributed Training
168 |
169 | The repository now offers distributed training with 🤗 Huggingface Accelerate. You just have to pass in your own `Dataset` into the imported `BYOLTrainer`
170 |
171 | First setup the configuration for distributed training by invoking the accelerate CLI
172 |
173 | ```bash
174 | $ accelerate config
175 | ```
176 |
177 | Then craft your training script as shown below, say in `./train.py`
178 |
179 | ```python
180 | from torchvision import models
181 |
182 | from byol_pytorch import (
183 | BYOL,
184 | BYOLTrainer,
185 | MockDataset
186 | )
187 |
188 | resnet = models.resnet50(pretrained = True)
189 |
190 | dataset = MockDataset(256, 10000)
191 |
192 | trainer = BYOLTrainer(
193 | resnet,
194 | dataset = dataset,
195 | image_size = 256,
196 | hidden_layer = 'avgpool',
197 | learning_rate = 3e-4,
198 | num_train_steps = 100_000,
199 | batch_size = 16,
200 | checkpoint_every = 1000 # improved model will be saved periodically to ./checkpoints folder
201 | )
202 |
203 | trainer()
204 | ```
205 |
206 | Then use the accelerate CLI again to launch the script
207 |
208 | ```bash
209 | $ accelerate launch ./train.py
210 | ```
211 |
212 | ## Alternatives
213 |
214 | If your downstream task involves segmentation, please look at the following repository, which extends BYOL to 'pixel'-level learning.
215 |
216 | https://github.com/lucidrains/pixel-level-contrastive-learning
217 |
218 | ## Citation
219 |
220 | ```bibtex
221 | @misc{grill2020bootstrap,
222 | title = {Bootstrap Your Own Latent: A New Approach to Self-Supervised Learning},
223 | author = {Jean-Bastien Grill and Florian Strub and Florent Altché and Corentin Tallec and Pierre H. Richemond and Elena Buchatskaya and Carl Doersch and Bernardo Avila Pires and Zhaohan Daniel Guo and Mohammad Gheshlaghi Azar and Bilal Piot and Koray Kavukcuoglu and Rémi Munos and Michal Valko},
224 | year = {2020},
225 | eprint = {2006.07733},
226 | archivePrefix = {arXiv},
227 | primaryClass = {cs.LG}
228 | }
229 | ```
230 |
231 | ```bibtex
232 | @misc{chen2020exploring,
233 | title={Exploring Simple Siamese Representation Learning},
234 | author={Xinlei Chen and Kaiming He},
235 | year={2020},
236 | eprint={2011.10566},
237 | archivePrefix={arXiv},
238 | primaryClass={cs.CV}
239 | }
240 | ```
241 |
--------------------------------------------------------------------------------
/byol_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from byol_pytorch.byol_pytorch import BYOL
2 | from byol_pytorch.trainer import BYOLTrainer, MockDataset
3 |
--------------------------------------------------------------------------------
/byol_pytorch/byol_pytorch.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import random
3 | from functools import wraps
4 |
5 | import torch
6 | from torch import nn
7 | import torch.nn.functional as F
8 | import torch.distributed as dist
9 |
10 | from torchvision import transforms as T
11 |
12 | # helper functions
13 |
14 | def default(val, def_val):
15 | return def_val if val is None else val
16 |
17 | def flatten(t):
18 | return t.reshape(t.shape[0], -1)
19 |
20 | def singleton(cache_key):
21 | def inner_fn(fn):
22 | @wraps(fn)
23 | def wrapper(self, *args, **kwargs):
24 | instance = getattr(self, cache_key)
25 | if instance is not None:
26 | return instance
27 |
28 | instance = fn(self, *args, **kwargs)
29 | setattr(self, cache_key, instance)
30 | return instance
31 | return wrapper
32 | return inner_fn
33 |
34 | def get_module_device(module):
35 | return next(module.parameters()).device
36 |
37 | def set_requires_grad(model, val):
38 | for p in model.parameters():
39 | p.requires_grad = val
40 |
41 | def MaybeSyncBatchnorm(is_distributed = None):
42 | is_distributed = default(is_distributed, dist.is_initialized() and dist.get_world_size() > 1)
43 | return nn.SyncBatchNorm if is_distributed else nn.BatchNorm1d
44 |
45 | # loss fn
46 |
47 | def loss_fn(x, y):
48 | x = F.normalize(x, dim=-1, p=2)
49 | y = F.normalize(y, dim=-1, p=2)
50 | return 2 - 2 * (x * y).sum(dim=-1)
51 |
52 | # augmentation utils
53 |
54 | class RandomApply(nn.Module):
55 | def __init__(self, fn, p):
56 | super().__init__()
57 | self.fn = fn
58 | self.p = p
59 | def forward(self, x):
60 | if random.random() > self.p:
61 | return x
62 | return self.fn(x)
63 |
64 | # exponential moving average
65 |
66 | class EMA():
67 | def __init__(self, beta):
68 | super().__init__()
69 | self.beta = beta
70 |
71 | def update_average(self, old, new):
72 | if old is None:
73 | return new
74 | return old * self.beta + (1 - self.beta) * new
75 |
76 | def update_moving_average(ema_updater, ma_model, current_model):
77 | for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
78 | old_weight, up_weight = ma_params.data, current_params.data
79 | ma_params.data = ema_updater.update_average(old_weight, up_weight)
80 |
81 | # MLP class for projector and predictor
82 |
83 | def MLP(dim, projection_size, hidden_size=4096, sync_batchnorm=None):
84 | return nn.Sequential(
85 | nn.Linear(dim, hidden_size),
86 | MaybeSyncBatchnorm(sync_batchnorm)(hidden_size),
87 | nn.ReLU(inplace=True),
88 | nn.Linear(hidden_size, projection_size)
89 | )
90 |
91 | def SimSiamMLP(dim, projection_size, hidden_size=4096, sync_batchnorm=None):
92 | return nn.Sequential(
93 | nn.Linear(dim, hidden_size, bias=False),
94 | MaybeSyncBatchnorm(sync_batchnorm)(hidden_size),
95 | nn.ReLU(inplace=True),
96 | nn.Linear(hidden_size, hidden_size, bias=False),
97 | MaybeSyncBatchnorm(sync_batchnorm)(hidden_size),
98 | nn.ReLU(inplace=True),
99 | nn.Linear(hidden_size, projection_size, bias=False),
100 | MaybeSyncBatchnorm(sync_batchnorm)(projection_size, affine=False)
101 | )
102 |
103 | # a wrapper class for the base neural network
104 | # will manage the interception of the hidden layer output
105 | # and pipe it into the projecter and predictor nets
106 |
107 | class NetWrapper(nn.Module):
108 | def __init__(self, net, projection_size, projection_hidden_size, layer = -2, use_simsiam_mlp = False, sync_batchnorm = None):
109 | super().__init__()
110 | self.net = net
111 | self.layer = layer
112 |
113 | self.projector = None
114 | self.projection_size = projection_size
115 | self.projection_hidden_size = projection_hidden_size
116 |
117 | self.use_simsiam_mlp = use_simsiam_mlp
118 | self.sync_batchnorm = sync_batchnorm
119 |
120 | self.hidden = {}
121 | self.hook_registered = False
122 |
123 | def _find_layer(self):
124 | if type(self.layer) == str:
125 | modules = dict([*self.net.named_modules()])
126 | return modules.get(self.layer, None)
127 | elif type(self.layer) == int:
128 | children = [*self.net.children()]
129 | return children[self.layer]
130 | return None
131 |
132 | def _hook(self, _, input, output):
133 | device = input[0].device
134 | self.hidden[device] = flatten(output)
135 |
136 | def _register_hook(self):
137 | layer = self._find_layer()
138 | assert layer is not None, f'hidden layer ({self.layer}) not found'
139 | handle = layer.register_forward_hook(self._hook)
140 | self.hook_registered = True
141 |
142 | @singleton('projector')
143 | def _get_projector(self, hidden):
144 | _, dim = hidden.shape
145 | create_mlp_fn = MLP if not self.use_simsiam_mlp else SimSiamMLP
146 | projector = create_mlp_fn(dim, self.projection_size, self.projection_hidden_size, sync_batchnorm = self.sync_batchnorm)
147 | return projector.to(hidden)
148 |
149 | def get_representation(self, x):
150 | if self.layer == -1:
151 | return self.net(x)
152 |
153 | if not self.hook_registered:
154 | self._register_hook()
155 |
156 | self.hidden.clear()
157 | _ = self.net(x)
158 | hidden = self.hidden[x.device]
159 | self.hidden.clear()
160 |
161 | assert hidden is not None, f'hidden layer {self.layer} never emitted an output'
162 | return hidden
163 |
164 | def forward(self, x, return_projection = True):
165 | representation = self.get_representation(x)
166 |
167 | if not return_projection:
168 | return representation
169 |
170 | projector = self._get_projector(representation)
171 | projection = projector(representation)
172 | return projection, representation
173 |
174 | # main class
175 |
176 | class BYOL(nn.Module):
177 | def __init__(
178 | self,
179 | net,
180 | image_size,
181 | hidden_layer = -2,
182 | projection_size = 256,
183 | projection_hidden_size = 4096,
184 | augment_fn = None,
185 | augment_fn2 = None,
186 | moving_average_decay = 0.99,
187 | use_momentum = True,
188 | sync_batchnorm = None
189 | ):
190 | super().__init__()
191 | self.net = net
192 |
193 | # default SimCLR augmentation
194 |
195 | DEFAULT_AUG = torch.nn.Sequential(
196 | RandomApply(
197 | T.ColorJitter(0.8, 0.8, 0.8, 0.2),
198 | p = 0.3
199 | ),
200 | T.RandomGrayscale(p=0.2),
201 | T.RandomHorizontalFlip(),
202 | RandomApply(
203 | T.GaussianBlur((3, 3), (1.0, 2.0)),
204 | p = 0.2
205 | ),
206 | T.RandomResizedCrop((image_size, image_size)),
207 | T.Normalize(
208 | mean=torch.tensor([0.485, 0.456, 0.406]),
209 | std=torch.tensor([0.229, 0.224, 0.225])),
210 | )
211 |
212 | self.augment1 = default(augment_fn, DEFAULT_AUG)
213 | self.augment2 = default(augment_fn2, self.augment1)
214 |
215 | self.online_encoder = NetWrapper(
216 | net,
217 | projection_size,
218 | projection_hidden_size,
219 | layer = hidden_layer,
220 | use_simsiam_mlp = not use_momentum,
221 | sync_batchnorm = sync_batchnorm
222 | )
223 |
224 | self.use_momentum = use_momentum
225 | self.target_encoder = None
226 | self.target_ema_updater = EMA(moving_average_decay)
227 |
228 | self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size)
229 |
230 | # get device of network and make wrapper same device
231 | device = get_module_device(net)
232 | self.to(device)
233 |
234 | # send a mock image tensor to instantiate singleton parameters
235 | self.forward(torch.randn(2, 3, image_size, image_size, device=device))
236 |
237 | @singleton('target_encoder')
238 | def _get_target_encoder(self):
239 | target_encoder = copy.deepcopy(self.online_encoder)
240 | set_requires_grad(target_encoder, False)
241 | return target_encoder
242 |
243 | def reset_moving_average(self):
244 | del self.target_encoder
245 | self.target_encoder = None
246 |
247 | def update_moving_average(self):
248 | assert self.use_momentum, 'you do not need to update the moving average, since you have turned off momentum for the target encoder'
249 | assert self.target_encoder is not None, 'target encoder has not been created yet'
250 | update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)
251 |
252 | def forward(
253 | self,
254 | x,
255 | return_embedding = False,
256 | return_projection = True
257 | ):
258 | assert not (self.training and x.shape[0] == 1), 'you must have greater than 1 sample when training, due to the batchnorm in the projection layer'
259 |
260 | if return_embedding:
261 | return self.online_encoder(x, return_projection = return_projection)
262 |
263 | image_one, image_two = self.augment1(x), self.augment2(x)
264 |
265 | images = torch.cat((image_one, image_two), dim = 0)
266 |
267 | online_projections, _ = self.online_encoder(images)
268 | online_predictions = self.online_predictor(online_projections)
269 |
270 | online_pred_one, online_pred_two = online_predictions.chunk(2, dim = 0)
271 |
272 | with torch.no_grad():
273 | target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder
274 |
275 | target_projections, _ = target_encoder(images)
276 | target_projections = target_projections.detach()
277 |
278 | target_proj_one, target_proj_two = target_projections.chunk(2, dim = 0)
279 |
280 | loss_one = loss_fn(online_pred_one, target_proj_two.detach())
281 | loss_two = loss_fn(online_pred_two, target_proj_one.detach())
282 |
283 | loss = loss_one + loss_two
284 | return loss.mean()
285 |
--------------------------------------------------------------------------------
/byol_pytorch/trainer.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import torch
4 | import torch.distributed as dist
5 | from torch.nn import Module
6 | from torch.nn import SyncBatchNorm
7 |
8 | from torch.optim import Optimizer, Adam
9 | from torch.utils.data import Dataset, DataLoader
10 |
11 | from byol_pytorch.byol_pytorch import BYOL
12 |
13 | from beartype import beartype
14 | from beartype.typing import Optional
15 |
16 | from accelerate import Accelerator
17 | from accelerate.utils import DistributedDataParallelKwargs
18 |
19 | # constants
20 |
21 | DEFAULT_DDP_KWARGS = DistributedDataParallelKwargs(
22 | find_unused_parameters = True
23 | )
24 |
25 | # functions
26 |
27 | def exists(v):
28 | return v is not None
29 |
30 | def cycle(dl):
31 | while True:
32 | for batch in dl:
33 | yield batch
34 |
35 | # class
36 |
37 | class MockDataset(Dataset):
38 | def __init__(self, image_size, length):
39 | self.length = length
40 | self.image_size = image_size
41 |
42 | def __len__(self):
43 | return self.length
44 |
45 | def __getitem__(self, idx):
46 | return torch.randn(3, self.image_size, self.image_size)
47 |
48 | # main trainer
49 |
50 | class BYOLTrainer(Module):
51 | @beartype
52 | def __init__(
53 | self,
54 | net: Module,
55 | *,
56 | image_size: int,
57 | hidden_layer: str,
58 | learning_rate: float,
59 | dataset: Dataset,
60 | num_train_steps: int,
61 | batch_size: int = 16,
62 | optimizer_klass = Adam,
63 | checkpoint_every: int = 1000,
64 | checkpoint_folder: str = './checkpoints',
65 | byol_kwargs: dict = dict(),
66 | optimizer_kwargs: dict = dict(),
67 | accelerator_kwargs: dict = dict(),
68 | ):
69 | super().__init__()
70 |
71 | if 'kwargs_handlers' not in accelerator_kwargs:
72 | accelerator_kwargs['kwargs_handlers'] = [DEFAULT_DDP_KWARGS]
73 |
74 | self.accelerator = Accelerator(**accelerator_kwargs)
75 |
76 | if dist.is_initialized() and dist.get_world_size() > 1:
77 | net = SyncBatchNorm.convert_sync_batchnorm(net)
78 |
79 | self.net = net
80 |
81 | self.byol = BYOL(net, image_size = image_size, hidden_layer = hidden_layer, **byol_kwargs)
82 |
83 | self.optimizer = optimizer_klass(self.byol.parameters(), lr = learning_rate, **optimizer_kwargs)
84 |
85 | self.dataloader = DataLoader(dataset, shuffle = True, batch_size = batch_size)
86 |
87 | self.num_train_steps = num_train_steps
88 |
89 | self.checkpoint_every = checkpoint_every
90 | self.checkpoint_folder = Path(checkpoint_folder)
91 | self.checkpoint_folder.mkdir(exist_ok = True, parents = True)
92 | assert self.checkpoint_folder.is_dir()
93 |
94 | # prepare with accelerate
95 |
96 | (
97 | self.byol,
98 | self.optimizer,
99 | self.dataloader
100 | ) = self.accelerator.prepare(
101 | self.byol,
102 | self.optimizer,
103 | self.dataloader
104 | )
105 |
106 | self.register_buffer('step', torch.tensor(0))
107 |
108 | def wait(self):
109 | return self.accelerator.wait_for_everyone()
110 |
111 | def print(self, msg):
112 | return self.accelerator.print(msg)
113 |
114 | def forward(self):
115 | step = self.step.item()
116 | data_it = cycle(self.dataloader)
117 |
118 | for _ in range(self.num_train_steps):
119 | images = next(data_it)
120 |
121 | with self.accelerator.autocast():
122 | loss = self.byol(images)
123 | self.accelerator.backward(loss)
124 |
125 | self.print(f'loss {loss.item():.3f}')
126 |
127 | self.optimizer.step()
128 | self.optimizer.zero_grad()
129 |
130 | self.wait()
131 |
132 | self.byol.update_moving_average()
133 |
134 | self.wait()
135 |
136 | if not (step % self.checkpoint_every) and self.accelerator.is_main_process:
137 | checkpoint_num = step // self.checkpoint_every
138 | checkpoint_path = self.checkpoint_folder / f'checkpoint.{checkpoint_num}.pt'
139 | torch.save(self.net.state_dict(), str(checkpoint_path))
140 |
141 | self.wait()
142 |
143 | step += 1
144 |
145 | self.print('training complete')
146 |
--------------------------------------------------------------------------------
/diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/byol-pytorch/0c3ab5409181852f8495ef924dce9186f94d9126/diagram.png
--------------------------------------------------------------------------------
/examples/lightning/README.md:
--------------------------------------------------------------------------------
1 | ## Pytorch-lightning example script
2 |
3 | ### Requirements
4 |
5 | ```bash
6 | $ pip install pytorch-lightning
7 | $ pip install pillow
8 | ```
9 |
10 | ### Run
11 |
12 | ```bash
13 | $ python train.py --image_folder /path/to/your/images
14 | ```
15 |
--------------------------------------------------------------------------------
/examples/lightning/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import multiprocessing
4 | from pathlib import Path
5 | from PIL import Image
6 |
7 | import torch
8 | from torchvision import models, transforms
9 | from torch.utils.data import DataLoader, Dataset
10 |
11 | from byol_pytorch import BYOL
12 | import pytorch_lightning as pl
13 |
14 | # test model, a resnet 50
15 |
16 | resnet = models.resnet50(pretrained=True)
17 |
18 | # arguments
19 |
20 | parser = argparse.ArgumentParser(description='byol-lightning-test')
21 |
22 | parser.add_argument('--image_folder', type=str, required = True,
23 | help='path to your folder of images for self-supervised learning')
24 |
25 | args = parser.parse_args()
26 |
27 | # constants
28 |
29 | BATCH_SIZE = 32
30 | EPOCHS = 1000
31 | LR = 3e-4
32 | NUM_GPUS = 2
33 | IMAGE_SIZE = 256
34 | IMAGE_EXTS = ['.jpg', '.png', '.jpeg']
35 | NUM_WORKERS = multiprocessing.cpu_count()
36 |
37 | # pytorch lightning module
38 |
39 | class SelfSupervisedLearner(pl.LightningModule):
40 | def __init__(self, net, **kwargs):
41 | super().__init__()
42 | self.learner = BYOL(net, **kwargs)
43 |
44 | def forward(self, images):
45 | return self.learner(images)
46 |
47 | def training_step(self, images, _):
48 | loss = self.forward(images)
49 | return {'loss': loss}
50 |
51 | def configure_optimizers(self):
52 | return torch.optim.Adam(self.parameters(), lr=LR)
53 |
54 | def on_before_zero_grad(self, _):
55 | if self.learner.use_momentum:
56 | self.learner.update_moving_average()
57 |
58 | # images dataset
59 |
60 | def expand_greyscale(t):
61 | return t.expand(3, -1, -1)
62 |
63 | class ImagesDataset(Dataset):
64 | def __init__(self, folder, image_size):
65 | super().__init__()
66 | self.folder = folder
67 | self.paths = []
68 |
69 | for path in Path(f'{folder}').glob('**/*'):
70 | _, ext = os.path.splitext(path)
71 | if ext.lower() in IMAGE_EXTS:
72 | self.paths.append(path)
73 |
74 | print(f'{len(self.paths)} images found')
75 |
76 | self.transform = transforms.Compose([
77 | transforms.Resize(image_size),
78 | transforms.CenterCrop(image_size),
79 | transforms.ToTensor(),
80 | transforms.Lambda(expand_greyscale)
81 | ])
82 |
83 | def __len__(self):
84 | return len(self.paths)
85 |
86 | def __getitem__(self, index):
87 | path = self.paths[index]
88 | img = Image.open(path)
89 | img = img.convert('RGB')
90 | return self.transform(img)
91 |
92 | # main
93 |
94 | if __name__ == '__main__':
95 | ds = ImagesDataset(args.image_folder, IMAGE_SIZE)
96 | train_loader = DataLoader(ds, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, shuffle=True)
97 |
98 | model = SelfSupervisedLearner(
99 | resnet,
100 | image_size = IMAGE_SIZE,
101 | hidden_layer = 'avgpool',
102 | projection_size = 256,
103 | projection_hidden_size = 4096,
104 | moving_average_decay = 0.99
105 | )
106 |
107 | trainer = pl.Trainer(
108 | gpus = NUM_GPUS,
109 | max_epochs = EPOCHS,
110 | accumulate_grad_batches = 1,
111 | sync_batchnorm = True
112 | )
113 |
114 | trainer.fit(model, train_loader)
115 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'byol-pytorch',
5 | packages = find_packages(exclude=['examples']),
6 | version = '0.8.2',
7 | license='MIT',
8 | description = 'Self-supervised contrastive learning made simple',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | url = 'https://github.com/lucidrains/byol-pytorch',
12 | long_description_content_type = 'text/markdown',
13 | keywords = [
14 | 'self-supervised learning',
15 | 'artificial intelligence'
16 | ],
17 | install_requires=[
18 | 'accelerate',
19 | 'beartype',
20 | 'torch>=1.6',
21 | 'torchvision>=0.8'
22 | ],
23 | classifiers=[
24 | 'Development Status :: 4 - Beta',
25 | 'Intended Audience :: Developers',
26 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
27 | 'License :: OSI Approved :: MIT License',
28 | 'Programming Language :: Python :: 3.6',
29 | ],
30 | )
31 |
--------------------------------------------------------------------------------