├── PLRC ├── __init__.py ├── models │ ├── utils.py │ ├── resnet.py │ ├── builder.py │ └── plrc_loss.py └── datasets │ ├── transforms.py │ └── imagenet.py ├── CONTRIBUTING.md ├── README.md ├── CODE_OF_CONDUCT.md ├── main_plrc.py └── LICENSE /PLRC/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to PLRC 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to PLRC, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Point-Level Region Contrast for Object Detection Pre-Training 2 | 3 | 4 | This is a PyTorch implementation of the [PLRC paper](https://arxiv.org/abs/2202.04639): 5 | ``` 6 | @inproceedings{bai2022point, 7 | title={Point-Level Region Contrast for Object Detection Pre-Training}, 8 | author={Bai, Yutong and Chen, Xinlei and Kirillov, Alexander and Yuille, Alan and Berg, Alexander C}, 9 | booktitle={CVPR}, 10 | year={2022} 11 | } 12 | ``` 13 | 14 | 15 | ### Preparation 16 | 17 | Install PyTorch and ImageNet dataset following the [official PyTorch ImageNet training code](https://github.com/pytorch/examples/tree/master/imagenet). 18 | 19 | 20 | 21 | ### Unsupervised Training 22 | 23 | This implementation only supports **multi-gpu**, **DistributedDataParallel** training, which is faster and simpler; single-gpu or DataParallel training is not supported. 24 | 25 | To do unsupervised pre-training of a ResNet-50 model on ImageNet in an 8-gpu machine, run: 26 | ``` 27 | python main_plrc.py \ 28 | --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0 29 | ``` 30 | This script uses all the default hyper-parameters as described in the PRLC paper. 31 | 32 | 33 | ### Models 34 | 35 | Our pre-trained ResNet-50 model and finetuned checkpoints on object detection can be downloaded as following: 36 | 37 | 38 | | | Pretrained Model | Epoch | 39 | | ----------- | :-------------------------------------------------------------------------------------------------: | :------: 40 | | Res50 | [download link](https://dl.fbaipublicfiles.com/plrc/pre-train/model_final.pth) | 100 | 41 | 42 | 43 | 44 | | | Finetuned Model | AP | AP50 | AP75 | 45 | | ----------- | :-------------------------------------------------------------------------------------------------: | :------: | :--------: | :--------: | 46 | | Res50 | [download link](https://dl.fbaipublicfiles.com/plrc/fine-tune/model_final.pth) | 58.2 | 82.7 | 65.1 | 47 | 48 | The APs on Pascal VOC is averaged over 5 times. 49 | 50 | 51 | ### Detection 52 | 53 | Same as [MoCo](https://github.com/facebookresearch/moco) for object detection transfer, please see [moco/detection](https://github.com/facebookresearch/moco/tree/master/detection). 54 | 55 | 56 | ### Visualization 57 | 58 | For model visualzation, we provide an [google colab](https://colab.research.google.com/drive/172dmSGYAzEgiMJ1RFyuStrj_YOQVvFpQ?usp=sharing) for better illustration. 59 | 60 | 61 | 62 | 63 | ### License 64 | 65 | This project is under the CC-BY-NC 4.0 license. See [LICENSE](LICENSE) for details. 66 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq -------------------------------------------------------------------------------- /PLRC/models/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from scipy.stats import ortho_group 9 | import torch.distributed as du 10 | import sys 11 | import numpy as np 12 | from PIL import Image 13 | import random 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | 19 | import PLRC.models.plrc_loss as plrc_loss 20 | from PLRC.models.resnet import ResNetPLRC 21 | import PLRC.models.builder as builder 22 | 23 | 24 | def build_mlp1(input_dim, output_dim): 25 | 26 | return build_more_fcs( 27 | input_dim, output_dim, False, 1, 2048, "none", True, "none", True, True 28 | ) 29 | 30 | 31 | def append_fc_layers(fcs, norm_fc, use_bias, input_dim, dim_fc): 32 | 33 | fcs.append(nn.Linear(input_dim, dim_fc, bias=use_bias)) 34 | fcs.append(nn.ReLU(inplace=True)) 35 | 36 | return fcs 37 | 38 | 39 | def append_output_layer( 40 | fcs, norm_out, use_bias_out, use_weight_out, input_dim, output_dim 41 | ): 42 | 43 | fcs.append(nn.Linear(input_dim, output_dim, bias=use_bias_out)) 44 | 45 | return fcs 46 | 47 | 48 | def build_more_fcs( 49 | input_dim, 50 | output_dim, 51 | first_relu, 52 | more_fc, 53 | dim_fc, 54 | norm_fc, 55 | use_bias, 56 | norm_out, 57 | use_bias_out, 58 | use_weight_out, 59 | ): 60 | fcs = [] 61 | for _ in range(more_fc): 62 | fcs = append_fc_layers(fcs, norm_fc, use_bias, input_dim, dim_fc) 63 | input_dim = dim_fc 64 | fcs = append_output_layer( 65 | fcs, norm_out, use_bias_out, use_weight_out, input_dim, output_dim 66 | ) 67 | fcs = nn.Sequential(*fcs) 68 | 69 | return fcs 70 | 71 | 72 | def overlap(q, k, coord_q, coord_k, mask_size, pos_ratio=0.5): 73 | """q, k: N * C * H * W 74 | coord_q, coord_k: N * 4 (x_upper_left, y_upper_left, x_lower_right, y_lower_right) 75 | """ 76 | N, _, C = q.shape # -1 49 c 77 | H, W = mask_size, mask_size 78 | # [1, 7, 7] 79 | x_array = ( 80 | torch.arange(0.0, float(W), dtype=coord_q.dtype, device=coord_q.device) 81 | .view(1, 1, -1) 82 | .repeat(1, H, 1) 83 | ) 84 | y_array = ( 85 | torch.arange(0.0, float(H), dtype=coord_q.dtype, device=coord_q.device) 86 | .view(1, -1, 1) 87 | .repeat(1, 1, W) 88 | ) 89 | # [bs, 1, 1] 90 | q_bin_width = ((coord_q[:, 2] - coord_q[:, 0]) / W).view(-1, 1, 1) 91 | q_bin_height = ((coord_q[:, 3] - coord_q[:, 1]) / H).view(-1, 1, 1) 92 | k_bin_width = ((coord_k[:, 2] - coord_k[:, 0]) / W).view(-1, 1, 1) 93 | k_bin_height = ((coord_k[:, 3] - coord_k[:, 1]) / H).view(-1, 1, 1) 94 | # [bs, 1, 1] 95 | q_start_x = coord_q[:, 0].view(-1, 1, 1) 96 | q_start_y = coord_q[:, 1].view(-1, 1, 1) 97 | k_start_x = coord_k[:, 0].view(-1, 1, 1) 98 | k_start_y = coord_k[:, 1].view(-1, 1, 1) 99 | 100 | # [bs, 1, 1] 101 | q_bin_diag = torch.sqrt(q_bin_width**2 + q_bin_height**2) 102 | k_bin_diag = torch.sqrt(k_bin_width**2 + k_bin_height**2) 103 | max_bin_diag = torch.max(q_bin_diag, k_bin_diag) 104 | 105 | # [bs, 7, 7] 106 | center_q_x = (x_array + 0.5) * q_bin_width + q_start_x 107 | center_q_y = (y_array + 0.5) * q_bin_height + q_start_y 108 | center_k_x = (x_array + 0.5) * k_bin_width + k_start_x 109 | center_k_y = (y_array + 0.5) * k_bin_height + k_start_y 110 | 111 | # [bs, 49, 49] 112 | dist_center = ( 113 | torch.sqrt( 114 | (center_q_x.view(-1, H * W, 1) - center_k_x.view(-1, 1, H * W)) ** 2 115 | + (center_q_y.view(-1, H * W, 1) - center_k_y.view(-1, 1, H * W)) ** 2 116 | ) 117 | / max_bin_diag 118 | ) 119 | pos_mask = (dist_center < pos_ratio).float().detach() 120 | return pos_mask 121 | -------------------------------------------------------------------------------- /PLRC/models/resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from __future__ import print_function 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import math 13 | 14 | import PLRC.models.builder as builder 15 | 16 | import numpy as np 17 | 18 | 19 | __all__ = ["ResNet", "resnet50"] 20 | 21 | 22 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 23 | def norm_cdf(x): 24 | return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 25 | 26 | if (mean < a - 2 * std) or (mean > b + 2 * std): 27 | warnings.warn( 28 | "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 29 | "The distribution of values may be incorrect.", 30 | stacklevel=2, 31 | ) 32 | 33 | with torch.no_grad(): 34 | l = norm_cdf((a - mean) / std) 35 | u = norm_cdf((b - mean) / std) 36 | tensor.uniform_(2 * l - 1, 2 * u - 1) 37 | tensor.erfinv_() 38 | tensor.mul_(std * math.sqrt(2.0)) 39 | tensor.add_(mean) 40 | tensor.clamp_(min=a, max=b) 41 | return tensor 42 | 43 | 44 | def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): 45 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 46 | 47 | 48 | def conv3x3(in_planes, out_planes, stride=1): 49 | return nn.Conv2d( 50 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False 51 | ) 52 | 53 | 54 | class ProjHead(nn.Module): 55 | def __init__( 56 | self, 57 | in_dim, 58 | out_dim=4096, 59 | use_bn=True, 60 | norm_last_layer=True, 61 | nlayers=3, 62 | hidden_dim=2048, 63 | bottleneck_dim=256, 64 | ): 65 | super().__init__() 66 | nlayers = max(nlayers, 1) 67 | if nlayers == 1: 68 | self.mlp = nn.Linear(in_dim, bottleneck_dim) 69 | else: 70 | layers = [nn.Linear(in_dim, hidden_dim)] 71 | if use_bn: 72 | layers.append(nn.BatchNorm1d(hidden_dim)) 73 | layers.append(nn.ReLU()) 74 | for _ in range(nlayers - 2): 75 | layers.append(nn.Linear(hidden_dim, hidden_dim)) 76 | if use_bn: 77 | layers.append(nn.BatchNorm1d(hidden_dim)) 78 | layers.append(nn.ReLU()) 79 | layers.append(nn.Linear(hidden_dim, bottleneck_dim)) 80 | self.mlp = nn.Sequential(*layers) 81 | self.apply(self._init_weights) 82 | 83 | self.last_layer = nn.utils.weight_norm( 84 | nn.Linear(bottleneck_dim, out_dim, bias=False) 85 | ) 86 | 87 | self.last_layer.weight_g.data.fill_(1) 88 | if norm_last_layer: 89 | self.last_layer.weight_g.requires_grad = False 90 | 91 | def _init_weights(self, m): 92 | if isinstance(m, nn.Linear): 93 | trunc_normal_(m.weight, std=0.02) 94 | if isinstance(m, nn.Linear) and m.bias is not None: 95 | nn.init.constant_(m.bias, 0) 96 | 97 | def forward(self, x, dy=False): 98 | x = self.mlp(x) 99 | if dy == True: 100 | x_f = nn.functional.normalize(x, dim=-1, p=1) 101 | else: 102 | x_f = nn.functional.normalize(x, dim=-1, p=2) 103 | 104 | x_l = self.last_layer(x_f) 105 | return x_f, x_l 106 | 107 | 108 | class BasicBlock(nn.Module): 109 | expansion = 1 110 | 111 | def __init__(self, inplanes, planes, stride=1, downsample=None): 112 | super(BasicBlock, self).__init__() 113 | self.conv1 = conv3x3(inplanes, planes, stride) 114 | self.bn1 = nn.BatchNorm2d(planes) 115 | self.relu = nn.ReLU(inplace=True) 116 | self.conv2 = conv3x3(planes, planes) 117 | self.bn2 = nn.BatchNorm2d(planes) 118 | self.downsample = downsample 119 | self.stride = stride 120 | 121 | def forward(self, x): 122 | residual = x 123 | 124 | out = self.conv1(x) 125 | out = self.bn1(out) 126 | out = self.relu(out) 127 | 128 | out = self.conv2(out) 129 | out = self.bn2(out) 130 | 131 | if self.downsample is not None: 132 | residual = self.downsample(x) 133 | 134 | out += residual 135 | out = self.relu(out) 136 | 137 | return out 138 | 139 | 140 | class Bottleneck(nn.Module): 141 | expansion = 4 142 | 143 | def __init__(self, inplanes, planes, stride=1, downsample=None): 144 | super(Bottleneck, self).__init__() 145 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 146 | self.bn1 = nn.BatchNorm2d(planes) 147 | self.conv2 = nn.Conv2d( 148 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=False 149 | ) 150 | self.bn2 = nn.BatchNorm2d(planes) 151 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 152 | self.bn3 = nn.BatchNorm2d(planes * 4) 153 | self.relu = nn.ReLU(inplace=True) 154 | self.downsample = downsample 155 | self.stride = stride 156 | 157 | def forward(self, x): 158 | residual = x 159 | 160 | out = self.conv1(x) 161 | out = self.bn1(out) 162 | out = self.relu(out) 163 | 164 | out = self.conv2(out) 165 | out = self.bn2(out) 166 | out = self.relu(out) 167 | 168 | out = self.conv3(out) 169 | out = self.bn3(out) 170 | 171 | if self.downsample is not None: 172 | residual = self.downsample(x) 173 | 174 | out += residual 175 | out = self.relu(out) 176 | 177 | return out 178 | 179 | 180 | class ResNet(nn.Module): 181 | def __init__(self, block, layers, low_dim=128, in_channel=3, width_scale=1.0): 182 | self.inplanes = 64 183 | super(ResNet, self).__init__() 184 | self.conv1 = nn.Conv2d( 185 | in_channel, 64, kernel_size=7, stride=2, padding=3, bias=False 186 | ) 187 | self.bn1 = nn.BatchNorm2d(64) 188 | self.relu = nn.ReLU(inplace=True) 189 | self.base = int(64 * width_scale) 190 | 191 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 192 | self.layer1 = self._make_layer(block, self.base, layers[0]) 193 | self.layer2 = self._make_layer(block, self.base * 2, layers[1], stride=2) 194 | self.layer3 = self._make_layer(block, self.base * 4, layers[2], stride=2) 195 | self.layer4 = self._make_layer(block, self.base * 8, layers[3], stride=2) 196 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 197 | current_dim = self.base * 8 * block.expansion 198 | 199 | last_bn_name = "bn2" if isinstance(block, BasicBlock) else "bn3" 200 | for name, m in self.named_modules(): 201 | if isinstance(m, nn.Conv2d): 202 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 203 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 204 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 205 | if m.weight is not None: 206 | if name.startswith("layer") and name.endswith(last_bn_name): 207 | m.weight.data.zero_() 208 | else: 209 | m.weight.data.fill_(1) 210 | if m.bias is not None: 211 | m.bias.data.zero_() 212 | 213 | self.fcs = builder.build_mlp1(current_dim, low_dim) 214 | self.fcs_2 = ProjHead(current_dim, 4096) 215 | 216 | def _make_layer(self, block, planes, blocks, stride=1): 217 | downsample = None 218 | if stride != 1 or self.inplanes != planes * block.expansion: 219 | downsample = nn.Sequential( 220 | nn.Conv2d( 221 | self.inplanes, 222 | planes * block.expansion, 223 | kernel_size=1, 224 | stride=stride, 225 | bias=False, 226 | ), 227 | nn.BatchNorm2d(planes * block.expansion), 228 | ) 229 | 230 | layers = [] 231 | layers.append(block(self.inplanes, planes, stride, downsample)) 232 | self.inplanes = planes * block.expansion 233 | for _ in range(1, blocks): 234 | layers.append(block(self.inplanes, planes)) 235 | 236 | return nn.Sequential(*layers) 237 | 238 | def forward(self, x, mode="image", dy=False): 239 | x = self.conv1(x) 240 | x = self.bn1(x) 241 | x = self.relu(x) 242 | x = self.maxpool(x) 243 | x = self.layer1(x) 244 | x = self.layer2(x) 245 | x = self.layer3(x) 246 | x = self.layer4(x) 247 | if mode == "point": 248 | B, C, H, W = x.shape 249 | feat = x 250 | feat = feat.view(feat.size(0), -1) 251 | x = x.permute(0, 2, 3, 1).reshape(B * H * W, C) 252 | y_1 = self.fcs(x).reshape(B, H, W, -1).permute(0, 3, 1, 2) 253 | y_2_f, y_2_l = self.fcs_2(x) 254 | y_2_f = y_2_f.reshape(B, H, W, -1).permute(0, 3, 1, 2) 255 | y_2_l = y_2_l.reshape(B, H, W, -1).permute(0, 3, 1, 2) 256 | 257 | elif mode == "image": 258 | feat = self.avgpool(x) 259 | feat = feat.view(feat.size(0), -1) 260 | x = self.fcs(feat) 261 | 262 | return y_1, y_2_f, y_2_l 263 | 264 | 265 | def resnet50(**kwargs): 266 | """Constructs a ResNet-50 model.""" 267 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 268 | return model 269 | 270 | 271 | class ResNetPLRC(nn.Module): 272 | def __init__(self, name="resnet50", feat_dim=128): 273 | super(ResNetPLRC, self).__init__() 274 | width_scale = 1 275 | feat_dim = 128 276 | self.name = name 277 | self.enc = resnet50(in_channel=3, low_dim=feat_dim, width_scale=width_scale) 278 | 279 | def forward(self, x, mode="image", dy=False): 280 | feat, med_feat, feat_2 = self.enc(x, mode=mode, dy=dy) 281 | return feat, med_feat, feat_2, None 282 | -------------------------------------------------------------------------------- /PLRC/datasets/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Image transformations.""" 8 | 9 | import cv2 10 | from PIL import Image, ImageFilter, ImageOps 11 | import math 12 | import numpy as np 13 | import random 14 | import torchvision.transforms.functional as tf 15 | import torch 16 | 17 | 18 | _pil_interpolation_to_str = { 19 | Image.NEAREST: "PIL.Image.NEAREST", 20 | Image.BILINEAR: "PIL.Image.BILINEAR", 21 | Image.BICUBIC: "PIL.Image.BICUBIC", 22 | Image.LANCZOS: "PIL.Image.LANCZOS", 23 | Image.HAMMING: "PIL.Image.HAMMING", 24 | Image.BOX: "PIL.Image.BOX", 25 | } 26 | 27 | 28 | def CHW2HWC(image): 29 | return image.transpose([1, 2, 0]) 30 | 31 | 32 | def HWC2CHW(image): 33 | return image.transpose([2, 0, 1]) 34 | 35 | 36 | def color_normalization(image, mean, std): 37 | """Expects image in CHW format.""" 38 | assert len(mean) == image.shape[0] 39 | assert len(std) == image.shape[0] 40 | for i in range(image.shape[0]): 41 | image[i] = image[i] - mean[i] 42 | image[i] = image[i] / std[i] 43 | return image 44 | 45 | 46 | def zero_pad(image, pad_size, order="CHW"): 47 | assert order in ["CHW", "HWC"] 48 | if order == "CHW": 49 | pad_width = ((0, 0), (pad_size, pad_size), (pad_size, pad_size)) 50 | else: 51 | pad_width = ((pad_size, pad_size), (pad_size, pad_size), (0, 0)) 52 | return np.pad(image, pad_width, mode="constant") 53 | 54 | 55 | def horizontal_flip(image, prob, order="CHW"): 56 | assert order in ["CHW", "HWC"] 57 | if np.random.uniform() < prob: 58 | if order == "CHW": 59 | image = image[:, :, ::-1] 60 | else: 61 | image = image[:, ::-1, :] 62 | return image 63 | 64 | 65 | def random_crop(image, size): 66 | if image.shape[0] == size and image.shape[1] == size: 67 | return image 68 | height = image.shape[0] 69 | width = image.shape[1] 70 | y_offset = 0 71 | if height > size: 72 | y_offset = int(np.random.randint(0, height - size)) 73 | x_offset = 0 74 | if width > size: 75 | x_offset = int(np.random.randint(0, width - size)) 76 | cropped = image[y_offset : y_offset + size, x_offset : x_offset + size, :] 77 | assert cropped.shape[0] == size, "Image not cropped properly" 78 | assert cropped.shape[1] == size, "Image not cropped properly" 79 | return cropped 80 | 81 | 82 | def scale(size, image): 83 | # TODO(ilijar): Refactor 84 | height = image.shape[0] 85 | width = image.shape[1] 86 | if (width <= height and width == size) or (height <= width and height == size): 87 | return image 88 | new_width = size 89 | new_height = size 90 | if width < height: 91 | new_height = int(math.floor((float(height) / width) * size)) 92 | else: 93 | new_width = int(math.floor((float(width) / height) * size)) 94 | img = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_LINEAR) 95 | return img.astype(np.float32) 96 | 97 | 98 | def center_crop(size, image): 99 | height = image.shape[0] 100 | width = image.shape[1] 101 | y_offset = int(math.ceil((height - size) / 2)) 102 | x_offset = int(math.ceil((width - size) / 2)) 103 | cropped = image[y_offset : y_offset + size, x_offset : x_offset + size, :] 104 | assert cropped.shape[0] == size, "Image height not cropped properly" 105 | assert cropped.shape[1] == size, "Image width not cropped properly" 106 | return cropped 107 | 108 | 109 | def random_sized_crop(image, size, area_frac=0.08): 110 | for _ in range(0, 10): 111 | height = image.shape[0] 112 | width = image.shape[1] 113 | area = height * width 114 | target_area = np.random.uniform(area_frac, 1.0) * area 115 | aspect_ratio = np.random.uniform(3.0 / 4.0, 4.0 / 3.0) 116 | w = int(round(math.sqrt(float(target_area) * aspect_ratio))) 117 | h = int(round(math.sqrt(float(target_area) / aspect_ratio))) 118 | if np.random.uniform() < 0.5: 119 | w, h = h, w 120 | if h <= height and w <= width: 121 | if height == h: 122 | y_offset = 0 123 | else: 124 | y_offset = np.random.randint(0, height - h) 125 | if width == w: 126 | x_offset = 0 127 | else: 128 | x_offset = np.random.randint(0, width - w) 129 | y_offset = int(y_offset) 130 | x_offset = int(x_offset) 131 | cropped = image[y_offset : y_offset + h, x_offset : x_offset + w, :] 132 | assert cropped.shape[0] == h and cropped.shape[1] == w, "Wrong crop size" 133 | cropped = cv2.resize(cropped, (size, size), interpolation=cv2.INTER_LINEAR) 134 | return cropped.astype(np.float32) 135 | return center_crop(size, scale(size, image)) 136 | 137 | 138 | def lighting(img, alphastd, eigval, eigvec): 139 | # TODO(ilijar): Refactor 140 | if alphastd == 0: 141 | return img 142 | # generate alpha1, alpha2, alpha3 143 | alpha = np.random.normal(0, alphastd, size=(1, 3)) 144 | eig_vec = np.array(eigvec) 145 | eig_val = np.reshape(eigval, (1, 3)) 146 | rgb = np.sum( 147 | eig_vec * np.repeat(alpha, 3, axis=0) * np.repeat(eig_val, 3, axis=0), axis=1 148 | ) 149 | for idx in range(img.shape[0]): 150 | img[idx] = img[idx] + rgb[2 - idx] 151 | return img 152 | 153 | 154 | class GaussianBlurSimple(object): 155 | """Gaussian blur augmentation from SimCLR https://arxiv.org/abs/2002.05709""" 156 | 157 | def __init__(self, sigma=[0.1, 2.0]): 158 | self.sigma = sigma 159 | 160 | def __call__(self, x): 161 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 162 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 163 | return x 164 | 165 | 166 | def _get_image_size(img): 167 | if tf._is_pil_image(img): 168 | return img.size 169 | elif isinstance(img, torch.Tensor) and img.dim() > 2: 170 | return img.shape[-2:][::-1] 171 | else: 172 | raise TypeError("Unexpected type {}".format(type(img))) 173 | 174 | 175 | class RandomResizedCrop(object): 176 | """Crop the given PIL Image to random size and aspect ratio. 177 | A crop of random size (default: of 0.08 to 1.0) of the original size and a random 178 | aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop 179 | is finally resized to given size. 180 | This is popularly used to train the Inception networks. 181 | Args: 182 | size: expected output size of each edge 183 | scale: range of size of the origin size cropped 184 | ratio: range of aspect ratio of the origin aspect ratio cropped 185 | interpolation: Default: PIL.Image.BILINEAR 186 | """ 187 | 188 | def __init__( 189 | self, 190 | size, 191 | scale=(0.08, 1.0), 192 | ratio=(3.0 / 4.0, 4.0 / 3.0), 193 | interpolation=Image.BILINEAR, 194 | ): 195 | if isinstance(size, (tuple, list)): 196 | self.size = size 197 | else: 198 | self.size = (size, size) 199 | 200 | if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): 201 | warnings.warn("range should be of kind (min, max)") 202 | 203 | self.interpolation = interpolation 204 | self.scale = scale 205 | self.ratio = ratio 206 | 207 | @staticmethod 208 | def get_params(img, scale, ratio): 209 | """Get parameters for ``crop`` for a random sized crop. 210 | Args: 211 | img (PIL Image): Image to be cropped. 212 | scale (tuple): range of size of the origin size cropped 213 | ratio (tuple): range of aspect ratio of the origin aspect ratio cropped 214 | Returns: 215 | tuple: params (i, j, h, w) to be passed to ``crop`` for a random 216 | sized crop. 217 | """ 218 | width, height = _get_image_size(img) 219 | area = height * width 220 | 221 | for attempt in range(10): 222 | target_area = random.uniform(*scale) * area 223 | log_ratio = (math.log(ratio[0]), math.log(ratio[1])) 224 | aspect_ratio = math.exp(random.uniform(*log_ratio)) 225 | 226 | w = int(round(math.sqrt(target_area * aspect_ratio))) 227 | h = int(round(math.sqrt(target_area / aspect_ratio))) 228 | 229 | if 0 < w <= width and 0 < h <= height: 230 | i = random.randint(0, height - h) 231 | j = random.randint(0, width - w) 232 | return i, j, h, w, height, width 233 | 234 | # Fallback to central crop 235 | in_ratio = float(width) / float(height) 236 | if in_ratio < min(ratio): 237 | w = width 238 | h = int(round(w / min(ratio))) 239 | elif in_ratio > max(ratio): 240 | h = height 241 | w = int(round(h * max(ratio))) 242 | else: # whole image 243 | w = width 244 | h = height 245 | i = (height - h) // 2 246 | j = (width - w) // 2 247 | return i, j, h, w, height, width 248 | 249 | def __call__(self, img): 250 | """ 251 | Args: 252 | img (PIL Image): Image to be cropped and resized. 253 | Returns: 254 | PIL Image: Randomly cropped and resized image. 255 | """ 256 | i, j, h, w, height, width = self.get_params(img, self.scale, self.ratio) 257 | coord = torch.Tensor( 258 | [ 259 | float(j) / (width - 1), 260 | float(i) / (height - 1), 261 | float(j + w - 1) / (width - 1), 262 | float(i + h - 1) / (height - 1), 263 | ] 264 | ) 265 | return tf.resized_crop(img, i, j, h, w, self.size, self.interpolation), coord 266 | 267 | def resized_crop(self, img, i, j, h, w, height, width): 268 | 269 | coord = torch.Tensor( 270 | [ 271 | float(j) / (width - 1), 272 | float(i) / (height - 1), 273 | float(j + w - 1) / (width - 1), 274 | float(i + h - 1) / (height - 1), 275 | ] 276 | ) 277 | return tf.resized_crop(img, i, j, h, w, self.size, self.interpolation), coord 278 | 279 | def __repr__(self): 280 | interpolate_str = _pil_interpolation_to_str[self.interpolation] 281 | format_string = self.__class__.__name__ + "(size={0}".format(self.size) 282 | format_string += ", scale={0}".format(tuple(round(s, 4) for s in self.scale)) 283 | format_string += ", ratio={0}".format(tuple(round(r, 4) for r in self.ratio)) 284 | format_string += ", interpolation={0})".format(interpolate_str) 285 | return format_string 286 | 287 | 288 | class RandomHorizontalFlip(object): 289 | """Horizontally flip the given PIL Image randomly with a given probability. 290 | Args: 291 | p (float): probability of the image being flipped. Default value is 0.5 292 | """ 293 | 294 | def __init__(self, p=0.5): 295 | self.p = p 296 | 297 | def __call__(self, img, coord): 298 | """ 299 | Args: 300 | img (PIL Image): Image to be flipped. 301 | Returns: 302 | PIL Image: Randomly flipped image. 303 | """ 304 | if random.random() < self.p: 305 | coord_new = coord.clone() 306 | coord_new[0] = coord[2] 307 | coord_new[2] = coord[0] 308 | return tf.hflip(img), coord_new 309 | return img, coord 310 | 311 | def __repr__(self): 312 | return self.__class__.__name__ + "(p={})".format(self.p) 313 | -------------------------------------------------------------------------------- /PLRC/datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ImageNet dataset.""" 8 | 9 | import cv2 10 | import numpy as np 11 | import os 12 | 13 | import torch 14 | import torch.utils.data 15 | import PLRC.datasets.transforms as transforms 16 | import torchvision as tv 17 | import torchvision.transforms.functional as tvf 18 | import io 19 | from PIL import Image 20 | from shapely.geometry import Polygon 21 | import json 22 | 23 | # Data location 24 | 25 | rng = np.random.default_rng() 26 | 27 | 28 | class ImageNet(torch.utils.data.Dataset): 29 | def __init__(self, split, args, path=""): 30 | self.path = path 31 | self._split = split 32 | self._first_k = 0 33 | self.random_resizedcrop = None 34 | self._construct_imdb() 35 | 36 | self.mask_size = 56 37 | self.mask_neg_num = 15 38 | self.mask_pos_num = 1 39 | self.mask_grid_num = 4 40 | self.mask_area_avgnum = 32 41 | self.im_size = 224 42 | 43 | normalize = tv.transforms.Normalize( 44 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 45 | ) 46 | mask_np = np.zeros((self.im_size, self.im_size)) 47 | for hh in range(self.mask_grid_num): 48 | for ww in range(self.mask_grid_num): 49 | start_h, end_h = ( 50 | hh * (self.im_size // self.mask_grid_num), 51 | (hh + 1) * (self.im_size // self.mask_grid_num) 52 | if hh != self.mask_grid_num - 1 53 | else self.im_size, 54 | ) 55 | start_w, end_w = ( 56 | ww * (self.im_size // self.mask_grid_num), 57 | (ww + 1) * (self.im_size // self.mask_grid_num) 58 | if ww != self.mask_grid_num - 1 59 | else self.im_size, 60 | ) 61 | mask_np[start_h:end_h, start_w:end_w] = hh * self.mask_grid_num + ww + 1 62 | self.pano_mask = Image.fromarray(np.uint8(mask_np), "L") # 63 | 64 | self.random_resizedcrop = transforms.RandomResizedCropCoord( 65 | self.im_size, scale=(0.2, 1.0) 66 | ) 67 | 68 | self.random_resizedcrop_mask = transforms.RandomResizedCropCoord( 69 | self.mask_size, scale=(0.2, 1.0), interpolation=Image.NEAREST 70 | ) 71 | self.randomhflip = transforms.RandomHorizontalFlipCoord() 72 | 73 | color_jitter = tv.transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) 74 | rnd_color_jitter = tv.transforms.RandomApply([color_jitter], p=0.8) 75 | self._transform = tv.transforms.Compose( 76 | [ 77 | rnd_color_jitter, 78 | tv.transforms.RandomGrayscale(p=0.2), 79 | tv.transforms.RandomApply( 80 | [transforms.GaussianBlurSimple([0.1, 2.0])], p=0.5 81 | ), 82 | tv.transforms.ToTensor(), 83 | normalize, 84 | ] 85 | ) 86 | self._transform_mask = tv.transforms.Compose( 87 | [ 88 | tv.transforms.ToTensor(), 89 | ] 90 | ) 91 | 92 | def _apply_single_transformation(self, n, im): 93 | if n % 2 == 1 and hasattr(self, "_transform_prime"): 94 | return self._transform_prime(im) 95 | else: 96 | return self._transform(im) 97 | 98 | def _construct_imdb(self): 99 | """Constructs the imdb.""" 100 | data_dir = os.path.join(self.path, self._split) 101 | assert os.path.exists(data_dir), "{} dir not found".format(data_dir) 102 | 103 | # Map ImageNet class ids to contiguous ids 104 | self._class_ids = os.listdir(data_dir) 105 | self._class_id_cont_id = {v: i for i, v in enumerate(self._class_ids)} 106 | 107 | # Construct the image db 108 | self._imdb = [] 109 | for class_id in self._class_ids: 110 | cont_id = self._class_id_cont_id[class_id] 111 | im_dir = os.path.join(data_dir, class_id) 112 | for ii, im_name in enumerate(sorted(os.listdir(im_dir))): 113 | if self._first_k and ii >= self._first_k: 114 | break 115 | 116 | self._imdb.append( 117 | { 118 | "im_path": os.path.join(im_dir, im_name), 119 | "class": cont_id, 120 | } 121 | ) 122 | 123 | self.num_classes = len(self._imdb) 124 | 125 | def __getitem__(self, index): 126 | if isinstance(index, tuple): 127 | index, scales, repeats = index 128 | else: 129 | scales = [None, None] 130 | repeats = [1] 131 | flag = 1 132 | anno_mask = True 133 | im = tv.datasets.folder.default_loader(self._imdb[index]["im_path"]) 134 | 135 | width, height = im.size 136 | im_name = self._imdb[index]["im_path"].split("/")[-1].split(".")[0] 137 | pano_mask = self.pano_mask.resize((width, height), resample=Image.NEAREST) 138 | 139 | im_multiv = [] 140 | im_multiv_2 = [] 141 | mask_multiv = [] 142 | 143 | obj_list_multiv = [] 144 | 145 | coord_multiv = [] 146 | coord_multiv_2 = [] 147 | 148 | # augmentations 149 | for n, s in enumerate(scales): 150 | if s is not None: 151 | self._set_crop_size(s) 152 | 153 | iii = 0 154 | while True: 155 | im_ = im 156 | im_2 = im 157 | pano_mask_ = pano_mask 158 | 159 | i, j, h, w, height, width = self.random_resizedcrop.get_params( 160 | im_, self.random_resizedcrop.scale, self.random_resizedcrop.ratio 161 | ) 162 | im_, coord = self.random_resizedcrop.resized_crop( 163 | im_, i, j, h, w, height, width 164 | ) 165 | 166 | pano_mask_, _ = self.random_resizedcrop_mask.resized_crop( 167 | pano_mask_, i, j, h, w, height, width 168 | ) 169 | 170 | ( 171 | i_2, 172 | j_2, 173 | h_2, 174 | w_2, 175 | height_2, 176 | width_2, 177 | ) = self.random_resizedcrop.get_params( 178 | im_2, self.random_resizedcrop.scale, self.random_resizedcrop.ratio 179 | ) 180 | im_2, coord_2 = self.random_resizedcrop.resized_crop( 181 | im_2, i_2, j_2, h_2, w_2, height_2, width_2 182 | ) 183 | 184 | polygon = Polygon([(i, j), (i + h, j), (i + h, j + w), (i, j + w)]) 185 | other_polygon = Polygon( 186 | [ 187 | (i_2, j_2), 188 | (i_2 + h_2, j_2), 189 | (i_2 + h_2, j_2 + w_2), 190 | (i_2, j_2 + w_2), 191 | ] 192 | ) 193 | intersection = polygon.intersection(other_polygon) 194 | 195 | iii += 1 196 | 197 | if intersection.area / max(h * w, h_2 * w_2) >= 0.5 or iii > 100: 198 | break 199 | 200 | if torch.rand(1) < 0.5: 201 | im_, coord = self.randomhflip(im_, coord) 202 | pano_mask_, _ = self.randomhflip(pano_mask_, coord) 203 | 204 | if torch.rand(1) < 0.5: 205 | im_2, coord_2 = self.randomhflip(im_2, coord_2) 206 | 207 | coord_multiv.append(coord) 208 | coord_multiv_2.append(coord_2) 209 | 210 | pano_mask_np_ = np.array(pano_mask_) 211 | 212 | obj_list_ = np.unique(pano_mask_np_) 213 | 214 | str_ = str(len(obj_list_)) 215 | 216 | if len(obj_list_) <= 1: 217 | flag = -1 218 | 219 | obj_list_multiv.append(obj_list_) 220 | mask_multiv.append(pano_mask_np_) 221 | 222 | im_multiv.append(np.array(self._apply_single_transformation(n, im_))) 223 | im_multiv_2.append(np.array(self._apply_single_transformation(n, im_2))) 224 | 225 | if flag == 1: 226 | common_objects = list( 227 | set(obj_list_multiv[0]).intersection(set(obj_list_multiv[1])) 228 | ) 229 | if len(common_objects) == 0: 230 | flag = -1 231 | else: 232 | obj = common_objects[np.random.randint(0, len(common_objects))] 233 | 234 | multiple_mask_multiv = [] 235 | 236 | for n, s in enumerate(scales): 237 | if flag == 1: 238 | masks_list = [] 239 | obj_list_ = obj_list_multiv[n] 240 | pano_mask_np_ = mask_multiv[n] 241 | 242 | bg_list = obj_list_[obj_list_ != obj] 243 | mask_np_pos_ = (pano_mask_np_ == obj).astype(np.int32) 244 | 245 | xs, ys = np.nonzero(mask_np_pos_) 246 | tmp_points = np.stack((xs, ys), axis=-1).astype(np.int32) 247 | mask_np_pos_ = rng.choice( 248 | tmp_points, self.mask_pos_num * self.mask_area_avgnum 249 | ) 250 | 251 | mask_np_pos_ = mask_np_pos_.reshape( 252 | self.mask_pos_num, self.mask_area_avgnum, 2 253 | ) 254 | 255 | masks_list.append(mask_np_pos_) 256 | 257 | for time_ in range(self.mask_neg_num): 258 | obj_neg = bg_list[np.random.randint(0, len(bg_list))] 259 | mask_np_neg = (pano_mask_np_ == obj_neg).astype(np.int32) 260 | xs, ys = np.nonzero(mask_np_neg) 261 | tmp_points = np.stack((xs, ys), axis=-1).astype(np.int32) 262 | mask_np_neg = rng.choice( 263 | tmp_points, self.mask_pos_num * self.mask_area_avgnum 264 | ) # 1+Negative, Avg_num, 2 265 | 266 | mask_np_neg = mask_np_neg.reshape( 267 | self.mask_pos_num, self.mask_area_avgnum, 2 268 | ) 269 | 270 | masks_list.append(mask_np_neg) 271 | 272 | else: 273 | masks_list = [] 274 | obj_list_ = obj_list_multiv[n] 275 | pano_mask_np_ = mask_multiv[n] 276 | 277 | for time_ in range(1 + self.mask_neg_num): 278 | obj = obj_list_[np.random.randint(0, len(obj_list_))] 279 | mask_np_ = (pano_mask_np_ == obj).astype(np.int32) 280 | 281 | xs, ys = np.nonzero(mask_np_) 282 | tmp_points = np.stack((xs, ys), axis=-1).astype(np.int32) 283 | mask_np_ = rng.choice( 284 | tmp_points, self.mask_pos_num * self.mask_area_avgnum 285 | ) 286 | mask_np_ = mask_np_.reshape( 287 | self.mask_pos_num, self.mask_area_avgnum, 2 288 | ) 289 | masks_list.append(mask_np_) 290 | 291 | masks_list = np.stack(masks_list, axis=0) 292 | masks_list = masks_list.reshape( 293 | (1 + self.mask_neg_num) * self.mask_pos_num, self.mask_area_avgnum, 2 294 | ) # Negative, Avg_num, 2 -- N, POS_NUM, AVG_NUM, 2 295 | multiple_mask_multiv.append(masks_list) 296 | 297 | cls_labels = self._imdb[index]["class"] 298 | return ( 299 | im_multiv, 300 | index, 301 | cls_labels, 302 | multiple_mask_multiv, 303 | flag, 304 | im_multiv_2, 305 | coord_multiv, 306 | coord_multiv_2, 307 | ) 308 | 309 | def __len__(self): 310 | return len(self._imdb) 311 | -------------------------------------------------------------------------------- /PLRC/models/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | import PLRC.models.plrc_loss as plrc_loss 10 | from PLRC.models.resnet import ResNetPLRC 11 | from PLRC.models.utils import * 12 | import torch.nn.functional as F 13 | 14 | 15 | class PLRC(nn.Module): 16 | def __init__(self, args, dim=128, K=65536, m=0.999, T=0.07, mlp=False): 17 | 18 | super(PLRC, self).__init__() 19 | 20 | self.K = K 21 | self.m = m 22 | self.T = T 23 | 24 | self.encoder_q = ResNetPLRC() 25 | self.encoder_k = ResNetPLRC() 26 | 27 | for param_q, param_k in zip( 28 | self.encoder_q.parameters(), self.encoder_k.parameters() 29 | ): 30 | param_k.data.copy_(param_q.data) 31 | param_k.requires_grad = False 32 | self.plrc_loss = plrc_loss.SupConLoss(args) 33 | self.plrc_loss_image = plrc_loss.NCEAverage(args) 34 | 35 | self.mask_size = 56 36 | self.mask_neg_num = 15 37 | self.mask_pos_num = 1 38 | self.mask_grid_num = 4 39 | self.mask_area_avgnum = 32 40 | 41 | self.keep_negs = True 42 | self.logis_sum = "exp_logits" 43 | self.logis_avg = False 44 | self.scale = False 45 | self.keep_point = 16 46 | 47 | @torch.no_grad() 48 | def _momentum_update_key_encoder(self): 49 | """ 50 | Momentum update of the key encoder 51 | """ 52 | for param_q, param_k in zip( 53 | self.encoder_q.parameters(), self.encoder_k.parameters() 54 | ): 55 | param_k.data = param_k.data * self.m + param_q.data * (1.0 - self.m) 56 | 57 | @torch.no_grad() 58 | def _batch_shuffle_ddp(self, x): 59 | """ 60 | Batch shuffle, for making use of BatchNorm. 61 | *** Only support DistributedDataParallel (DDP) model. *** 62 | """ 63 | # gather from all gpus 64 | batch_size_this = x.shape[0] 65 | x_gather = concat_all_gather(x) 66 | batch_size_all = x_gather.shape[0] 67 | 68 | num_gpus = batch_size_all // batch_size_this 69 | 70 | # random shuffle index 71 | idx_shuffle = torch.randperm(batch_size_all).cuda() 72 | 73 | # broadcast to all gpus 74 | torch.distributed.broadcast(idx_shuffle, src=0) 75 | 76 | # index for restoring 77 | idx_unshuffle = torch.argsort(idx_shuffle) 78 | 79 | # shuffled index for this gpu 80 | gpu_idx = torch.distributed.get_rank() 81 | idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx] 82 | 83 | return x_gather[idx_this], idx_unshuffle 84 | 85 | @torch.no_grad() 86 | def _batch_unshuffle_ddp(self, x, idx_unshuffle): 87 | """ 88 | Undo batch shuffle. 89 | *** Only support DistributedDataParallel (DDP) model. *** 90 | """ 91 | batch_size_this = x.shape[0] 92 | x_gather = concat_all_gather(x) 93 | batch_size_all = x_gather.shape[0] 94 | 95 | num_gpus = batch_size_all // batch_size_this 96 | gpu_idx = torch.distributed.get_rank() 97 | idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx] 98 | 99 | return x_gather[idx_this] 100 | 101 | def forward( 102 | self, 103 | xs, 104 | ids, 105 | cur_epoch, 106 | masks_all, 107 | obj, 108 | xs_2=None, 109 | coord_multiv=None, 110 | coord_multiv_2=None, 111 | ): 112 | 113 | input_shuffle = True 114 | fxs, fxs_pre, fxs_hist, fxs_hist_pre = [], [], [], [] 115 | gxs, pxs = [], [] 116 | fxs_moco, fxs_pre_moco, fxs_hist_moco, fxs_hist_pre_moco = [], [], [], [] 117 | gxs_moco, pxs_moco = [], [] 118 | 119 | fxs_ts, fxs_pre_ts, fxs_hist_ts, fxs_hist_pre_ts = [], [], [], [] 120 | fxs_ts_f, fxs_pre_ts_f, fxs_hist_ts_f, fxs_hist_pre_ts_f = [], [], [], [] 121 | 122 | for gx, gx_2, is_key, masks in zip(xs, xs_2, [0, 1], masks_all): 123 | 124 | if is_key: 125 | with torch.no_grad(): 126 | self._momentum_update_key_encoder() 127 | 128 | x, idx_unshuffle = self._batch_shuffle_ddp(gx) 129 | 130 | fx, fx_2_f, fx_2, px = self.encoder_k(x, mode="point", dy=True) 131 | 132 | fx = self._batch_unshuffle_ddp(fx, idx_unshuffle) 133 | 134 | if fx_2 is not None: 135 | fx_2 = self._batch_unshuffle_ddp(fx_2, idx_unshuffle) 136 | fx_2_f = self._batch_unshuffle_ddp(fx_2_f, idx_unshuffle) 137 | else: 138 | fx_2 = fx 139 | 140 | fx_moco = torch.mean(fx, dim=(2, 3)) 141 | 142 | H, W = self.mask_size, self.mask_size 143 | grids = masks 144 | 145 | grids[:, :, :, 0], grids[:, :, :, 1] = ( 146 | grids[:, :, :, 0] / (H - 1) * 2 - 1, 147 | grids[:, :, :, 1] / (W - 1) * 2 - 1, 148 | ) 149 | pixel_feat = nn.functional.grid_sample(fx, grids) 150 | 151 | B_pf, C_pf, N_pf, P_pf = pixel_feat.shape 152 | 153 | pixel = F.normalize( 154 | pixel_feat_q_arch.permute(0, 1, 3, 2)[:, :, :, 0].permute( 155 | 0, 2, 1 156 | ), 157 | dim=-1, 158 | p=2, 159 | ) 160 | pixel_top = F.normalize( 161 | pixel_feat[:, :, 0:1, :] 162 | .permute(0, 1, 3, 2)[:, :, :, 0] 163 | .permute(0, 2, 1), 164 | dim=-1, 165 | p=2, 166 | ) 167 | m = ( 168 | torch.matmul(pixel_top, pixel.permute(0, 2, 1)) 169 | .max(dim=2)[1] 170 | .unsqueeze(-1) 171 | .repeat(1, 1, C_pf) 172 | .cuda() 173 | ) 174 | pixel_feat_top = ( 175 | pixel_feat[:, :, 0:1, :] 176 | .permute(0, 1, 3, 2)[:, :, :, 0] 177 | .permute(0, 2, 1) 178 | .gather(dim=1, index=m) 179 | .unsqueeze(0) 180 | .permute(1, 3, 0, 2)[:, :, :, : self.keep_point] 181 | ) 182 | 183 | fx_append = pixel_feat[:, :, 0, :] 184 | 185 | fx_append = fx_append.permute(0, 2, 1).reshape(-1, C_pf) 186 | 187 | fxs_hist_pre.append(fx_append) 188 | fxs_hist_pre_moco.append(fx_moco) 189 | 190 | fxs_pre.append(pixel_feat_top) 191 | fxs.append(pixel_feat_top) 192 | fxs_hist.append(pixel_feat_top) 193 | 194 | B_fx, C_fx, H_fx, W_fx = fx.shape 195 | fx = fx.permute(0, 2, 3, 1).reshape(B_fx * H_fx * W_fx, C_fx) 196 | 197 | fx_2 = torch.nn.functional.interpolate( 198 | fx_2, size=(self.mask_size, self.mask_size), mode="bilinear" 199 | ) 200 | B_fx_, C_fx_, H_fx_, W_fx_ = fx_2.shape 201 | 202 | fx_2 = fx_2.permute(0, 2, 3, 1).reshape(B_fx_ * H_fx_ * W_fx_, C_fx_) 203 | 204 | fxs_pre_ts.append(fx_2) 205 | fxs_ts.append(fx_2) 206 | fxs_hist_ts.append(fx_2) 207 | 208 | fx_2_f = torch.nn.functional.interpolate( 209 | fx_2_f, size=(self.mask_size, self.mask_size), mode="bilinear" 210 | ) 211 | B_fx_f, C_fx_f, H_fx_f, W_fx_f = fx_2_f.shape 212 | 213 | fx_2_f = fx_2_f.permute(0, 2, 3, 1).reshape( 214 | B_fx_f * H_fx_f * W_fx_f, C_fx_f 215 | ) 216 | 217 | fxs_pre_ts_f.append(fx_2_f) 218 | fxs_ts_f.append(fx_2_f) 219 | fxs_hist_ts_f.append(fx_2_f) 220 | 221 | fxs_pre_moco.append(fx_moco) 222 | fxs_moco.append(fx_moco) 223 | fxs_hist_moco.append(fx_moco) 224 | 225 | else: 226 | fx_q, fx_q_2_f, fx_q_2, px = self.encoder_q(gx, mode="point") 227 | 228 | if fx_q_2 is None: 229 | fx_q_2 = fx_q 230 | fx_moco = torch.mean(fx_q, dim=(2, 3)) 231 | H, W = self.mask_size, self.mask_size 232 | grids_q = masks.clone() 233 | grids_q[:, :, :, 0], grids_q[:, :, :, 1] = ( 234 | grids_q[:, :, :, 0] / (H - 1) * 2 - 1, 235 | grids_q[:, :, :, 1] / (W - 1) * 2 - 1, 236 | ) 237 | pixel_feat_q = nn.functional.grid_sample(fx_q, grids_q) 238 | B_pf, C_pf, N_pf, P_pf = pixel_feat_q.shape 239 | pixel_feat_q_arch = pixel_feat_q[:, :, 0:1, :] 240 | 241 | fxs_pre.append(pixel_feat_q[:, :, 0:1, : self.keep_point]) 242 | 243 | fxs.append(pixel_feat_q[:, :, 0:1, : self.keep_point]) 244 | 245 | fx_q_2 = torch.nn.functional.interpolate( 246 | fx_q_2, size=(self.mask_size, self.mask_size), mode="bilinear" 247 | ) 248 | B_fxq, C_fxq, H_fxq, W_fxq = fx_q_2.shape 249 | fx_q_2 = fx_q_2.permute(0, 2, 3, 1).reshape( 250 | B_fxq * H_fxq * W_fxq, C_fxq 251 | ) 252 | 253 | fxs_pre_ts.append(fx_q_2) 254 | 255 | fxs_ts.append(fx_q_2) 256 | 257 | fx_q_2_f = torch.nn.functional.interpolate( 258 | fx_q_2_f, size=(self.mask_size, self.mask_size), mode="bilinear" 259 | ) 260 | B_fxq_f, C_fxq_f, H_fxq_f, W_fxq_f = fx_q_2_f.shape 261 | fx_q_2_f = fx_q_2_f.permute(0, 2, 3, 1).reshape( 262 | B_fxq_f * H_fxq_f * W_fxq_f, C_fxq_f 263 | ) 264 | fxs_pre_ts_f.append(fx_q_2_f) 265 | 266 | fxs_ts_f.append(fx_q_2_f) 267 | 268 | gxs.append(gx) 269 | pxs.append(px) 270 | 271 | with torch.no_grad(): 272 | if input_shuffle and self.training: 273 | x, idx_restore = self._batch_shuffle_ddp(gx_2) 274 | else: 275 | x = gx_2 276 | fx, fx_2_f, fx_2, px = self.encoder_k(x, mode="point") 277 | 278 | if input_shuffle and self.training: 279 | fx = self._batch_unshuffle_ddp(fx, idx_restore) 280 | if fx_2 is not None: 281 | fx_2 = self._batch_unshuffle_ddp(fx_2, idx_restore) 282 | 283 | fx_2_f = self._batch_unshuffle_ddp(fx_2_f, idx_restore) 284 | else: 285 | fx_2 = fx 286 | 287 | grids = masks.clone() 288 | 289 | grids[:, :, :, 0], grids[:, :, :, 1] = ( 290 | grids[:, :, :, 0] / (H - 1) * 2 - 1, 291 | grids[:, :, :, 1] / (W - 1) * 2 - 1, 292 | ) 293 | 294 | pixel_feat = nn.functional.grid_sample(fx, grids) 295 | B_pf, C_pf, N_pf, P_pf = pixel_feat.shape 296 | 297 | fxs_pre.append(pixel_feat) 298 | fxs.append(pixel_feat) 299 | fxs_hist.append(pixel_feat) 300 | 301 | B_fx, C_fx, H_fx, W_fx = fx.shape 302 | fx = fx.permute(0, 2, 3, 1).reshape(B_fx * H_fx * W_fx, C_fx) 303 | 304 | fx_2 = torch.nn.functional.interpolate( 305 | fx_2, size=(self.mask_size, self.mask_size), mode="bilinear" 306 | ) 307 | B_fx_, C_fx_, H_fx_, W_fx_ = fx_2.shape 308 | fx_2 = fx_2.permute(0, 2, 3, 1).reshape(B_fx_ * H_fx_ * W_fx_, C_fx_) 309 | 310 | fxs_pre_ts.append(fx_2) 311 | fxs_ts.append(fx_2) 312 | fxs_hist_ts.append(fx_2) 313 | 314 | fx_2_f = torch.nn.functional.interpolate( 315 | fx_2_f, size=(self.mask_size, self.mask_size), mode="bilinear" 316 | ) 317 | B_fx_f, C_fx_f, H_fx_f, W_fx_f = fx_2_f.shape 318 | fx_2_f = fx_2_f.permute(0, 2, 3, 1).reshape( 319 | B_fx_f * H_fx_f * W_fx_f, C_fx_f 320 | ) 321 | 322 | fxs_pre_ts_f.append(fx_2_f) 323 | fxs_ts_f.append(fx_2_f) 324 | fxs_hist_ts_f.append(fx_2_f) 325 | fxs_pre_moco.append(fx_moco) 326 | fxs_moco.append(fx_moco) 327 | 328 | loss_image = self.plrc_loss_image( 329 | fxs_moco, 330 | fxs_pre_moco, 331 | fxs_hist_moco, 332 | fxs_hist_pre_moco, 333 | ids, 334 | gxs_moco, 335 | pxs_moco, 336 | None, 337 | self.training, 338 | cur_epoch, 339 | mode="image", 340 | ) 341 | 342 | loss_point = self.plrc_loss( 343 | fxs, 344 | fxs_pre, 345 | fxs_hist, 346 | fxs_hist_pre, 347 | ids, 348 | gxs, 349 | pxs, 350 | None, 351 | self.training, 352 | cur_epoch, 353 | mode="point", 354 | len_obj=obj, 355 | fxs_ts=fxs_ts, 356 | fxs_pre_ts=fxs_pre_ts, 357 | fxs_hist_ts=fxs_hist_ts, 358 | fxs_hist_pre_ts=fxs_hist_pre_ts, 359 | coord_multiv=coord_multiv, 360 | coord_multiv_2=coord_multiv_2, 361 | fxs_ts_f=fxs_ts_f, 362 | fxs_pre_ts_f=fxs_pre_ts_f, 363 | fxs_hist_ts_f=fxs_hist_ts_f, 364 | fxs_hist_pre_ts_f=fxs_hist_pre_ts_f, 365 | ) 366 | 367 | return loss_image, loss_point 368 | 369 | 370 | # utils 371 | @torch.no_grad() 372 | def concat_all_gather(tensor): 373 | """ 374 | Performs all_gather operation on the provided tensors. 375 | *** Warning ***: torch.distributed.all_gather has no gradient. 376 | """ 377 | tensors_gather = [ 378 | torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) 379 | ] 380 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 381 | 382 | output = torch.cat(tensors_gather, dim=0) 383 | return output 384 | -------------------------------------------------------------------------------- /main_plrc.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | import argparse 10 | import builtins 11 | import math 12 | import os 13 | import random 14 | import shutil 15 | import time 16 | import warnings 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.parallel 21 | import torch.backends.cudnn as cudnn 22 | import torch.distributed as dist 23 | import torch.optim 24 | import torch.multiprocessing as mp 25 | import torch.utils.data 26 | import torch.utils.data.distributed 27 | import torchvision.transforms as transforms 28 | import torchvision.datasets as datasets 29 | import torchvision.models as models 30 | 31 | from PLRC.datasets.imagenet import ImageNet 32 | from PLRC.models.builder import PLRC 33 | 34 | 35 | parser = argparse.ArgumentParser(description="PyTorch ImageNet Training") 36 | parser.add_argument("data", metavar="DIR", help="path to dataset") 37 | parser.add_argument( 38 | "-j", 39 | "--workers", 40 | default=32, 41 | type=int, 42 | metavar="N", 43 | help="number of data loading workers (default: 32)", 44 | ) 45 | parser.add_argument( 46 | "--epochs", default=100, type=int, metavar="N", help="number of total epochs to run" 47 | ) 48 | parser.add_argument( 49 | "--start-epoch", 50 | default=0, 51 | type=int, 52 | metavar="N", 53 | help="manual epoch number (useful on restarts)", 54 | ) 55 | parser.add_argument( 56 | "-b", 57 | "--batch-size", 58 | default=256, 59 | type=int, 60 | metavar="N", 61 | help="mini-batch size (default: 256), this is the total " 62 | "batch size of all GPUs on the current node when " 63 | "using Data Parallel or Distributed Data Parallel", 64 | ) 65 | parser.add_argument( 66 | "--lr", 67 | "--learning-rate", 68 | default=0.03, 69 | type=float, 70 | metavar="LR", 71 | help="initial learning rate", 72 | dest="lr", 73 | ) 74 | parser.add_argument( 75 | "--schedule", 76 | default=[120, 160], 77 | nargs="*", 78 | type=int, 79 | help="learning rate schedule (when to drop lr by 10x)", 80 | ) 81 | parser.add_argument( 82 | "--momentum", default=0.9, type=float, metavar="M", help="momentum of SGD solver" 83 | ) 84 | parser.add_argument( 85 | "--wd", 86 | "--weight-decay", 87 | default=1e-4, 88 | type=float, 89 | metavar="W", 90 | help="weight decay (default: 1e-4)", 91 | dest="weight_decay", 92 | ) 93 | parser.add_argument( 94 | "-p", 95 | "--print-freq", 96 | default=10, 97 | type=int, 98 | metavar="N", 99 | help="print frequency (default: 10)", 100 | ) 101 | parser.add_argument( 102 | "--resume", 103 | default="", 104 | type=str, 105 | metavar="PATH", 106 | help="path to latest checkpoint (default: none)", 107 | ) 108 | parser.add_argument( 109 | "--world-size", 110 | default=-1, 111 | type=int, 112 | help="number of nodes for distributed training", 113 | ) 114 | parser.add_argument( 115 | "--rank", default=-1, type=int, help="node rank for distributed training" 116 | ) 117 | parser.add_argument( 118 | "--dist-url", 119 | default="tcp://224.66.41.62:23456", 120 | type=str, 121 | help="url used to set up distributed training", 122 | ) 123 | parser.add_argument( 124 | "--dist-backend", default="nccl", type=str, help="distributed backend" 125 | ) 126 | parser.add_argument( 127 | "--seed", default=None, type=int, help="seed for initializing training. " 128 | ) 129 | parser.add_argument("--gpu", default=None, type=int, help="GPU id to use.") 130 | parser.add_argument( 131 | "--multiprocessing-distributed", 132 | action="store_true", 133 | help="Use multi-processing distributed training to launch " 134 | "N processes per node, which has N GPUs. This is the " 135 | "fastest way to use PyTorch for either single node or " 136 | "multi node data parallel training", 137 | ) 138 | parser.add_argument( 139 | "--moco-dim", default=128, type=int, help="feature dimension (default: 128)" 140 | ) 141 | parser.add_argument( 142 | "--moco-k", 143 | default=1024, 144 | type=int, 145 | help="queue size; number of negative keys (default: 65536)", 146 | ) 147 | parser.add_argument( 148 | "--moco-m", 149 | default=0.999, 150 | type=float, 151 | help="moco momentum of updating key encoder (default: 0.999)", 152 | ) 153 | parser.add_argument( 154 | "--moco-t", default=0.07, type=float, help="softmax temperature (default: 0.07)" 155 | ) 156 | 157 | 158 | # plrc configs: 159 | parser.add_argument( 160 | "--ts_ratio", 161 | default=0.3, 162 | type=float, 163 | help="balancing factor for self-distillation loss", 164 | ) 165 | parser.add_argument( 166 | "--cl_ratio", 167 | default=0.3, 168 | type=float, 169 | help="balancing factor for point-level contrastive loss", 170 | ) 171 | parser.add_argument( 172 | "--im_ratio", 173 | default=0.7, 174 | type=float, 175 | help="balancing factor for image-level contrastive loss", 176 | ) 177 | 178 | 179 | def main(): 180 | args = parser.parse_args() 181 | 182 | if args.seed is not None: 183 | random.seed(args.seed) 184 | torch.manual_seed(args.seed) 185 | cudnn.deterministic = True 186 | warnings.warn( 187 | "You have chosen to seed training. " 188 | "This will turn on the CUDNN deterministic setting, " 189 | "which can slow down your training considerably! " 190 | "You may see unexpected behavior when restarting " 191 | "from checkpoints." 192 | ) 193 | 194 | if args.gpu is not None: 195 | warnings.warn( 196 | "You have chosen a specific GPU. This will completely " 197 | "disable data parallelism." 198 | ) 199 | 200 | if args.dist_url == "env://" and args.world_size == -1: 201 | args.world_size = int(os.environ["WORLD_SIZE"]) 202 | 203 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 204 | 205 | ngpus_per_node = torch.cuda.device_count() 206 | if args.multiprocessing_distributed: 207 | # Since we have ngpus_per_node processes per node, the total world_size 208 | # needs to be adjusted accordingly 209 | args.world_size = ngpus_per_node * args.world_size 210 | # Use torch.multiprocessing.spawn to launch distributed processes: the 211 | # main_worker process function 212 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 213 | else: 214 | # Simply call main_worker function 215 | main_worker(args.gpu, ngpus_per_node, args) 216 | 217 | 218 | def main_worker(gpu, ngpus_per_node, args): 219 | args.gpu = gpu 220 | 221 | # suppress printing if not master 222 | if args.multiprocessing_distributed and args.gpu != 0: 223 | 224 | def print_pass(*args): 225 | pass 226 | 227 | builtins.print = print_pass 228 | 229 | if args.gpu is not None: 230 | print("Use GPU: {} for training".format(args.gpu)) 231 | 232 | if args.distributed: 233 | if args.dist_url == "env://" and args.rank == -1: 234 | args.rank = int(os.environ["RANK"]) 235 | if args.multiprocessing_distributed: 236 | # For multiprocessing distributed training, rank needs to be the 237 | # global rank among all the processes 238 | args.rank = args.rank * ngpus_per_node + gpu 239 | dist.init_process_group( 240 | backend=args.dist_backend, 241 | init_method=args.dist_url, 242 | world_size=args.world_size, 243 | rank=args.rank, 244 | ) 245 | # create model 246 | model = PLRC(args) 247 | print(model) 248 | 249 | if args.distributed: 250 | if args.gpu is not None: 251 | torch.cuda.set_device(args.gpu) 252 | model.cuda(args.gpu) 253 | args.batch_size = int(args.batch_size / ngpus_per_node) 254 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 255 | model = torch.nn.parallel.DistributedDataParallel( 256 | model, device_ids=[args.gpu] 257 | ) 258 | else: 259 | model.cuda() 260 | model = torch.nn.parallel.DistributedDataParallel(model) 261 | elif args.gpu is not None: 262 | torch.cuda.set_device(args.gpu) 263 | model = model.cuda(args.gpu) 264 | raise NotImplementedError("Only DistributedDataParallel is supported.") 265 | else: 266 | raise NotImplementedError("Only DistributedDataParallel is supported.") 267 | 268 | # define loss function (criterion) and optimizer 269 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 270 | 271 | optimizer = torch.optim.SGD( 272 | model.parameters(), 273 | args.lr, 274 | momentum=args.momentum, 275 | weight_decay=args.weight_decay, 276 | ) 277 | 278 | # optionally resume from a checkpoint 279 | if args.resume: 280 | if os.path.isfile(args.resume): 281 | print("=> loading checkpoint '{}'".format(args.resume)) 282 | if args.gpu is None: 283 | checkpoint = torch.load(args.resume) 284 | else: 285 | # Map model to be loaded to specified single gpu. 286 | loc = "cuda:{}".format(args.gpu) 287 | checkpoint = torch.load(args.resume, map_location=loc) 288 | args.start_epoch = checkpoint["epoch"] 289 | model.load_state_dict(checkpoint["state_dict"]) 290 | optimizer.load_state_dict(checkpoint["optimizer"]) 291 | print( 292 | "=> loaded checkpoint '{}' (epoch {})".format( 293 | args.resume, checkpoint["epoch"] 294 | ) 295 | ) 296 | else: 297 | print("=> no checkpoint found at '{}'".format(args.resume)) 298 | 299 | cudnn.benchmark = True 300 | 301 | train_dataset = ImageNet(split="train", args=args, path=args.data) 302 | 303 | if args.distributed: 304 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 305 | else: 306 | train_sampler = None 307 | 308 | train_loader = torch.utils.data.DataLoader( 309 | train_dataset, 310 | batch_size=args.batch_size, 311 | shuffle=(train_sampler is None), 312 | num_workers=args.workers, 313 | pin_memory=True, 314 | sampler=train_sampler, 315 | drop_last=True, 316 | ) 317 | 318 | for epoch in range(args.start_epoch, args.epochs): 319 | if args.distributed: 320 | train_sampler.set_epoch(epoch) 321 | adjust_learning_rate(optimizer, epoch, args) 322 | 323 | # train for one epoch 324 | train(train_loader, model, criterion, optimizer, epoch, args) 325 | 326 | if not args.multiprocessing_distributed or ( 327 | args.multiprocessing_distributed and args.rank % ngpus_per_node == 0 328 | ): 329 | save_checkpoint( 330 | { 331 | "epoch": epoch + 1, 332 | "arch": "resnet50", 333 | "state_dict": model.state_dict(), 334 | "optimizer": optimizer.state_dict(), 335 | }, 336 | is_best=False, 337 | filename="checkpoint_{:04d}.pth.tar".format(epoch), 338 | ) 339 | 340 | 341 | def train(train_loader, model, criterion, optimizer, cur_epoch, args): 342 | batch_time = AverageMeter("Time", ":6.3f") 343 | data_time = AverageMeter("Data", ":6.3f") 344 | losses = AverageMeter("Loss", ":.4e") 345 | progress = ProgressMeter( 346 | len(train_loader), 347 | [batch_time, data_time, losses], 348 | prefix="Epoch: [{}]".format(cur_epoch), 349 | ) 350 | 351 | # switch to train mode 352 | model.train() 353 | 354 | end = time.time() 355 | for cur_iter, ( 356 | inputs, 357 | ids, 358 | cls_labels, 359 | masks, 360 | obj, 361 | inputs_2, 362 | coord_multiv, 363 | coord_multiv_2, 364 | ) in enumerate(train_loader): 365 | 366 | data_time.update(time.time() - end) 367 | 368 | if args.gpu is not None: 369 | 370 | inputs_cuda = [] 371 | masks_cuda = [] 372 | inputs_2_cuda = [] 373 | for i in range(len(inputs)): 374 | ii = inputs[i].float().cuda(non_blocking=True) 375 | jj = masks[i].float().cuda(non_blocking=True) 376 | kk = inputs_2[i].float().cuda(non_blocking=True) 377 | # hack to squeeze the extra dimension, need to think about loss balancing if there are multiple 378 | if len(ii.shape) == 5: 379 | _, dm, dc, dh, dw = ii.shape 380 | _, dm, dc, dh, dw = jj.shape 381 | ii = torch.reshape(ii, (-1, dc, dh, dw)) 382 | jj = torch.reshape(jj, (-1, dc, dh, dw)) 383 | kk = torch.reshape(jj, (-1, dc, dh, dw)) 384 | inputs_cuda.append(ii) 385 | masks_cuda.append(jj) 386 | inputs_2_cuda.append(kk) 387 | del inputs, masks, inputs_2 388 | inputs = inputs_cuda 389 | masks = masks_cuda 390 | inputs_2 = inputs_2_cuda 391 | ids = ids.cuda(non_blocking=True) 392 | 393 | loss_point, loss_moco = model( 394 | inputs, ids, cur_epoch, masks, obj, inputs_2, coord_multiv, coord_multiv_2 395 | ) 396 | 397 | loss_vec_point, to_vis_point = loss_point 398 | 399 | loss_vec_moco, to_vis_moco = loss_moco 400 | 401 | loss = loss_vec_point.sum() + loss_vec_moco.sum() * args.im_ratio 402 | 403 | losses.update(loss.item(), cls_labels.shape[0]) 404 | 405 | optimizer.zero_grad() 406 | loss.backward() 407 | optimizer.step() 408 | 409 | # measure elapsed time 410 | batch_time.update(time.time() - end) 411 | end = time.time() 412 | 413 | if i % args.print_freq == 0: 414 | progress.display(i) 415 | 416 | 417 | def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"): 418 | torch.save(state, filename) 419 | if is_best: 420 | shutil.copyfile(filename, "model_best.pth.tar") 421 | 422 | 423 | class AverageMeter(object): 424 | """Computes and stores the average and current value""" 425 | 426 | def __init__(self, name, fmt=":f"): 427 | self.name = name 428 | self.fmt = fmt 429 | self.reset() 430 | 431 | def reset(self): 432 | self.val = 0 433 | self.avg = 0 434 | self.sum = 0 435 | self.count = 0 436 | 437 | def update(self, val, n=1): 438 | self.val = val 439 | self.sum += val * n 440 | self.count += n 441 | self.avg = self.sum / self.count 442 | 443 | def __str__(self): 444 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" 445 | return fmtstr.format(**self.__dict__) 446 | 447 | 448 | class ProgressMeter(object): 449 | def __init__(self, num_batches, meters, prefix=""): 450 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 451 | self.meters = meters 452 | self.prefix = prefix 453 | 454 | def display(self, batch): 455 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 456 | entries += [str(meter) for meter in self.meters] 457 | print("\t".join(entries)) 458 | 459 | def _get_batch_fmtstr(self, num_batches): 460 | num_digits = len(str(num_batches // 1)) 461 | fmt = "{:" + str(num_digits) + "d}" 462 | return "[" + fmt + "/" + fmt.format(num_batches) + "]" 463 | 464 | 465 | def adjust_learning_rate(optimizer, epoch, args): 466 | """Decay the learning rate based on schedule""" 467 | lr = args.lr 468 | lr *= 0.5 * (1.0 + math.cos(math.pi * epoch / args.epochs)) 469 | for param_group in optimizer.param_groups: 470 | param_group["lr"] = lr 471 | 472 | 473 | if __name__ == "__main__": 474 | main() 475 | -------------------------------------------------------------------------------- /PLRC/models/plrc_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from __future__ import absolute_import, division, print_function, unicode_literals 8 | 9 | import math 10 | from scipy.stats import ortho_group 11 | 12 | import torch 13 | from torch import nn 14 | import torch.nn.functional as F 15 | 16 | import PLRC.models.builder as builder 17 | 18 | import torch.distributed as du 19 | 20 | import sys 21 | import numpy as np 22 | 23 | import matplotlib.pyplot as plt 24 | 25 | from PIL import Image 26 | import random 27 | from PLRC.models.utils import * 28 | 29 | 30 | class General_Loss(nn.Module): 31 | def __init__(self, args, apply=False): 32 | super().__init__() 33 | self.skip_first_few = False 34 | self.T = 0.2 35 | self.total_batch_size = args.batch_size 36 | self.world_size = args.world_size 37 | self.per_gpu_batch_size = self.total_batch_size // self.world_size 38 | self.gpu_rank = args.rank 39 | self.stop_grad_query = False 40 | self.stop_grad_key = False 41 | assert not self.stop_grad_key or not self.stop_grad_query 42 | self.loss_input_dim = 128 43 | self.crops_per_iter = 2 44 | self.apply = apply 45 | 46 | self.mask_size = 56 47 | self.mask_neg_num = 15 48 | self.mask_pos_num = 1 49 | self.mask_grid_num = 4 50 | self.mask_area_avgnum = 32 51 | 52 | self.ts_ratio = args.ts_ratio 53 | self.cl_ratio = args.cl_ratio 54 | self.im_ratio = args.im_ratio 55 | 56 | self.keep_negs = True 57 | self.logis_sum = "exp_logits" 58 | self.logis_avg = False 59 | self.scale = False 60 | self.keep_point = 16 61 | 62 | def forward( 63 | self, 64 | fxs, 65 | fxs_pre, 66 | fxs_hist, 67 | fxs_hist_pre, 68 | ids, 69 | gxs, 70 | pxs, 71 | rxs, 72 | is_train, 73 | cur_epoch, 74 | fxs_moco=None, 75 | fxs_pre_moco=None, 76 | fxs_hist_moco=None, 77 | fxs_hist_pre_moco=None, 78 | mix_image=None, 79 | mode="point", 80 | len_obj=None, 81 | fxs_ts=None, 82 | fxs_pre_ts=None, 83 | fxs_hist_ts=None, 84 | fxs_hist_pre_ts=None, 85 | coord_multiv=None, 86 | coord_multiv_2=None, 87 | fxs_ts_f=None, 88 | fxs_pre_ts_f=None, 89 | fxs_hist_ts_f=None, 90 | fxs_hist_pre_ts_f=None, 91 | ): 92 | if mode == "point": 93 | losses = self.forward_cont( 94 | fxs, 95 | fxs_pre, 96 | fxs_hist, 97 | fxs_hist_pre, 98 | ids, 99 | is_train, 100 | cur_epoch, 101 | mode=mode, 102 | len_obj=len_obj, 103 | ) 104 | to_vis = None 105 | losses_point, losses_moco = losses 106 | losses_point.append(torch.zeros(1, dtype=torch.float32).cuda()) 107 | losses_point.append(torch.zeros(1, dtype=torch.float32).cuda()) 108 | 109 | losses_moco.append(torch.zeros(1, dtype=torch.float32).cuda()) 110 | losses_moco.append(torch.zeros(1, dtype=torch.float32).cuda()) 111 | return losses_point, losses_moco, to_vis 112 | 113 | else: 114 | losses = self.forward_cont( 115 | fxs, 116 | fxs_pre, 117 | fxs_hist, 118 | fxs_hist_pre, 119 | ids, 120 | is_train, 121 | cur_epoch, 122 | mode=mode, 123 | len_obj=len_obj, 124 | fxs_ts=fxs_ts, 125 | fxs_pre_ts=fxs_pre_ts, 126 | fxs_hist_ts=fxs_hist_ts, 127 | fxs_hist_pre_ts=fxs_hist_pre_ts, 128 | coord_multiv=coord_multiv, 129 | coord_multiv_2=coord_multiv_2, 130 | fxs_ts_f=fxs_ts_f, 131 | fxs_pre_ts_f=fxs_pre_ts_f, 132 | fxs_hist_ts_f=fxs_hist_ts_f, 133 | fxs_hist_pre_ts_f=fxs_hist_pre_ts_f, 134 | ) 135 | to_vis = None 136 | 137 | losses.append(torch.zeros(1, dtype=torch.float32).cuda()) 138 | losses.append(torch.zeros(1, dtype=torch.float32).cuda()) 139 | return losses, to_vis 140 | 141 | def forward_cont(self, fxs, fxs_pre, fxs_hist, fxs_hist_pre, ids, is_train): 142 | raise NotImplementedError 143 | 144 | 145 | class Memory_Loss(General_Loss): 146 | def __init__(self, args, apply=False): 147 | super().__init__(args) 148 | self.skip_first_few = True 149 | self.dim = 128 150 | dataset_size = 1281167 151 | self.Ks = [65536] 152 | self.build_loss() 153 | self.self_distillation = SelfDistillation( 154 | args, 155 | out_dim=self.mask_size * self.mask_size, 156 | warmup_teacher_temp=0.04, 157 | teacher_temp=0.07, 158 | warmup_teacher_temp_epochs=30, 159 | nepochs=100, 160 | student_temp=0.1, 161 | center_momentum=0.9, 162 | ) 163 | 164 | stdv = 1.0 / math.sqrt(self.dim / 3) 165 | 166 | for ni, ns in enumerate([65536]): 167 | append = "" 168 | if ni > 0: 169 | append = "_%d" % ni 170 | self.register_buffer( 171 | "ptr" + append, 172 | torch.zeros( 173 | [ 174 | 1, 175 | ] 176 | ), 177 | ) 178 | self.register_buffer( 179 | "queue_x" + append, torch.rand(ns, self.dim).mul_(2 * stdv).add_(-stdv) 180 | ) 181 | getattr(self, "ptr" + append).requires_grad = False 182 | getattr(self, "queue_x" + append).requires_grad = False 183 | 184 | def build_loss(self): 185 | raise NotImplementedError 186 | 187 | def compute_cont_loss(self, out_pos, out_neg, targets): 188 | raise NotImplementedError 189 | 190 | def build_positive(self, q, p, nq, np, ncrops): 191 | return torch.einsum("nbc,nc->nb", [q.view(np, ncrops, -1), p]).view(nq, 1) 192 | 193 | def build_positive_points(self, q, p, nq, np, ncrops): 194 | assert ncrops == 1 195 | assert nq == np 196 | B, N, C = q.shape 197 | assert B == p.shape[0] and C == p.shape[2] and N == p.shape[1] 198 | return torch.matmul(q, p.permute(0, 2, 1)) 199 | 200 | def build_negative(self, q, append): 201 | return torch.einsum( 202 | "nc,kc->nk", [q, getattr(self, "queue_x" + append).clone().detach()] 203 | ) 204 | 205 | def update_memory(self, fxs_hist): 206 | with torch.no_grad(): 207 | 208 | for ki, ni in zip(fxs_hist, [0]): 209 | 210 | append = "" 211 | if ni > 0: 212 | append = "_%d" % ni 213 | 214 | ptr = int(getattr(self, "ptr" + append).item()) 215 | if self.world_size > 1: 216 | ki = builder.concat_all_gather(ki) 217 | 218 | num_items = ki.size(0) 219 | K = self.Ks[ni] # 66 220 | 221 | if num_items > K - ptr: 222 | num_items = K - ptr 223 | getattr(self, "queue_x" + append)[ptr : (ptr + num_items), :] = ki[ 224 | :num_items 225 | ] 226 | ptr += num_items 227 | if ptr == K: 228 | ptr = 0 229 | 230 | getattr(self, "ptr" + append)[0] = ptr 231 | 232 | def forward( 233 | self, 234 | fxs, 235 | fxs_pre, 236 | fxs_hist, 237 | fxs_hist_pre, 238 | ids, 239 | gxs, 240 | pxs, 241 | rxs, 242 | is_train, 243 | cur_epoch, 244 | fxs_moco=None, 245 | fxs_pre_moco=None, 246 | fxs_hist_moco=None, 247 | fxs_hist_pre_moco=None, 248 | mix_image=None, 249 | mode="point", 250 | len_obj=None, 251 | fxs_ts=None, 252 | fxs_pre_ts=None, 253 | fxs_hist_ts=None, 254 | fxs_hist_pre_ts=None, 255 | coord_multiv=None, 256 | coord_multiv_2=None, 257 | fxs_ts_f=None, 258 | fxs_pre_ts_f=None, 259 | fxs_hist_ts_f=None, 260 | fxs_hist_pre_ts_f=None, 261 | ): 262 | 263 | if mode == "point": 264 | # Still needs to change 265 | fxs = [F.normalize(fx, dim=1) for fx in fxs] 266 | fxs_pre = [F.normalize(fx, dim=1) for fx in fxs_pre] 267 | fxs_hist_pre = [F.normalize(fx, dim=1) for fx in fxs_hist_pre] 268 | 269 | loss = torch.zeros(1, dtype=torch.float32).cuda() 270 | loss_ts = torch.zeros(1, dtype=torch.float32).cuda() 271 | loss_cl = torch.zeros(1, dtype=torch.float32).cuda() 272 | 273 | for ii, tri in enumerate([[0, 1, 0]]): 274 | qi, p_2i, pi, ni = 0, 1, 2, 0 275 | q = fxs[qi].detach() if self.stop_grad_query else fxs[qi] 276 | p_2 = fxs_pre[p_2i].detach() if self.stop_grad_key else fxs_pre[p_2i] 277 | p = fxs_pre[pi].detach() if self.stop_grad_key else fxs_pre[pi] 278 | 279 | q_ts = ( 280 | fxs_pre_ts[qi].detach() if self.stop_grad_query else fxs_pre_ts[qi] 281 | ) 282 | 283 | p_2_ts = ( 284 | fxs_pre_ts[p_2i].detach() 285 | if self.stop_grad_key 286 | else fxs_pre_ts[p_2i] 287 | ) 288 | p_ts = fxs_pre_ts[pi].detach() if self.stop_grad_key else fxs_pre_ts[pi] 289 | 290 | q_ts_f = ( 291 | fxs_pre_ts_f[qi].detach() 292 | if self.stop_grad_query 293 | else fxs_pre_ts_f[qi] 294 | ) 295 | p_2_ts_f = ( 296 | fxs_pre_ts_f[p_2i].detach() 297 | if self.stop_grad_key 298 | else fxs_pre_ts_f[p_2i] 299 | ) 300 | p_ts_f = ( 301 | fxs_pre_ts_f[pi].detach() 302 | if self.stop_grad_key 303 | else fxs_pre_ts_f[pi] 304 | ) 305 | 306 | append = "" 307 | if ni > 0: 308 | append = "_%d" % ni 309 | nq = q.shape[0] 310 | np = p.shape[0] 311 | assert nq >= np 312 | ncrops = nq // np 313 | 314 | loss_cl += self.compute_cont_loss(p, q) 315 | 316 | _, c_ = q_ts.shape 317 | 318 | q_ts = q_ts.reshape(-1, self.mask_size * self.mask_size, c_) # B, 49, C 319 | 320 | p_ts = p_ts.reshape(-1, self.mask_size * self.mask_size, c_) 321 | 322 | p_2_ts = p_2_ts.reshape(-1, self.mask_size * self.mask_size, c_) 323 | 324 | _, c_f = q_ts_f.shape 325 | q_ts_f = q_ts_f.reshape( 326 | -1, self.mask_size * self.mask_size, c_f 327 | ) # B, 49, C 328 | 329 | p_ts_f = p_ts_f.reshape(-1, self.mask_size * self.mask_size, c_f) 330 | 331 | p_2_ts_f = p_2_ts_f.reshape(-1, self.mask_size * self.mask_size, c_f) 332 | 333 | overlap_mask = overlap( 334 | q_ts, p_2_ts, coord_multiv[0], coord_multiv_2[0], self.mask_size 335 | ) 336 | 337 | b_, hw, hw_ = torch.nonzero( 338 | overlap_mask, as_tuple=True 339 | ) # B, 49, 49 --- B, q_ts, p_2_ts 340 | 341 | out_pos_f = self.build_positive_points( 342 | q_ts_f, p_ts_f, nq, np, ncrops 343 | ) # Bx49*49 344 | 345 | out_pos_feat_f = out_pos_f[b_, hw, :] # Mx49 346 | 347 | out_pos_t_f = self.build_positive_points( 348 | p_2_ts_f, p_ts_f, nq, np, ncrops 349 | ) # BxB 350 | 351 | out_pos_t_feat_f = out_pos_t_f[b_, hw_, :] # Mx49 352 | 353 | out_pos_feat_f = out_pos_feat_f / 128.0 354 | out_pos_t_feat_f = out_pos_t_feat_f / 128.0 355 | 356 | q_ts_feat = q_ts[b_, hw, :] 357 | 358 | q_ts_feat = torch.cat([out_pos_feat_f, q_ts_feat], dim=-1) 359 | 360 | p_2_ts_feat = p_2_ts[b_, hw_, :] 361 | 362 | p_2_ts_feat = torch.cat([out_pos_t_feat_f, p_2_ts_feat], dim=-1) 363 | 364 | loss_ts += self.self_distillation( 365 | student_output=q_ts_feat, 366 | teacher_output=p_2_ts_feat, 367 | epoch=cur_epoch, 368 | ) 369 | 370 | loss += self.ts_ratio * loss_ts + self.cl_ratio * loss_cl 371 | 372 | if is_train: 373 | self.update_memory(fxs_hist_pre) 374 | 375 | return [loss, torch.zeros(1, dtype=torch.float32).cuda()] 376 | 377 | elif mode == "image": 378 | 379 | fxs = [F.normalize(fx, dim=1) for fx in fxs] 380 | fxs_pre = [F.normalize(fx, dim=1) for fx in fxs_pre] 381 | fxs_hist_pre = [F.normalize(fx, dim=1) for fx in fxs_hist_pre] 382 | 383 | loss = torch.zeros(1, dtype=torch.float32).cuda() 384 | 385 | for ii, tri in enumerate([[0, 1, 0]]): 386 | qi, pi, ni = tri[0], tri[1], tri[2] 387 | q = fxs[qi].detach() if self.stop_grad_query else fxs[qi] 388 | p = fxs_pre[pi].detach() if self.stop_grad_key else fxs_pre[pi] 389 | 390 | append = "" 391 | if ni > 0: 392 | append = "_%d" % ni 393 | nq = q.shape[0] 394 | np = p.shape[0] 395 | assert nq >= np 396 | ncrops = nq // np 397 | out_pos = self.build_positive(q, p, nq, np, ncrops) # BxB 398 | out_neg = self.build_negative(q, append) # BxN 399 | targets = torch.tensor(list(range(nq)), dtype=torch.long).cuda() 400 | 401 | loss += self.compute_cont_loss(out_pos, out_neg, targets) 402 | 403 | if is_train: 404 | self.update_memory(fxs_hist_pre) 405 | 406 | return [loss, torch.zeros(1, dtype=torch.float32).cuda()] 407 | 408 | 409 | class SelfDistillation(nn.Module): 410 | def __init__( 411 | self, 412 | args, 413 | out_dim, 414 | warmup_teacher_temp, 415 | teacher_temp, 416 | warmup_teacher_temp_epochs, 417 | nepochs, 418 | student_temp=0.1, 419 | center_momentum=0.9, 420 | ): 421 | super().__init__() 422 | self.student_temp = student_temp 423 | self.center_momentum = center_momentum 424 | 425 | self.register_buffer("center", torch.zeros(1, out_dim + 4096)) 426 | 427 | self.teacher_temp_schedule = np.concatenate( 428 | ( 429 | np.linspace( 430 | warmup_teacher_temp, teacher_temp, warmup_teacher_temp_epochs 431 | ), 432 | np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp, 433 | ) 434 | ) 435 | self.teacher_temp = teacher_temp 436 | 437 | def forward(self, student_output, teacher_output, epoch): 438 | """ 439 | Cross-entropy between softmax outputs of the teacher and student networks. 440 | """ 441 | student_out = student_output / self.student_temp 442 | temp = self.teacher_temp_schedule[epoch] 443 | teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1) 444 | teacher_out = teacher_out.detach() # .chunk(2) 445 | 446 | total_loss = 0 447 | n_loss_terms = 0 448 | loss = torch.sum(-teacher_out * F.log_softmax(student_out, dim=-1), dim=-1) 449 | total_loss += loss.mean() 450 | self.update_center(teacher_output) 451 | 452 | return total_loss 453 | 454 | @torch.no_grad() 455 | def update_center(self, teacher_output): 456 | """ 457 | Update center used for teacher output. 458 | """ 459 | batch_center = torch.sum(teacher_output, dim=0, keepdim=True) 460 | du.all_reduce(batch_center) 461 | batch_center = batch_center / (len(teacher_output) * du.get_world_size()) 462 | self.center = self.center * self.center_momentum + batch_center * ( 463 | 1 - self.center_momentum 464 | ) 465 | 466 | 467 | class NCEAverage(Memory_Loss): 468 | def build_loss(self): 469 | self.loss_fn = nn.CrossEntropyLoss() 470 | self.loss_fn_mask = nn.CrossEntropyLoss(reduction="none") 471 | 472 | def compute_cont_loss(self, out_pos, out_neg, targets, T=0.2): 473 | out_x = torch.cat([out_pos, out_neg], dim=1) 474 | out_x = torch.div(out_x, T) 475 | 476 | return self.loss_fn(out_x, targets) 477 | 478 | def compute_cont_loss_mask(self, out_pos, out_neg, targets, mask, T=0.2): 479 | out_x = torch.cat([out_pos, out_neg], dim=1) 480 | if out_x.shape[1] > len(mask): 481 | mask = torch.cat( 482 | [mask, torch.ones(out_x.shape[1] - len(mask)).cuda()], dim=0 483 | ) 484 | out_x = torch.div(out_x, T) 485 | loss_fn = nn.CrossEntropyLoss(weight=mask) 486 | loss = loss_fn(out_x, targets) 487 | 488 | return loss 489 | 490 | 491 | class SupConLoss(Memory_Loss): 492 | def build_loss(self): 493 | self.temperature = 0.2 494 | self.contrast_mode = "all" 495 | self.base_temperature = 0.07 496 | self.loss_fn = nn.CrossEntropyLoss() 497 | 498 | def compute_cont_loss(self, p, q): 499 | 500 | B, C, N, P = p.shape 501 | 502 | features = torch.cat( 503 | [ 504 | p.permute(0, 2, 3, 1).reshape(-1, P, C), 505 | q.permute(0, 2, 3, 1).reshape(-1, P, C), 506 | ], 507 | dim=1, 508 | ) 509 | 510 | batch_size = features.shape[0] 511 | 512 | mask = torch.eye(batch_size, dtype=torch.float32).cuda() 513 | 514 | contrast_count = features.shape[1] 515 | 516 | contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0) 517 | if self.contrast_mode == "one": 518 | anchor_feature = features[:, 0] 519 | anchor_count = 1 520 | elif self.contrast_mode == "all": 521 | anchor_feature = contrast_feature 522 | anchor_count = contrast_count 523 | else: 524 | raise ValueError("Unknown mode: {}".format(self.contrast_mode)) 525 | 526 | # compute logits 527 | anchor_dot_contrast = torch.div( 528 | torch.matmul(anchor_feature, contrast_feature.T), self.temperature 529 | ) 530 | logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True) 531 | logits = anchor_dot_contrast - logits_max.detach() 532 | 533 | # tile mask 534 | mask = mask.repeat(anchor_count, contrast_count) 535 | # mask-out self-contrast cases 536 | logits_mask = torch.scatter( 537 | torch.ones_like(mask), 538 | 1, 539 | torch.arange(batch_size * anchor_count).view(-1, 1).cuda(), 540 | 0, 541 | ) 542 | mask = mask * logits_mask 543 | 544 | # compute log_prob 545 | exp_logits = torch.exp(logits) * logits_mask 546 | log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True)) 547 | 548 | # compute mean of log-likelihood over positive 549 | mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1) 550 | 551 | # loss 552 | loss = -(self.temperature / self.base_temperature) * mean_log_prob_pos 553 | loss = loss.view(anchor_count, batch_size).mean() 554 | 555 | return loss 556 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | Section 1 -- Definitions. 71 | 72 | a. Adapted Material means material subject to Copyright and Similar 73 | Rights that is derived from or based upon the Licensed Material 74 | and in which the Licensed Material is translated, altered, 75 | arranged, transformed, or otherwise modified in a manner requiring 76 | permission under the Copyright and Similar Rights held by the 77 | Licensor. For purposes of this Public License, where the Licensed 78 | Material is a musical work, performance, or sound recording, 79 | Adapted Material is always produced where the Licensed Material is 80 | synched in timed relation with a moving image. 81 | 82 | b. Adapter's License means the license You apply to Your Copyright 83 | and Similar Rights in Your contributions to Adapted Material in 84 | accordance with the terms and conditions of this Public License. 85 | 86 | c. Copyright and Similar Rights means copyright and/or similar rights 87 | closely related to copyright including, without limitation, 88 | performance, broadcast, sound recording, and Sui Generis Database 89 | Rights, without regard to how the rights are labeled or 90 | categorized. For purposes of this Public License, the rights 91 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 92 | Rights. 93 | d. Effective Technological Measures means those measures that, in the 94 | absence of proper authority, may not be circumvented under laws 95 | fulfilling obligations under Article 11 of the WIPO Copyright 96 | Treaty adopted on December 20, 1996, and/or similar international 97 | agreements. 98 | 99 | e. Exceptions and Limitations means fair use, fair dealing, and/or 100 | any other exception or limitation to Copyright and Similar Rights 101 | that applies to Your use of the Licensed Material. 102 | 103 | f. Licensed Material means the artistic or literary work, database, 104 | or other material to which the Licensor applied this Public 105 | License. 106 | 107 | g. Licensed Rights means the rights granted to You subject to the 108 | terms and conditions of this Public License, which are limited to 109 | all Copyright and Similar Rights that apply to Your use of the 110 | Licensed Material and that the Licensor has authority to license. 111 | 112 | h. Licensor means the individual(s) or entity(ies) granting rights 113 | under this Public License. 114 | 115 | i. NonCommercial means not primarily intended for or directed towards 116 | commercial advantage or monetary compensation. For purposes of 117 | this Public License, the exchange of the Licensed Material for 118 | other material subject to Copyright and Similar Rights by digital 119 | file-sharing or similar means is NonCommercial provided there is 120 | no payment of monetary compensation in connection with the 121 | exchange. 122 | 123 | j. Share means to provide material to the public by any means or 124 | process that requires permission under the Licensed Rights, such 125 | as reproduction, public display, public performance, distribution, 126 | dissemination, communication, or importation, and to make material 127 | available to the public including in ways that members of the 128 | public may access the material from a place and at a time 129 | individually chosen by them. 130 | 131 | k. Sui Generis Database Rights means rights other than copyright 132 | resulting from Directive 96/9/EC of the European Parliament and of 133 | the Council of 11 March 1996 on the legal protection of databases, 134 | as amended and/or succeeded, as well as other essentially 135 | equivalent rights anywhere in the world. 136 | 137 | l. You means the individual or entity exercising the Licensed Rights 138 | under this Public License. Your has a corresponding meaning. 139 | 140 | Section 2 -- Scope. 141 | 142 | a. License grant. 143 | 144 | 1. Subject to the terms and conditions of this Public License, 145 | the Licensor hereby grants You a worldwide, royalty-free, 146 | non-sublicensable, non-exclusive, irrevocable license to 147 | exercise the Licensed Rights in the Licensed Material to: 148 | 149 | a. reproduce and Share the Licensed Material, in whole or 150 | in part, for NonCommercial purposes only; and 151 | 152 | b. produce, reproduce, and Share Adapted Material for 153 | NonCommercial purposes only. 154 | 155 | 2. Exceptions and Limitations. For the avoidance of doubt, where 156 | Exceptions and Limitations apply to Your use, this Public 157 | License does not apply, and You do not need to comply with 158 | its terms and conditions. 159 | 160 | 3. Term. The term of this Public License is specified in Section 161 | 6(a). 162 | 163 | 4. Media and formats; technical modifications allowed. The 164 | Licensor authorizes You to exercise the Licensed Rights in 165 | all media and formats whether now known or hereafter created, 166 | and to make technical modifications necessary to do so. The 167 | Licensor waives and/or agrees not to assert any right or 168 | authority to forbid You from making technical modifications 169 | necessary to exercise the Licensed Rights, including 170 | technical modifications necessary to circumvent Effective 171 | Technological Measures. For purposes of this Public License, 172 | simply making modifications authorized by this Section 2(a) 173 | (4) never produces Adapted Material. 174 | 175 | 5. Downstream recipients. 176 | 177 | a. Offer from the Licensor -- Licensed Material. Every 178 | recipient of the Licensed Material automatically 179 | receives an offer from the Licensor to exercise the 180 | Licensed Rights under the terms and conditions of this 181 | Public License. 182 | 183 | b. No downstream restrictions. You may not offer or impose 184 | any additional or different terms or conditions on, or 185 | apply any Effective Technological Measures to, the 186 | Licensed Material if doing so restricts exercise of the 187 | Licensed Rights by any recipient of the Licensed 188 | Material. 189 | 190 | 6. No endorsement. Nothing in this Public License constitutes or 191 | may be construed as permission to assert or imply that You 192 | are, or that Your use of the Licensed Material is, connected 193 | with, or sponsored, endorsed, or granted official status by, 194 | the Licensor or others designated to receive attribution as 195 | provided in Section 3(a)(1)(A)(i). 196 | 197 | b. Other rights. 198 | 199 | 1. Moral rights, such as the right of integrity, are not 200 | licensed under this Public License, nor are publicity, 201 | privacy, and/or other similar personality rights; however, to 202 | the extent possible, the Licensor waives and/or agrees not to 203 | assert any such rights held by the Licensor to the limited 204 | extent necessary to allow You to exercise the Licensed 205 | Rights, but not otherwise. 206 | 207 | 2. Patent and trademark rights are not licensed under this 208 | Public License. 209 | 210 | 3. To the extent possible, the Licensor waives any right to 211 | collect royalties from You for the exercise of the Licensed 212 | Rights, whether directly or through a collecting society 213 | under any voluntary or waivable statutory or compulsory 214 | licensing scheme. In all other cases the Licensor expressly 215 | reserves any right to collect such royalties, including when 216 | the Licensed Material is used other than for NonCommercial 217 | purposes. 218 | 219 | Section 3 -- License Conditions. 220 | 221 | Your exercise of the Licensed Rights is expressly made subject to the 222 | following conditions. 223 | 224 | a. Attribution. 225 | 226 | 1. If You Share the Licensed Material (including in modified 227 | form), You must: 228 | 229 | a. retain the following if it is supplied by the Licensor 230 | with the Licensed Material: 231 | 232 | i. identification of the creator(s) of the Licensed 233 | Material and any others designated to receive 234 | attribution, in any reasonable manner requested by 235 | the Licensor (including by pseudonym if 236 | designated); 237 | 238 | ii. a copyright notice; 239 | 240 | iii. a notice that refers to this Public License; 241 | 242 | iv. a notice that refers to the disclaimer of 243 | warranties; 244 | 245 | v. a URI or hyperlink to the Licensed Material to the 246 | extent reasonably practicable; 247 | 248 | b. indicate if You modified the Licensed Material and 249 | retain an indication of any previous modifications; and 250 | 251 | c. indicate the Licensed Material is licensed under this 252 | Public License, and include the text of, or the URI or 253 | hyperlink to, this Public License. 254 | 255 | 2. You may satisfy the conditions in Section 3(a)(1) in any 256 | reasonable manner based on the medium, means, and context in 257 | which You Share the Licensed Material. For example, it may be 258 | reasonable to satisfy the conditions by providing a URI or 259 | hyperlink to a resource that includes the required 260 | information. 261 | 262 | 3. If requested by the Licensor, You must remove any of the 263 | information required by Section 3(a)(1)(A) to the extent 264 | reasonably practicable. 265 | 266 | 4. If You Share Adapted Material You produce, the Adapter's 267 | License You apply must not prevent recipients of the Adapted 268 | Material from complying with this Public License. 269 | 270 | Section 4 -- Sui Generis Database Rights. 271 | 272 | Where the Licensed Rights include Sui Generis Database Rights that 273 | apply to Your use of the Licensed Material: 274 | 275 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 276 | to extract, reuse, reproduce, and Share all or a substantial 277 | portion of the contents of the database for NonCommercial purposes 278 | only; 279 | 280 | b. if You include all or a substantial portion of the database 281 | contents in a database in which You have Sui Generis Database 282 | Rights, then the database in which You have Sui Generis Database 283 | Rights (but not its individual contents) is Adapted Material; and 284 | 285 | c. You must comply with the conditions in Section 3(a) if You Share 286 | all or a substantial portion of the contents of the database. 287 | 288 | For the avoidance of doubt, this Section 4 supplements and does not 289 | replace Your obligations under this Public License where the Licensed 290 | Rights include other Copyright and Similar Rights. 291 | 292 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 293 | 294 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 295 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 296 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 297 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 298 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 299 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 300 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 301 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 302 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 303 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 304 | 305 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 306 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 307 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 308 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 309 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 310 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 311 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 312 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 313 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 314 | 315 | c. The disclaimer of warranties and limitation of liability provided 316 | above shall be interpreted in a manner that, to the extent 317 | possible, most closely approximates an absolute disclaimer and 318 | waiver of all liability. 319 | 320 | Section 6 -- Term and Termination. 321 | 322 | a. This Public License applies for the term of the Copyright and 323 | Similar Rights licensed here. However, if You fail to comply with 324 | this Public License, then Your rights under this Public License 325 | terminate automatically. 326 | 327 | b. Where Your right to use the Licensed Material has terminated under 328 | Section 6(a), it reinstates: 329 | 330 | 1. automatically as of the date the violation is cured, provided 331 | it is cured within 30 days of Your discovery of the 332 | violation; or 333 | 334 | 2. upon express reinstatement by the Licensor. 335 | 336 | For the avoidance of doubt, this Section 6(b) does not affect any 337 | right the Licensor may have to seek remedies for Your violations 338 | of this Public License. 339 | 340 | c. For the avoidance of doubt, the Licensor may also offer the 341 | Licensed Material under separate terms or conditions or stop 342 | distributing the Licensed Material at any time; however, doing so 343 | will not terminate this Public License. 344 | 345 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 346 | License. 347 | 348 | Section 7 -- Other Terms and Conditions. 349 | 350 | a. The Licensor shall not be bound by any additional or different 351 | terms or conditions communicated by You unless expressly agreed. 352 | 353 | b. Any arrangements, understandings, or agreements regarding the 354 | Licensed Material not stated herein are separate from and 355 | independent of the terms and conditions of this Public License. 356 | 357 | Section 8 -- Interpretation. 358 | 359 | a. For the avoidance of doubt, this Public License does not, and 360 | shall not be interpreted to, reduce, limit, restrict, or impose 361 | conditions on any use of the Licensed Material that could lawfully 362 | be made without permission under this Public License. 363 | 364 | b. To the extent possible, if any provision of this Public License is 365 | deemed unenforceable, it shall be automatically reformed to the 366 | minimum extent necessary to make it enforceable. If the provision 367 | cannot be reformed, it shall be severed from this Public License 368 | without affecting the enforceability of the remaining terms and 369 | conditions. 370 | 371 | c. No term or condition of this Public License will be waived and no 372 | failure to comply consented to unless expressly agreed to by the 373 | Licensor. 374 | 375 | d. Nothing in this Public License constitutes or may be interpreted 376 | as a limitation upon, or waiver of, any privileges and immunities 377 | that apply to the Licensor or You, including from the legal 378 | processes of any jurisdiction or authority. 379 | 380 | ======================================================================= 381 | 382 | Creative Commons is not a party to its public 383 | licenses. Notwithstanding, Creative Commons may elect to apply one of 384 | its public licenses to material it publishes and in those instances 385 | will be considered the “Licensor.” The text of the Creative Commons 386 | public licenses is dedicated to the public domain under the CC0 Public 387 | Domain Dedication. Except for the limited purpose of indicating that 388 | material is shared under a Creative Commons public license or as 389 | otherwise permitted by the Creative Commons policies published at 390 | creativecommons.org/policies, Creative Commons does not authorize the 391 | use of the trademark "Creative Commons" or any other trademark or logo 392 | of Creative Commons without its prior written consent including, 393 | without limitation, in connection with any unauthorized modifications 394 | to any of its public licenses or any other arrangements, 395 | understandings, or agreements concerning use of licensed material. For 396 | the avoidance of doubt, this paragraph does not form part of the 397 | public licenses. 398 | 399 | Creative Commons may be contacted at creativecommons.org. 400 | --------------------------------------------------------------------------------