├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── pixel_level_contrastive_learning ├── __init__.py └── pixel_level_contrastive_learning.py ├── propagate.png └── setup.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 29 | run: | 30 | python setup.py sdist bdist_wheel 31 | twine upload dist/* 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Pixel-level Contrastive Learning 4 | 5 | Implementation of Pixel-level Contrastive Learning, proposed in the paper "Propagate Yourself", in Pytorch. In addition to doing contrastive learning on the pixel level, the online network further passes the pixel level representations to a Pixel Propagation Module and enforces a similarity loss to the target network. They beat all previous unsupervised and supervised methods in segmentation tasks. 6 | 7 | ## Install 8 | 9 | ```bash 10 | $ pip install pixel-level-contrastive-learning 11 | ``` 12 | 13 | ## Usage 14 | 15 | Below is an example of how you would use the framework to self-supervise training of a resnet, taking the output of layer 4 (8 x 8 'pixels'). 16 | 17 | 18 | ```python 19 | import torch 20 | from pixel_level_contrastive_learning import PixelCL 21 | from torchvision import models 22 | from tqdm import tqdm 23 | 24 | resnet = models.resnet50(pretrained=True) 25 | 26 | learner = PixelCL( 27 | resnet, 28 | image_size = 256, 29 | hidden_layer_pixel = 'layer4', # leads to output of 8x8 feature map for pixel-level learning 30 | hidden_layer_instance = -2, # leads to output for instance-level learning 31 | projection_size = 256, # size of projection output, 256 was used in the paper 32 | projection_hidden_size = 2048, # size of projection hidden dimension, paper used 2048 33 | moving_average_decay = 0.99, # exponential moving average decay of target encoder 34 | ppm_num_layers = 1, # number of layers for transform function in the pixel propagation module, 1 was optimal 35 | ppm_gamma = 2, # sharpness of the similarity in the pixel propagation module, already at optimal value of 2 36 | distance_thres = 0.7, # ideal value is 0.7, as indicated in the paper, which makes the assumption of each feature map's pixel diagonal distance to be 1 (still unclear) 37 | similarity_temperature = 0.3, # temperature for the cosine similarity for the pixel contrastive loss 38 | alpha = 1., # weight of the pixel propagation loss (pixpro) vs pixel CL loss 39 | use_pixpro = True, # do pixel pro instead of pixel contrast loss, defaults to pixpro, since it is the best one 40 | cutout_ratio_range = (0.6, 0.8) # a random ratio is selected from this range for the random cutout 41 | ).cuda() 42 | 43 | opt = torch.optim.Adam(learner.parameters(), lr=1e-4) 44 | 45 | def sample_batch_images(): 46 | return torch.randn(10, 3, 256, 256).cuda() 47 | 48 | for _ in tqdm(range(100000)): 49 | images = sample_batch_images() 50 | loss = learner(images) # if positive pixel pairs is equal to zero, the loss is equal to the instance level loss 51 | 52 | opt.zero_grad() 53 | loss.backward() 54 | print(loss.item()) 55 | opt.step() 56 | learner.update_moving_average() # update moving average of target encoder 57 | 58 | # after much training, save the improved model for testing on downstream task 59 | torch.save(resnet, 'improved-resnet.pt') 60 | ``` 61 | 62 | You can also return the number of positive pixel pairs on `forward`, for logging or other purposes 63 | 64 | ```python 65 | loss, positive_pairs = learner(images, return_positive_pairs = True) 66 | ``` 67 | ## Citations 68 | 69 | ```bibtex 70 | @misc{xie2020propagate, 71 | title={Propagate Yourself: Exploring Pixel-Level Consistency for Unsupervised Visual Representation Learning}, 72 | author={Zhenda Xie and Yutong Lin and Zheng Zhang and Yue Cao and Stephen Lin and Han Hu}, 73 | year={2020}, 74 | eprint={2011.10043}, 75 | archivePrefix={arXiv}, 76 | primaryClass={cs.CV} 77 | } 78 | ``` 79 | -------------------------------------------------------------------------------- /pixel_level_contrastive_learning/__init__.py: -------------------------------------------------------------------------------- 1 | from pixel_level_contrastive_learning.pixel_level_contrastive_learning import PPM, PixelCL 2 | -------------------------------------------------------------------------------- /pixel_level_contrastive_learning/pixel_level_contrastive_learning.py: -------------------------------------------------------------------------------- 1 | import math 2 | import copy 3 | import random 4 | from functools import wraps, partial 5 | from math import floor 6 | 7 | import torch 8 | from torch import nn, einsum 9 | import torch.nn.functional as F 10 | 11 | from kornia import augmentation as augs 12 | from kornia import filters, color 13 | 14 | from einops import rearrange 15 | 16 | # helper functions 17 | 18 | def identity(t): 19 | return t 20 | 21 | def default(val, def_val): 22 | return def_val if val is None else val 23 | 24 | def rand_true(prob): 25 | return random.random() < prob 26 | 27 | def singleton(cache_key): 28 | def inner_fn(fn): 29 | @wraps(fn) 30 | def wrapper(self, *args, **kwargs): 31 | instance = getattr(self, cache_key) 32 | if instance is not None: 33 | return instance 34 | 35 | instance = fn(self, *args, **kwargs) 36 | setattr(self, cache_key, instance) 37 | return instance 38 | return wrapper 39 | return inner_fn 40 | 41 | def get_module_device(module): 42 | return next(module.parameters()).device 43 | 44 | def set_requires_grad(model, val): 45 | for p in model.parameters(): 46 | p.requires_grad = val 47 | 48 | def cutout_coordinates(image, ratio_range = (0.6, 0.8)): 49 | _, _, orig_h, orig_w = image.shape 50 | 51 | ratio_lo, ratio_hi = ratio_range 52 | random_ratio = ratio_lo + random.random() * (ratio_hi - ratio_lo) 53 | w, h = floor(random_ratio * orig_w), floor(random_ratio * orig_h) 54 | coor_x = floor((orig_w - w) * random.random()) 55 | coor_y = floor((orig_h - h) * random.random()) 56 | return ((coor_y, coor_y + h), (coor_x, coor_x + w)), random_ratio 57 | 58 | def cutout_and_resize(image, coordinates, output_size = None, mode = 'nearest'): 59 | shape = image.shape 60 | output_size = default(output_size, shape[2:]) 61 | (y0, y1), (x0, x1) = coordinates 62 | cutout_image = image[:, :, y0:y1, x0:x1] 63 | return F.interpolate(cutout_image, size = output_size, mode = mode) 64 | 65 | # augmentation utils 66 | 67 | class RandomApply(nn.Module): 68 | def __init__(self, fn, p): 69 | super().__init__() 70 | self.fn = fn 71 | self.p = p 72 | def forward(self, x): 73 | if random.random() > self.p: 74 | return x 75 | return self.fn(x) 76 | 77 | # exponential moving average 78 | 79 | class EMA(): 80 | def __init__(self, beta): 81 | super().__init__() 82 | self.beta = beta 83 | 84 | def update_average(self, old, new): 85 | if old is None: 86 | return new 87 | return old * self.beta + (1 - self.beta) * new 88 | 89 | def update_moving_average(ema_updater, ma_model, current_model): 90 | for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): 91 | old_weight, up_weight = ma_params.data, current_params.data 92 | ma_params.data = ema_updater.update_average(old_weight, up_weight) 93 | 94 | # loss fn 95 | 96 | def loss_fn(x, y): 97 | x = F.normalize(x, dim=-1, p=2) 98 | y = F.normalize(y, dim=-1, p=2) 99 | return 2 - 2 * (x * y).sum(dim=-1) 100 | 101 | # classes 102 | 103 | class MLP(nn.Module): 104 | def __init__(self, chan, chan_out = 256, inner_dim = 2048): 105 | super().__init__() 106 | self.net = nn.Sequential( 107 | nn.Linear(chan, inner_dim), 108 | nn.BatchNorm1d(inner_dim), 109 | nn.ReLU(), 110 | nn.Linear(inner_dim, chan_out) 111 | ) 112 | 113 | def forward(self, x): 114 | return self.net(x) 115 | 116 | class ConvMLP(nn.Module): 117 | def __init__(self, chan, chan_out = 256, inner_dim = 2048): 118 | super().__init__() 119 | self.net = nn.Sequential( 120 | nn.Conv2d(chan, inner_dim, 1), 121 | nn.BatchNorm2d(inner_dim), 122 | nn.ReLU(), 123 | nn.Conv2d(inner_dim, chan_out, 1) 124 | ) 125 | 126 | def forward(self, x): 127 | return self.net(x) 128 | 129 | class PPM(nn.Module): 130 | def __init__( 131 | self, 132 | *, 133 | chan, 134 | num_layers = 1, 135 | gamma = 2): 136 | super().__init__() 137 | self.gamma = gamma 138 | 139 | if num_layers == 0: 140 | self.transform_net = nn.Identity() 141 | elif num_layers == 1: 142 | self.transform_net = nn.Conv2d(chan, chan, 1) 143 | elif num_layers == 2: 144 | self.transform_net = nn.Sequential( 145 | nn.Conv2d(chan, chan, 1), 146 | nn.BatchNorm2d(chan), 147 | nn.ReLU(), 148 | nn.Conv2d(chan, chan, 1) 149 | ) 150 | else: 151 | raise ValueError('num_layers must be one of 0, 1, or 2') 152 | 153 | def forward(self, x): 154 | xi = x[:, :, :, :, None, None] 155 | xj = x[:, :, None, None, :, :] 156 | similarity = F.relu(F.cosine_similarity(xi, xj, dim = 1)) ** self.gamma 157 | 158 | transform_out = self.transform_net(x) 159 | out = einsum('b x y h w, b c h w -> b c x y', similarity, transform_out) 160 | return out 161 | 162 | # a wrapper class for the base neural network 163 | # will manage the interception of the hidden layer output 164 | # and pipe it into the projecter and predictor nets 165 | 166 | class NetWrapper(nn.Module): 167 | def __init__( 168 | self, 169 | *, 170 | net, 171 | projection_size, 172 | projection_hidden_size, 173 | layer_pixel = -2, 174 | layer_instance = -2 175 | ): 176 | super().__init__() 177 | self.net = net 178 | self.layer_pixel = layer_pixel 179 | self.layer_instance = layer_instance 180 | 181 | self.pixel_projector = None 182 | self.instance_projector = None 183 | 184 | self.projection_size = projection_size 185 | self.projection_hidden_size = projection_hidden_size 186 | 187 | self.hidden_pixel = None 188 | self.hidden_instance = None 189 | self.hook_registered = False 190 | 191 | def _find_layer(self, layer_id): 192 | if type(layer_id) == str: 193 | modules = dict([*self.net.named_modules()]) 194 | return modules.get(layer_id, None) 195 | elif type(layer_id) == int: 196 | children = [*self.net.children()] 197 | return children[layer_id] 198 | return None 199 | 200 | def _hook_pixel(self, _, __, output): 201 | setattr(self, 'hidden_pixel', output) 202 | 203 | def _hook_instance(self, _, __, output): 204 | setattr(self, 'hidden_instance', output) 205 | 206 | def _register_hook(self): 207 | pixel_layer = self._find_layer(self.layer_pixel) 208 | instance_layer = self._find_layer(self.layer_instance) 209 | 210 | assert pixel_layer is not None, f'hidden layer ({self.layer_pixel}) not found' 211 | assert instance_layer is not None, f'hidden layer ({self.layer_instance}) not found' 212 | 213 | pixel_layer.register_forward_hook(self._hook_pixel) 214 | instance_layer.register_forward_hook(self._hook_instance) 215 | self.hook_registered = True 216 | 217 | @singleton('pixel_projector') 218 | def _get_pixel_projector(self, hidden): 219 | _, dim, *_ = hidden.shape 220 | projector = ConvMLP(dim, self.projection_size, self.projection_hidden_size) 221 | return projector.to(hidden) 222 | 223 | @singleton('instance_projector') 224 | def _get_instance_projector(self, hidden): 225 | _, dim = hidden.shape 226 | projector = MLP(dim, self.projection_size, self.projection_hidden_size) 227 | return projector.to(hidden) 228 | 229 | def get_representation(self, x): 230 | if not self.hook_registered: 231 | self._register_hook() 232 | 233 | _ = self.net(x) 234 | hidden_pixel = self.hidden_pixel 235 | hidden_instance = self.hidden_instance 236 | self.hidden_pixel = None 237 | self.hidden_instance = None 238 | assert hidden_pixel is not None, f'hidden pixel layer {self.layer_pixel} never emitted an output' 239 | assert hidden_instance is not None, f'hidden instance layer {self.layer_instance} never emitted an output' 240 | return hidden_pixel, hidden_instance 241 | 242 | def forward(self, x): 243 | pixel_representation, instance_representation = self.get_representation(x) 244 | instance_representation = instance_representation.flatten(1) 245 | 246 | pixel_projector = self._get_pixel_projector(pixel_representation) 247 | instance_projector = self._get_instance_projector(instance_representation) 248 | 249 | pixel_projection = pixel_projector(pixel_representation) 250 | instance_projection = instance_projector(instance_representation) 251 | return pixel_projection, instance_projection 252 | 253 | # main class 254 | 255 | class PixelCL(nn.Module): 256 | def __init__( 257 | self, 258 | net, 259 | image_size, 260 | hidden_layer_pixel = -2, 261 | hidden_layer_instance = -2, 262 | projection_size = 256, 263 | projection_hidden_size = 2048, 264 | augment_fn = None, 265 | augment_fn2 = None, 266 | prob_rand_hflip = 0.25, 267 | moving_average_decay = 0.99, 268 | ppm_num_layers = 1, 269 | ppm_gamma = 2, 270 | distance_thres = 0.7, 271 | similarity_temperature = 0.3, 272 | alpha = 1., 273 | use_pixpro = True, 274 | cutout_ratio_range = (0.6, 0.8), 275 | cutout_interpolate_mode = 'nearest', 276 | coord_cutout_interpolate_mode = 'bilinear' 277 | ): 278 | super().__init__() 279 | 280 | DEFAULT_AUG = nn.Sequential( 281 | RandomApply(augs.ColorJitter(0.8, 0.8, 0.8, 0.2), p=0.8), 282 | augs.RandomGrayscale(p=0.2), 283 | RandomApply(filters.GaussianBlur2d((3, 3), (1.5, 1.5)), p=0.1), 284 | augs.RandomSolarize(p=0.5), 285 | augs.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])) 286 | ) 287 | 288 | self.augment1 = default(augment_fn, DEFAULT_AUG) 289 | self.augment2 = default(augment_fn2, self.augment1) 290 | self.prob_rand_hflip = prob_rand_hflip 291 | 292 | self.online_encoder = NetWrapper( 293 | net = net, 294 | projection_size = projection_size, 295 | projection_hidden_size = projection_hidden_size, 296 | layer_pixel = hidden_layer_pixel, 297 | layer_instance = hidden_layer_instance 298 | ) 299 | 300 | self.target_encoder = None 301 | self.target_ema_updater = EMA(moving_average_decay) 302 | 303 | self.distance_thres = distance_thres 304 | self.similarity_temperature = similarity_temperature 305 | self.alpha = alpha 306 | 307 | self.use_pixpro = use_pixpro 308 | 309 | if use_pixpro: 310 | self.propagate_pixels = PPM( 311 | chan = projection_size, 312 | num_layers = ppm_num_layers, 313 | gamma = ppm_gamma 314 | ) 315 | 316 | self.cutout_ratio_range = cutout_ratio_range 317 | self.cutout_interpolate_mode = cutout_interpolate_mode 318 | self.coord_cutout_interpolate_mode = coord_cutout_interpolate_mode 319 | 320 | # instance level predictor 321 | self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size) 322 | 323 | # get device of network and make wrapper same device 324 | device = get_module_device(net) 325 | self.to(device) 326 | 327 | # send a mock image tensor to instantiate singleton parameters 328 | self.forward(torch.randn(2, 3, image_size, image_size, device=device)) 329 | 330 | @singleton('target_encoder') 331 | def _get_target_encoder(self): 332 | target_encoder = copy.deepcopy(self.online_encoder) 333 | set_requires_grad(target_encoder, False) 334 | return target_encoder 335 | 336 | def reset_moving_average(self): 337 | del self.target_encoder 338 | self.target_encoder = None 339 | 340 | def update_moving_average(self): 341 | assert self.target_encoder is not None, 'target encoder has not been created yet' 342 | update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder) 343 | 344 | def forward(self, x, return_positive_pairs = False): 345 | shape, device, prob_flip = x.shape, x.device, self.prob_rand_hflip 346 | 347 | rand_flip_fn = lambda t: torch.flip(t, dims = (-1,)) 348 | 349 | flip_image_one, flip_image_two = rand_true(prob_flip), rand_true(prob_flip) 350 | flip_image_one_fn = rand_flip_fn if flip_image_one else identity 351 | flip_image_two_fn = rand_flip_fn if flip_image_two else identity 352 | 353 | cutout_coordinates_one, _ = cutout_coordinates(x, self.cutout_ratio_range) 354 | cutout_coordinates_two, _ = cutout_coordinates(x, self.cutout_ratio_range) 355 | 356 | image_one_cutout = cutout_and_resize(x, cutout_coordinates_one, mode = self.cutout_interpolate_mode) 357 | image_two_cutout = cutout_and_resize(x, cutout_coordinates_two, mode = self.cutout_interpolate_mode) 358 | 359 | image_one_cutout = flip_image_one_fn(image_one_cutout) 360 | image_two_cutout = flip_image_two_fn(image_two_cutout) 361 | 362 | image_one_cutout, image_two_cutout = self.augment1(image_one_cutout), self.augment2(image_two_cutout) 363 | 364 | proj_pixel_one, proj_instance_one = self.online_encoder(image_one_cutout) 365 | proj_pixel_two, proj_instance_two = self.online_encoder(image_two_cutout) 366 | 367 | image_h, image_w = shape[2:] 368 | 369 | proj_image_shape = proj_pixel_one.shape[2:] 370 | proj_image_h, proj_image_w = proj_image_shape 371 | 372 | coordinates = torch.meshgrid( 373 | torch.arange(image_h, device = device), 374 | torch.arange(image_w, device = device) 375 | ) 376 | 377 | coordinates = torch.stack(coordinates).unsqueeze(0).float() 378 | coordinates /= math.sqrt(image_h ** 2 + image_w ** 2) 379 | coordinates[:, 0] *= proj_image_h 380 | coordinates[:, 1] *= proj_image_w 381 | 382 | proj_coors_one = cutout_and_resize(coordinates, cutout_coordinates_one, output_size = proj_image_shape, mode = self.coord_cutout_interpolate_mode) 383 | proj_coors_two = cutout_and_resize(coordinates, cutout_coordinates_two, output_size = proj_image_shape, mode = self.coord_cutout_interpolate_mode) 384 | 385 | proj_coors_one = flip_image_one_fn(proj_coors_one) 386 | proj_coors_two = flip_image_two_fn(proj_coors_two) 387 | 388 | proj_coors_one, proj_coors_two = map(lambda t: rearrange(t, 'b c h w -> (b h w) c'), (proj_coors_one, proj_coors_two)) 389 | pdist = nn.PairwiseDistance(p = 2) 390 | 391 | num_pixels = proj_coors_one.shape[0] 392 | 393 | proj_coors_one_expanded = proj_coors_one[:, None].expand(num_pixels, num_pixels, -1).reshape(num_pixels * num_pixels, 2) 394 | proj_coors_two_expanded = proj_coors_two[None, :].expand(num_pixels, num_pixels, -1).reshape(num_pixels * num_pixels, 2) 395 | 396 | distance_matrix = pdist(proj_coors_one_expanded, proj_coors_two_expanded) 397 | distance_matrix = distance_matrix.reshape(num_pixels, num_pixels) 398 | 399 | positive_mask_one_two = distance_matrix < self.distance_thres 400 | positive_mask_two_one = positive_mask_one_two.t() 401 | 402 | with torch.no_grad(): 403 | target_encoder = self._get_target_encoder() 404 | target_proj_pixel_one, target_proj_instance_one = target_encoder(image_one_cutout) 405 | target_proj_pixel_two, target_proj_instance_two = target_encoder(image_two_cutout) 406 | 407 | # flatten all the pixel projections 408 | 409 | flatten = lambda t: rearrange(t, 'b c h w -> b c (h w)') 410 | 411 | target_proj_pixel_one, target_proj_pixel_two = list(map(flatten, (target_proj_pixel_one, target_proj_pixel_two))) 412 | 413 | # get total number of positive pixel pairs 414 | 415 | positive_pixel_pairs = positive_mask_one_two.sum() 416 | 417 | # get instance level loss 418 | 419 | pred_instance_one = self.online_predictor(proj_instance_one) 420 | pred_instance_two = self.online_predictor(proj_instance_two) 421 | 422 | loss_instance_one = loss_fn(pred_instance_one, target_proj_instance_two.detach()) 423 | loss_instance_two = loss_fn(pred_instance_two, target_proj_instance_one.detach()) 424 | 425 | instance_loss = (loss_instance_one + loss_instance_two).mean() 426 | 427 | if positive_pixel_pairs == 0: 428 | ret = (instance_loss, 0) if return_positive_pairs else instance_loss 429 | return ret 430 | 431 | if not self.use_pixpro: 432 | # calculate pix contrast loss 433 | 434 | proj_pixel_one, proj_pixel_two = list(map(flatten, (proj_pixel_one, proj_pixel_two))) 435 | 436 | similarity_one_two = F.cosine_similarity(proj_pixel_one[..., :, None], target_proj_pixel_two[..., None, :], dim = 1) / self.similarity_temperature 437 | similarity_two_one = F.cosine_similarity(proj_pixel_two[..., :, None], target_proj_pixel_one[..., None, :], dim = 1) / self.similarity_temperature 438 | 439 | loss_pix_one_two = -torch.log( 440 | similarity_one_two.masked_select(positive_mask_one_two[None, ...]).exp().sum() / 441 | similarity_one_two.exp().sum() 442 | ) 443 | 444 | loss_pix_two_one = -torch.log( 445 | similarity_two_one.masked_select(positive_mask_two_one[None, ...]).exp().sum() / 446 | similarity_two_one.exp().sum() 447 | ) 448 | 449 | pix_loss = (loss_pix_one_two + loss_pix_two_one) / 2 450 | else: 451 | # calculate pix pro loss 452 | 453 | propagated_pixels_one = self.propagate_pixels(proj_pixel_one) 454 | propagated_pixels_two = self.propagate_pixels(proj_pixel_two) 455 | 456 | propagated_pixels_one, propagated_pixels_two = list(map(flatten, (propagated_pixels_one, propagated_pixels_two))) 457 | 458 | propagated_similarity_one_two = F.cosine_similarity(propagated_pixels_one[..., :, None], target_proj_pixel_two[..., None, :], dim = 1) 459 | propagated_similarity_two_one = F.cosine_similarity(propagated_pixels_two[..., :, None], target_proj_pixel_one[..., None, :], dim = 1) 460 | 461 | loss_pixpro_one_two = - propagated_similarity_one_two.masked_select(positive_mask_one_two[None, ...]).mean() 462 | loss_pixpro_two_one = - propagated_similarity_two_one.masked_select(positive_mask_two_one[None, ...]).mean() 463 | 464 | pix_loss = (loss_pixpro_one_two + loss_pixpro_two_one) / 2 465 | 466 | # total loss 467 | 468 | loss = pix_loss * self.alpha + instance_loss 469 | 470 | ret = (loss, positive_pixel_pairs) if return_positive_pairs else loss 471 | return ret 472 | -------------------------------------------------------------------------------- /propagate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/pixel-level-contrastive-learning/0c60e93df73b0ec351f2104839c4d4748f3d38a2/propagate.png -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'pixel-level-contrastive-learning', 5 | packages = find_packages(), 6 | version = '0.1.1', 7 | license='MIT', 8 | description = 'Pixel-Level Contrastive Learning', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | url = 'https://github.com/lucidrains/pixel-level-contrastive-learning', 12 | keywords = ['self-supervised learning', 'artificial intelligence'], 13 | install_requires=[ 14 | 'einops', 15 | 'torch>=1.6', 16 | 'kornia>=0.4.0' 17 | ], 18 | classifiers=[ 19 | 'Development Status :: 4 - Beta', 20 | 'Intended Audience :: Developers', 21 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 22 | 'License :: OSI Approved :: MIT License', 23 | 'Programming Language :: Python :: 3.6', 24 | ], 25 | ) --------------------------------------------------------------------------------