├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── pixel_level_contrastive_learning
├── __init__.py
└── pixel_level_contrastive_learning.py
├── propagate.png
└── setup.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflows will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | name: Upload Python Package
5 |
6 | on:
7 | release:
8 | types: [created]
9 |
10 | jobs:
11 | deploy:
12 |
13 | runs-on: ubuntu-latest
14 |
15 | steps:
16 | - uses: actions/checkout@v2
17 | - name: Set up Python
18 | uses: actions/setup-python@v2
19 | with:
20 | python-version: '3.x'
21 | - name: Install dependencies
22 | run: |
23 | python -m pip install --upgrade pip
24 | pip install setuptools wheel twine
25 | - name: Build and publish
26 | env:
27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
29 | run: |
30 | python setup.py sdist bdist_wheel
31 | twine upload dist/*
32 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Pixel-level Contrastive Learning
4 |
5 | Implementation of Pixel-level Contrastive Learning, proposed in the paper "Propagate Yourself", in Pytorch. In addition to doing contrastive learning on the pixel level, the online network further passes the pixel level representations to a Pixel Propagation Module and enforces a similarity loss to the target network. They beat all previous unsupervised and supervised methods in segmentation tasks.
6 |
7 | ## Install
8 |
9 | ```bash
10 | $ pip install pixel-level-contrastive-learning
11 | ```
12 |
13 | ## Usage
14 |
15 | Below is an example of how you would use the framework to self-supervise training of a resnet, taking the output of layer 4 (8 x 8 'pixels').
16 |
17 |
18 | ```python
19 | import torch
20 | from pixel_level_contrastive_learning import PixelCL
21 | from torchvision import models
22 | from tqdm import tqdm
23 |
24 | resnet = models.resnet50(pretrained=True)
25 |
26 | learner = PixelCL(
27 | resnet,
28 | image_size = 256,
29 | hidden_layer_pixel = 'layer4', # leads to output of 8x8 feature map for pixel-level learning
30 | hidden_layer_instance = -2, # leads to output for instance-level learning
31 | projection_size = 256, # size of projection output, 256 was used in the paper
32 | projection_hidden_size = 2048, # size of projection hidden dimension, paper used 2048
33 | moving_average_decay = 0.99, # exponential moving average decay of target encoder
34 | ppm_num_layers = 1, # number of layers for transform function in the pixel propagation module, 1 was optimal
35 | ppm_gamma = 2, # sharpness of the similarity in the pixel propagation module, already at optimal value of 2
36 | distance_thres = 0.7, # ideal value is 0.7, as indicated in the paper, which makes the assumption of each feature map's pixel diagonal distance to be 1 (still unclear)
37 | similarity_temperature = 0.3, # temperature for the cosine similarity for the pixel contrastive loss
38 | alpha = 1., # weight of the pixel propagation loss (pixpro) vs pixel CL loss
39 | use_pixpro = True, # do pixel pro instead of pixel contrast loss, defaults to pixpro, since it is the best one
40 | cutout_ratio_range = (0.6, 0.8) # a random ratio is selected from this range for the random cutout
41 | ).cuda()
42 |
43 | opt = torch.optim.Adam(learner.parameters(), lr=1e-4)
44 |
45 | def sample_batch_images():
46 | return torch.randn(10, 3, 256, 256).cuda()
47 |
48 | for _ in tqdm(range(100000)):
49 | images = sample_batch_images()
50 | loss = learner(images) # if positive pixel pairs is equal to zero, the loss is equal to the instance level loss
51 |
52 | opt.zero_grad()
53 | loss.backward()
54 | print(loss.item())
55 | opt.step()
56 | learner.update_moving_average() # update moving average of target encoder
57 |
58 | # after much training, save the improved model for testing on downstream task
59 | torch.save(resnet, 'improved-resnet.pt')
60 | ```
61 |
62 | You can also return the number of positive pixel pairs on `forward`, for logging or other purposes
63 |
64 | ```python
65 | loss, positive_pairs = learner(images, return_positive_pairs = True)
66 | ```
67 | ## Citations
68 |
69 | ```bibtex
70 | @misc{xie2020propagate,
71 | title={Propagate Yourself: Exploring Pixel-Level Consistency for Unsupervised Visual Representation Learning},
72 | author={Zhenda Xie and Yutong Lin and Zheng Zhang and Yue Cao and Stephen Lin and Han Hu},
73 | year={2020},
74 | eprint={2011.10043},
75 | archivePrefix={arXiv},
76 | primaryClass={cs.CV}
77 | }
78 | ```
79 |
--------------------------------------------------------------------------------
/pixel_level_contrastive_learning/__init__.py:
--------------------------------------------------------------------------------
1 | from pixel_level_contrastive_learning.pixel_level_contrastive_learning import PPM, PixelCL
2 |
--------------------------------------------------------------------------------
/pixel_level_contrastive_learning/pixel_level_contrastive_learning.py:
--------------------------------------------------------------------------------
1 | import math
2 | import copy
3 | import random
4 | from functools import wraps, partial
5 | from math import floor
6 |
7 | import torch
8 | from torch import nn, einsum
9 | import torch.nn.functional as F
10 |
11 | from kornia import augmentation as augs
12 | from kornia import filters, color
13 |
14 | from einops import rearrange
15 |
16 | # helper functions
17 |
18 | def identity(t):
19 | return t
20 |
21 | def default(val, def_val):
22 | return def_val if val is None else val
23 |
24 | def rand_true(prob):
25 | return random.random() < prob
26 |
27 | def singleton(cache_key):
28 | def inner_fn(fn):
29 | @wraps(fn)
30 | def wrapper(self, *args, **kwargs):
31 | instance = getattr(self, cache_key)
32 | if instance is not None:
33 | return instance
34 |
35 | instance = fn(self, *args, **kwargs)
36 | setattr(self, cache_key, instance)
37 | return instance
38 | return wrapper
39 | return inner_fn
40 |
41 | def get_module_device(module):
42 | return next(module.parameters()).device
43 |
44 | def set_requires_grad(model, val):
45 | for p in model.parameters():
46 | p.requires_grad = val
47 |
48 | def cutout_coordinates(image, ratio_range = (0.6, 0.8)):
49 | _, _, orig_h, orig_w = image.shape
50 |
51 | ratio_lo, ratio_hi = ratio_range
52 | random_ratio = ratio_lo + random.random() * (ratio_hi - ratio_lo)
53 | w, h = floor(random_ratio * orig_w), floor(random_ratio * orig_h)
54 | coor_x = floor((orig_w - w) * random.random())
55 | coor_y = floor((orig_h - h) * random.random())
56 | return ((coor_y, coor_y + h), (coor_x, coor_x + w)), random_ratio
57 |
58 | def cutout_and_resize(image, coordinates, output_size = None, mode = 'nearest'):
59 | shape = image.shape
60 | output_size = default(output_size, shape[2:])
61 | (y0, y1), (x0, x1) = coordinates
62 | cutout_image = image[:, :, y0:y1, x0:x1]
63 | return F.interpolate(cutout_image, size = output_size, mode = mode)
64 |
65 | # augmentation utils
66 |
67 | class RandomApply(nn.Module):
68 | def __init__(self, fn, p):
69 | super().__init__()
70 | self.fn = fn
71 | self.p = p
72 | def forward(self, x):
73 | if random.random() > self.p:
74 | return x
75 | return self.fn(x)
76 |
77 | # exponential moving average
78 |
79 | class EMA():
80 | def __init__(self, beta):
81 | super().__init__()
82 | self.beta = beta
83 |
84 | def update_average(self, old, new):
85 | if old is None:
86 | return new
87 | return old * self.beta + (1 - self.beta) * new
88 |
89 | def update_moving_average(ema_updater, ma_model, current_model):
90 | for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
91 | old_weight, up_weight = ma_params.data, current_params.data
92 | ma_params.data = ema_updater.update_average(old_weight, up_weight)
93 |
94 | # loss fn
95 |
96 | def loss_fn(x, y):
97 | x = F.normalize(x, dim=-1, p=2)
98 | y = F.normalize(y, dim=-1, p=2)
99 | return 2 - 2 * (x * y).sum(dim=-1)
100 |
101 | # classes
102 |
103 | class MLP(nn.Module):
104 | def __init__(self, chan, chan_out = 256, inner_dim = 2048):
105 | super().__init__()
106 | self.net = nn.Sequential(
107 | nn.Linear(chan, inner_dim),
108 | nn.BatchNorm1d(inner_dim),
109 | nn.ReLU(),
110 | nn.Linear(inner_dim, chan_out)
111 | )
112 |
113 | def forward(self, x):
114 | return self.net(x)
115 |
116 | class ConvMLP(nn.Module):
117 | def __init__(self, chan, chan_out = 256, inner_dim = 2048):
118 | super().__init__()
119 | self.net = nn.Sequential(
120 | nn.Conv2d(chan, inner_dim, 1),
121 | nn.BatchNorm2d(inner_dim),
122 | nn.ReLU(),
123 | nn.Conv2d(inner_dim, chan_out, 1)
124 | )
125 |
126 | def forward(self, x):
127 | return self.net(x)
128 |
129 | class PPM(nn.Module):
130 | def __init__(
131 | self,
132 | *,
133 | chan,
134 | num_layers = 1,
135 | gamma = 2):
136 | super().__init__()
137 | self.gamma = gamma
138 |
139 | if num_layers == 0:
140 | self.transform_net = nn.Identity()
141 | elif num_layers == 1:
142 | self.transform_net = nn.Conv2d(chan, chan, 1)
143 | elif num_layers == 2:
144 | self.transform_net = nn.Sequential(
145 | nn.Conv2d(chan, chan, 1),
146 | nn.BatchNorm2d(chan),
147 | nn.ReLU(),
148 | nn.Conv2d(chan, chan, 1)
149 | )
150 | else:
151 | raise ValueError('num_layers must be one of 0, 1, or 2')
152 |
153 | def forward(self, x):
154 | xi = x[:, :, :, :, None, None]
155 | xj = x[:, :, None, None, :, :]
156 | similarity = F.relu(F.cosine_similarity(xi, xj, dim = 1)) ** self.gamma
157 |
158 | transform_out = self.transform_net(x)
159 | out = einsum('b x y h w, b c h w -> b c x y', similarity, transform_out)
160 | return out
161 |
162 | # a wrapper class for the base neural network
163 | # will manage the interception of the hidden layer output
164 | # and pipe it into the projecter and predictor nets
165 |
166 | class NetWrapper(nn.Module):
167 | def __init__(
168 | self,
169 | *,
170 | net,
171 | projection_size,
172 | projection_hidden_size,
173 | layer_pixel = -2,
174 | layer_instance = -2
175 | ):
176 | super().__init__()
177 | self.net = net
178 | self.layer_pixel = layer_pixel
179 | self.layer_instance = layer_instance
180 |
181 | self.pixel_projector = None
182 | self.instance_projector = None
183 |
184 | self.projection_size = projection_size
185 | self.projection_hidden_size = projection_hidden_size
186 |
187 | self.hidden_pixel = None
188 | self.hidden_instance = None
189 | self.hook_registered = False
190 |
191 | def _find_layer(self, layer_id):
192 | if type(layer_id) == str:
193 | modules = dict([*self.net.named_modules()])
194 | return modules.get(layer_id, None)
195 | elif type(layer_id) == int:
196 | children = [*self.net.children()]
197 | return children[layer_id]
198 | return None
199 |
200 | def _hook_pixel(self, _, __, output):
201 | setattr(self, 'hidden_pixel', output)
202 |
203 | def _hook_instance(self, _, __, output):
204 | setattr(self, 'hidden_instance', output)
205 |
206 | def _register_hook(self):
207 | pixel_layer = self._find_layer(self.layer_pixel)
208 | instance_layer = self._find_layer(self.layer_instance)
209 |
210 | assert pixel_layer is not None, f'hidden layer ({self.layer_pixel}) not found'
211 | assert instance_layer is not None, f'hidden layer ({self.layer_instance}) not found'
212 |
213 | pixel_layer.register_forward_hook(self._hook_pixel)
214 | instance_layer.register_forward_hook(self._hook_instance)
215 | self.hook_registered = True
216 |
217 | @singleton('pixel_projector')
218 | def _get_pixel_projector(self, hidden):
219 | _, dim, *_ = hidden.shape
220 | projector = ConvMLP(dim, self.projection_size, self.projection_hidden_size)
221 | return projector.to(hidden)
222 |
223 | @singleton('instance_projector')
224 | def _get_instance_projector(self, hidden):
225 | _, dim = hidden.shape
226 | projector = MLP(dim, self.projection_size, self.projection_hidden_size)
227 | return projector.to(hidden)
228 |
229 | def get_representation(self, x):
230 | if not self.hook_registered:
231 | self._register_hook()
232 |
233 | _ = self.net(x)
234 | hidden_pixel = self.hidden_pixel
235 | hidden_instance = self.hidden_instance
236 | self.hidden_pixel = None
237 | self.hidden_instance = None
238 | assert hidden_pixel is not None, f'hidden pixel layer {self.layer_pixel} never emitted an output'
239 | assert hidden_instance is not None, f'hidden instance layer {self.layer_instance} never emitted an output'
240 | return hidden_pixel, hidden_instance
241 |
242 | def forward(self, x):
243 | pixel_representation, instance_representation = self.get_representation(x)
244 | instance_representation = instance_representation.flatten(1)
245 |
246 | pixel_projector = self._get_pixel_projector(pixel_representation)
247 | instance_projector = self._get_instance_projector(instance_representation)
248 |
249 | pixel_projection = pixel_projector(pixel_representation)
250 | instance_projection = instance_projector(instance_representation)
251 | return pixel_projection, instance_projection
252 |
253 | # main class
254 |
255 | class PixelCL(nn.Module):
256 | def __init__(
257 | self,
258 | net,
259 | image_size,
260 | hidden_layer_pixel = -2,
261 | hidden_layer_instance = -2,
262 | projection_size = 256,
263 | projection_hidden_size = 2048,
264 | augment_fn = None,
265 | augment_fn2 = None,
266 | prob_rand_hflip = 0.25,
267 | moving_average_decay = 0.99,
268 | ppm_num_layers = 1,
269 | ppm_gamma = 2,
270 | distance_thres = 0.7,
271 | similarity_temperature = 0.3,
272 | alpha = 1.,
273 | use_pixpro = True,
274 | cutout_ratio_range = (0.6, 0.8),
275 | cutout_interpolate_mode = 'nearest',
276 | coord_cutout_interpolate_mode = 'bilinear'
277 | ):
278 | super().__init__()
279 |
280 | DEFAULT_AUG = nn.Sequential(
281 | RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8),
282 | augs.RandomGrayscale(p=0.2),
283 | RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1),
284 | augs.RandomSolarize(p=0.5),
285 | augs.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225]))
286 | )
287 |
288 | self.augment1 = default(augment_fn, DEFAULT_AUG)
289 | self.augment2 = default(augment_fn2, self.augment1)
290 | self.prob_rand_hflip = prob_rand_hflip
291 |
292 | self.online_encoder = NetWrapper(
293 | net = net,
294 | projection_size = projection_size,
295 | projection_hidden_size = projection_hidden_size,
296 | layer_pixel = hidden_layer_pixel,
297 | layer_instance = hidden_layer_instance
298 | )
299 |
300 | self.target_encoder = None
301 | self.target_ema_updater = EMA(moving_average_decay)
302 |
303 | self.distance_thres = distance_thres
304 | self.similarity_temperature = similarity_temperature
305 | self.alpha = alpha
306 |
307 | self.use_pixpro = use_pixpro
308 |
309 | if use_pixpro:
310 | self.propagate_pixels = PPM(
311 | chan = projection_size,
312 | num_layers = ppm_num_layers,
313 | gamma = ppm_gamma
314 | )
315 |
316 | self.cutout_ratio_range = cutout_ratio_range
317 | self.cutout_interpolate_mode = cutout_interpolate_mode
318 | self.coord_cutout_interpolate_mode = coord_cutout_interpolate_mode
319 |
320 | # instance level predictor
321 | self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size)
322 |
323 | # get device of network and make wrapper same device
324 | device = get_module_device(net)
325 | self.to(device)
326 |
327 | # send a mock image tensor to instantiate singleton parameters
328 | self.forward(torch.randn(2, 3, image_size, image_size, device=device))
329 |
330 | @singleton('target_encoder')
331 | def _get_target_encoder(self):
332 | target_encoder = copy.deepcopy(self.online_encoder)
333 | set_requires_grad(target_encoder, False)
334 | return target_encoder
335 |
336 | def reset_moving_average(self):
337 | del self.target_encoder
338 | self.target_encoder = None
339 |
340 | def update_moving_average(self):
341 | assert self.target_encoder is not None, 'target encoder has not been created yet'
342 | update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)
343 |
344 | def forward(self, x, return_positive_pairs = False):
345 | shape, device, prob_flip = x.shape, x.device, self.prob_rand_hflip
346 |
347 | rand_flip_fn = lambda t: torch.flip(t, dims = (-1,))
348 |
349 | flip_image_one, flip_image_two = rand_true(prob_flip), rand_true(prob_flip)
350 | flip_image_one_fn = rand_flip_fn if flip_image_one else identity
351 | flip_image_two_fn = rand_flip_fn if flip_image_two else identity
352 |
353 | cutout_coordinates_one, _ = cutout_coordinates(x, self.cutout_ratio_range)
354 | cutout_coordinates_two, _ = cutout_coordinates(x, self.cutout_ratio_range)
355 |
356 | image_one_cutout = cutout_and_resize(x, cutout_coordinates_one, mode = self.cutout_interpolate_mode)
357 | image_two_cutout = cutout_and_resize(x, cutout_coordinates_two, mode = self.cutout_interpolate_mode)
358 |
359 | image_one_cutout = flip_image_one_fn(image_one_cutout)
360 | image_two_cutout = flip_image_two_fn(image_two_cutout)
361 |
362 | image_one_cutout, image_two_cutout = self.augment1(image_one_cutout), self.augment2(image_two_cutout)
363 |
364 | proj_pixel_one, proj_instance_one = self.online_encoder(image_one_cutout)
365 | proj_pixel_two, proj_instance_two = self.online_encoder(image_two_cutout)
366 |
367 | image_h, image_w = shape[2:]
368 |
369 | proj_image_shape = proj_pixel_one.shape[2:]
370 | proj_image_h, proj_image_w = proj_image_shape
371 |
372 | coordinates = torch.meshgrid(
373 | torch.arange(image_h, device = device),
374 | torch.arange(image_w, device = device)
375 | )
376 |
377 | coordinates = torch.stack(coordinates).unsqueeze(0).float()
378 | coordinates /= math.sqrt(image_h ** 2 + image_w ** 2)
379 | coordinates[:, 0] *= proj_image_h
380 | coordinates[:, 1] *= proj_image_w
381 |
382 | proj_coors_one = cutout_and_resize(coordinates, cutout_coordinates_one, output_size = proj_image_shape, mode = self.coord_cutout_interpolate_mode)
383 | proj_coors_two = cutout_and_resize(coordinates, cutout_coordinates_two, output_size = proj_image_shape, mode = self.coord_cutout_interpolate_mode)
384 |
385 | proj_coors_one = flip_image_one_fn(proj_coors_one)
386 | proj_coors_two = flip_image_two_fn(proj_coors_two)
387 |
388 | proj_coors_one, proj_coors_two = map(lambda t: rearrange(t, 'b c h w -> (b h w) c'), (proj_coors_one, proj_coors_two))
389 | pdist = nn.PairwiseDistance(p = 2)
390 |
391 | num_pixels = proj_coors_one.shape[0]
392 |
393 | proj_coors_one_expanded = proj_coors_one[:, None].expand(num_pixels, num_pixels, -1).reshape(num_pixels * num_pixels, 2)
394 | proj_coors_two_expanded = proj_coors_two[None, :].expand(num_pixels, num_pixels, -1).reshape(num_pixels * num_pixels, 2)
395 |
396 | distance_matrix = pdist(proj_coors_one_expanded, proj_coors_two_expanded)
397 | distance_matrix = distance_matrix.reshape(num_pixels, num_pixels)
398 |
399 | positive_mask_one_two = distance_matrix < self.distance_thres
400 | positive_mask_two_one = positive_mask_one_two.t()
401 |
402 | with torch.no_grad():
403 | target_encoder = self._get_target_encoder()
404 | target_proj_pixel_one, target_proj_instance_one = target_encoder(image_one_cutout)
405 | target_proj_pixel_two, target_proj_instance_two = target_encoder(image_two_cutout)
406 |
407 | # flatten all the pixel projections
408 |
409 | flatten = lambda t: rearrange(t, 'b c h w -> b c (h w)')
410 |
411 | target_proj_pixel_one, target_proj_pixel_two = list(map(flatten, (target_proj_pixel_one, target_proj_pixel_two)))
412 |
413 | # get total number of positive pixel pairs
414 |
415 | positive_pixel_pairs = positive_mask_one_two.sum()
416 |
417 | # get instance level loss
418 |
419 | pred_instance_one = self.online_predictor(proj_instance_one)
420 | pred_instance_two = self.online_predictor(proj_instance_two)
421 |
422 | loss_instance_one = loss_fn(pred_instance_one, target_proj_instance_two.detach())
423 | loss_instance_two = loss_fn(pred_instance_two, target_proj_instance_one.detach())
424 |
425 | instance_loss = (loss_instance_one + loss_instance_two).mean()
426 |
427 | if positive_pixel_pairs == 0:
428 | ret = (instance_loss, 0) if return_positive_pairs else instance_loss
429 | return ret
430 |
431 | if not self.use_pixpro:
432 | # calculate pix contrast loss
433 |
434 | proj_pixel_one, proj_pixel_two = list(map(flatten, (proj_pixel_one, proj_pixel_two)))
435 |
436 | similarity_one_two = F.cosine_similarity(proj_pixel_one[..., :, None], target_proj_pixel_two[..., None, :], dim = 1) / self.similarity_temperature
437 | similarity_two_one = F.cosine_similarity(proj_pixel_two[..., :, None], target_proj_pixel_one[..., None, :], dim = 1) / self.similarity_temperature
438 |
439 | loss_pix_one_two = -torch.log(
440 | similarity_one_two.masked_select(positive_mask_one_two[None, ...]).exp().sum() /
441 | similarity_one_two.exp().sum()
442 | )
443 |
444 | loss_pix_two_one = -torch.log(
445 | similarity_two_one.masked_select(positive_mask_two_one[None, ...]).exp().sum() /
446 | similarity_two_one.exp().sum()
447 | )
448 |
449 | pix_loss = (loss_pix_one_two + loss_pix_two_one) / 2
450 | else:
451 | # calculate pix pro loss
452 |
453 | propagated_pixels_one = self.propagate_pixels(proj_pixel_one)
454 | propagated_pixels_two = self.propagate_pixels(proj_pixel_two)
455 |
456 | propagated_pixels_one, propagated_pixels_two = list(map(flatten, (propagated_pixels_one, propagated_pixels_two)))
457 |
458 | propagated_similarity_one_two = F.cosine_similarity(propagated_pixels_one[..., :, None], target_proj_pixel_two[..., None, :], dim = 1)
459 | propagated_similarity_two_one = F.cosine_similarity(propagated_pixels_two[..., :, None], target_proj_pixel_one[..., None, :], dim = 1)
460 |
461 | loss_pixpro_one_two = - propagated_similarity_one_two.masked_select(positive_mask_one_two[None, ...]).mean()
462 | loss_pixpro_two_one = - propagated_similarity_two_one.masked_select(positive_mask_two_one[None, ...]).mean()
463 |
464 | pix_loss = (loss_pixpro_one_two + loss_pixpro_two_one) / 2
465 |
466 | # total loss
467 |
468 | loss = pix_loss * self.alpha + instance_loss
469 |
470 | ret = (loss, positive_pixel_pairs) if return_positive_pairs else loss
471 | return ret
472 |
--------------------------------------------------------------------------------
/propagate.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/pixel-level-contrastive-learning/0c60e93df73b0ec351f2104839c4d4748f3d38a2/propagate.png
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'pixel-level-contrastive-learning',
5 | packages = find_packages(),
6 | version = '0.1.1',
7 | license='MIT',
8 | description = 'Pixel-Level Contrastive Learning',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | url = 'https://github.com/lucidrains/pixel-level-contrastive-learning',
12 | keywords = ['self-supervised learning', 'artificial intelligence'],
13 | install_requires=[
14 | 'einops',
15 | 'torch>=1.6',
16 | 'kornia>=0.4.0'
17 | ],
18 | classifiers=[
19 | 'Development Status :: 4 - Beta',
20 | 'Intended Audience :: Developers',
21 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
22 | 'License :: OSI Approved :: MIT License',
23 | 'Programming Language :: Python :: 3.6',
24 | ],
25 | )
--------------------------------------------------------------------------------