├── .flake8 ├── .gitignore ├── LICENSE ├── README.md ├── assets └── architecture.png ├── conf └── pretrain.yaml ├── detcon ├── __init__.py ├── datasets │ ├── __init__.py │ ├── coco.py │ ├── transforms.py │ └── voc.py ├── losses.py └── models.py ├── pretrain.py ├── pyproject.toml ├── requirements.txt └── tests ├── __init__.py ├── conftest.py ├── test_datasets.py ├── test_losses.py └── test_models.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 88 3 | extend-ignore = 4 | # See https://github.com/PyCQA/pycodestyle/issues/373 5 | E203, 6 | exclude = 7 | data/, 8 | logo/, 9 | logs/, 10 | lightning_logs/, 11 | output/, 12 | 13 | # Python 14 | build/, 15 | dist/, 16 | .cache/, 17 | .mypy_cache/, 18 | .pytest_cache/, 19 | __pycache__/, 20 | *.egg-info/ 21 | 22 | # Git 23 | .git/, 24 | .github/, -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | logs/ 2 | lightning_logs/ 3 | data/ 4 | data 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | pip-wheel-metadata/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 isaac 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # detcon-pytorch 2 | (WIP) PyTorch implementation of DeepMind's DetCon from ["Efficient Visual Pretraining with Contrastive Detection" Henaff et al. (ICCV 2021)](https://arxiv.org/abs/2103.10957) 3 | 4 | 5 | -------------------------------------------------------------------------------- /assets/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/detcon-pytorch/a5e03faf0c27bdbe64b72625873c0b2d3a696f04/assets/architecture.png -------------------------------------------------------------------------------- /conf/pretrain.yaml: -------------------------------------------------------------------------------- 1 | module: 2 | backbone: resnet18 3 | pretrained: True 4 | num_classes: 21 5 | num_samples: 5 6 | downsample: 32 7 | proj_hidden_dim: 128 8 | proj_dim: 256 9 | 10 | datamodule: 11 | root: /mnt/e/data/ 12 | batch_size: 32 13 | num_workers: 4 14 | prefetch_factor: 2 15 | 16 | trainer: 17 | max_epochs: 5 18 | precision: 16 19 | gpus: 1 20 | log_every_n_steps: 25 21 | -------------------------------------------------------------------------------- /detcon/__init__.py: -------------------------------------------------------------------------------- 1 | from . import datasets, losses 2 | from .models import DetConB 3 | -------------------------------------------------------------------------------- /detcon/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .coco import CocoSegmentation 2 | from .voc import VOCSegmentationDataModule, VOCSSLDataModule 3 | -------------------------------------------------------------------------------- /detcon/datasets/coco.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision.transforms as T 4 | from torchvision.datasets import CocoDetection 5 | from torchvision.transforms import InterpolationMode 6 | 7 | default_transform = T.Compose( 8 | [T.ToTensor(), T.Resize(size=(224, 224), interpolation=InterpolationMode.BILINEAR)] 9 | ) 10 | 11 | default_target_transform = T.Compose( 12 | [T.Resize(size=(224, 224), interpolation=InterpolationMode.NEAREST)] 13 | ) 14 | 15 | 16 | class CocoSegmentation(CocoDetection): 17 | def __init__(self, *args, **kwargs) -> None: 18 | if "transform" not in kwargs: 19 | kwargs["transform"] = default_transform 20 | if "target_transform" not in kwargs: 21 | kwargs["target_transform"] = default_target_transform 22 | super().__init__(*args, **kwargs) 23 | 24 | def _load_target(self, id: int) -> torch.Tensor: 25 | # Load binary masks 26 | anns = self.coco.loadAnns(self.coco.getAnnIds(id)) 27 | masks = [self.coco.annToMask(ann) for ann in anns] 28 | cats = [ann["category_id"] for ann in anns] 29 | 30 | # Create uint8 mask from binary masks 31 | t = self.coco.imgs[anns[0]["image_id"]] 32 | h, w = t["height"], t["width"] 33 | x = np.zeros(h, w).astype("uint8") 34 | for mask, c in zip(masks, cats): 35 | x[mask] = c 36 | x = torch.from_numpy(x) 37 | return x 38 | -------------------------------------------------------------------------------- /detcon/datasets/transforms.py: -------------------------------------------------------------------------------- 1 | import kornia.augmentation as K 2 | 3 | 4 | class RandomResizedCrop(K.RandomResizedCrop): 5 | def __init__(self, *args, **kwargs) -> None: 6 | if kwargs["align_corners"] is None: 7 | kwargs["align_corners"] = False 8 | super().__init__(*args, **kwargs) 9 | self.align_corners = None 10 | 11 | 12 | default_augs = K.AugmentationSequential( 13 | K.RandomHorizontalFlip(p=0.5), data_keys=["input", "mask"] 14 | ) 15 | 16 | default_ssl_augs = K.AugmentationSequential( 17 | K.RandomHorizontalFlip(p=0.5), 18 | K.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1, p=0.5), 19 | K.RandomGrayscale(p=0.2), 20 | K.RandomGaussianBlur(kernel_size=(23, 23), sigma=(0.1, 2.0), p=0.5), 21 | RandomResizedCrop( 22 | size=(224, 224), scale=(0.08, 1.0), resample="NEAREST", align_corners=None 23 | ), 24 | K.RandomSolarize(thresholds=0.1, p=0.2), 25 | data_keys=["input", "mask"], 26 | ) 27 | -------------------------------------------------------------------------------- /detcon/datasets/voc.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, Optional, Tuple 2 | 3 | import kornia.augmentation as K 4 | import pytorch_lightning as pl 5 | import torch 6 | import torchvision.transforms as T 7 | from torch.utils.data import DataLoader 8 | from torchvision.datasets import VOCSegmentation 9 | from torchvision.transforms import InterpolationMode 10 | 11 | from detcon.datasets.transforms import default_augs, default_ssl_augs 12 | 13 | default_transform = T.Compose( 14 | [T.ToTensor(), T.Resize(size=(224, 224), interpolation=InterpolationMode.BILINEAR)] 15 | ) 16 | 17 | default_target_transform = T.Compose( 18 | [ 19 | T.PILToTensor(), 20 | T.Resize(size=(224, 224), interpolation=InterpolationMode.NEAREST), 21 | ] 22 | ) 23 | 24 | 25 | class VOCSegmentationBaseDataModule(pl.LightningDataModule): 26 | 27 | classes = [ 28 | "background", 29 | "aeroplane", 30 | "bicycle", 31 | "bird", 32 | "boat", 33 | "bottle", 34 | "bus", 35 | "car", 36 | "cat", 37 | "chair", 38 | "cow", 39 | "diningtable", 40 | "dog", 41 | "horse", 42 | "motorbike", 43 | "person", 44 | "pottedplant", 45 | "sheep", 46 | "sofa", 47 | "train", 48 | "tvmonitor", 49 | ] 50 | 51 | def __init__( 52 | self, 53 | root: str, 54 | transform: Optional[Callable] = default_transform, 55 | target_transform: Optional[Callable] = default_target_transform, 56 | transforms: Optional[Callable] = None, 57 | augmentations: K.AugmentationSequential = default_augs, 58 | batch_size: int = 1, 59 | num_workers: int = 0, 60 | prefetch_factor: Optional[int] = 2, 61 | pin_memory: bool = False, 62 | ): 63 | super().__init__() 64 | self.root = root 65 | self.transform = transform 66 | self.target_transform = target_transform 67 | self.transforms = transforms 68 | self.augmentations = augmentations 69 | self.batch_size = batch_size 70 | self.num_workers = num_workers 71 | self.prefetch_factor = prefetch_factor 72 | self.pin_memory = pin_memory 73 | self.num_classes = len(self.classes) 74 | self.idx2class = {i: c for i, c in enumerate(self.classes)} 75 | self.idx2class[255] = "ignore" 76 | 77 | def train_dataloader(self) -> DataLoader: 78 | return DataLoader( 79 | self.train_dataset, 80 | shuffle=True, 81 | batch_size=self.batch_size, 82 | num_workers=self.num_workers, 83 | prefetch_factor=self.prefetch_factor, 84 | pin_memory=self.pin_memory, 85 | ) 86 | 87 | def on_before_batch_transfer( 88 | self, batch: Tuple[torch.Tensor, torch.Tensor], dataloader_idx: int 89 | ) -> Dict[str, torch.Tensor]: 90 | x, y = batch 91 | y[y == 255] = 0 92 | return {"image": x, "mask": y} 93 | 94 | 95 | class VOCSegmentationDataModule(VOCSegmentationBaseDataModule): 96 | def __init__(self, *args, **kwargs) -> None: 97 | super().__init__(*args, **kwargs) 98 | 99 | def setup(self, stage: Optional[str] = None): 100 | self.train_dataset = VOCSegmentation( 101 | root=self.root, 102 | year="2012", 103 | image_set="train", 104 | transform=self.transform, 105 | target_transform=self.target_transform, 106 | transforms=self.transforms, 107 | ) 108 | self.val_dataset = VOCSegmentation( 109 | root=self.root, 110 | year="2012", 111 | image_set="val", 112 | transform=self.transform, 113 | target_transform=self.target_transform, 114 | transforms=self.transforms, 115 | ) 116 | self.test_dataset = VOCSegmentation( 117 | root=self.root, 118 | year="2007", 119 | image_set="test", 120 | transform=self.transform, 121 | target_transform=self.target_transform, 122 | transforms=self.transforms, 123 | ) 124 | 125 | def val_dataloader(self) -> DataLoader: 126 | return DataLoader( 127 | self.val_dataset, 128 | batch_size=self.batch_size, 129 | num_workers=self.num_workers, 130 | prefetch_factor=self.prefetch_factor, 131 | pin_memory=self.pin_memory, 132 | ) 133 | 134 | def test_dataloader(self) -> DataLoader: 135 | return DataLoader( 136 | self.test_dataset, 137 | batch_size=self.batch_size, 138 | num_workers=self.num_workers, 139 | prefetch_factor=self.prefetch_factor, 140 | pin_memory=self.pin_memory, 141 | ) 142 | 143 | def on_after_batch_transfer( 144 | self, batch: Dict[str, torch.Tensor], dataloader_idx: int 145 | ) -> Dict[str, torch.Tensor]: 146 | batch["mask"] = batch["mask"].to(torch.float) 147 | batch["image"], batch["mask"] = self.augmentations( 148 | batch["image"], batch["mask"] 149 | ) 150 | batch["mask"] = batch["mask"].to(torch.long) 151 | batch["mask"] = batch["mask"].squeeze(dim=1) 152 | return batch 153 | 154 | 155 | class VOCSSLDataModule(VOCSegmentationBaseDataModule): 156 | def __init__(self, *args, **kwargs) -> None: 157 | kwargs["augmentations"] = default_ssl_augs 158 | super().__init__(*args, **kwargs) 159 | 160 | def setup(self, stage: Optional[str] = None) -> None: 161 | self.train_dataset = VOCSegmentation( 162 | root=self.root, 163 | year="2012", 164 | image_set="train", 165 | transform=self.transform, 166 | target_transform=self.target_transform, 167 | transforms=self.transforms, 168 | ) 169 | 170 | def on_after_batch_transfer( 171 | self, batch: Dict[str, torch.Tensor], dataloader_idx: int 172 | ) -> Dict[str, torch.Tensor]: 173 | batch["mask"] = batch["mask"].to(torch.float) 174 | image1, mask1 = self.augmentations(batch["image"], batch["mask"]) 175 | image2, mask2 = self.augmentations(batch["image"], batch["mask"]) 176 | mask1 = mask1.squeeze(dim=1).to(torch.long) 177 | mask2 = mask2.squeeze(dim=1).to(torch.long) 178 | batch = {"image": (image1, image2), "mask": (mask1, mask2)} 179 | return batch 180 | -------------------------------------------------------------------------------- /detcon/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def manual_cross_entropy( 7 | logits: torch.Tensor, labels: torch.Tensor, weight: torch.Tensor 8 | ) -> torch.Tensor: 9 | ce = -weight * torch.sum(labels * F.log_softmax(logits, dim=-1), dim=-1) 10 | return torch.mean(ce) 11 | 12 | 13 | class DetConBLoss(nn.Module): 14 | """Modified from https://github.com/deepmind/detcon/blob/main/utils/losses.py.""" 15 | 16 | def __init__(self, temperature: float = 0.1) -> None: 17 | super().__init__() 18 | self.temperature = torch.tensor(temperature) 19 | 20 | def forward( 21 | self, 22 | pred1: torch.Tensor, 23 | pred2: torch.Tensor, 24 | target1: torch.Tensor, 25 | target2: torch.Tensor, 26 | pind1: torch.Tensor, 27 | pind2: torch.Tensor, 28 | tind1: torch.Tensor, 29 | tind2: torch.Tensor, 30 | local_negatives: bool = True, 31 | ) -> torch.Tensor: 32 | """Compute the NCE scores from pairs of predictions and targets. 33 | 34 | This implements the batched form of the loss described in 35 | Section 3.1, Equation 3 in https://arxiv.org/pdf/2103.10957.pdf. 36 | 37 | Args: 38 | pred1: (b, num_samples, d) the prediction from first view. 39 | pred2: (b, num_samples, d) the prediction from second view. 40 | target1: (b, num_samples, d) the projection from first view. 41 | target2: (b, num_samples, d) the projection from second view. 42 | pind1: (b, num_samples) mask indices for first view's prediction. 43 | pind2: (b, num_samples) mask indices for second view's prediction. 44 | tind1: (b, num_samples) mask indices for first view's projection. 45 | tind2: (b, num_samples) mask indices for second view's projection. 46 | temperature: (float) the temperature to use for the NCE loss. 47 | local_negatives (bool): whether to include local negatives 48 | 49 | Returns: 50 | A single scalar loss for the XT-NCE objective. 51 | """ 52 | bs, num_samples, num_features = pred1.shape 53 | infinity_proxy = 1e9 # Used for masks to proxy a very large number. 54 | eps = 1e-11 55 | 56 | def make_same_obj(ind_0, ind_1): 57 | same_obj = torch.eq( 58 | ind_0.reshape([bs, num_samples, 1]), ind_1.reshape([bs, 1, num_samples]) 59 | ) 60 | same_obj = same_obj.unsqueeze(2).to(torch.float) 61 | return same_obj 62 | 63 | same_obj_aa = make_same_obj(pind1, tind1) 64 | same_obj_ab = make_same_obj(pind1, tind2) 65 | same_obj_ba = make_same_obj(pind2, tind1) 66 | same_obj_bb = make_same_obj(pind2, tind2) 67 | 68 | # L2 normalize the tensors to use for the cosine-similarity 69 | pred1 = F.normalize(pred1, dim=-1) 70 | pred2 = F.normalize(pred2, dim=-1) 71 | target1 = F.normalize(target1, dim=-1) 72 | target2 = F.normalize(target2, dim=-1) 73 | 74 | labels = F.one_hot(torch.arange(bs), num_classes=bs).to(pred1.device) 75 | labels = labels.unsqueeze(dim=2).unsqueeze(dim=1) 76 | 77 | # Do our matmuls and mask out appropriately. 78 | logits_aa = torch.einsum("abk,uvk->abuv", pred1, target1) / ( 79 | self.temperature + eps 80 | ) 81 | logits_bb = torch.einsum("abk,uvk->abuv", pred2, target2) / ( 82 | self.temperature + eps 83 | ) 84 | logits_ab = torch.einsum("abk,uvk->abuv", pred1, target2) / ( 85 | self.temperature + eps 86 | ) 87 | logits_ba = torch.einsum("abk,uvk->abuv", pred2, target1) / ( 88 | self.temperature + eps 89 | ) 90 | 91 | labels_aa = labels * same_obj_aa 92 | labels_ab = labels * same_obj_ab 93 | labels_ba = labels * same_obj_ba 94 | labels_bb = labels * same_obj_bb 95 | 96 | logits_aa = logits_aa - infinity_proxy * labels * same_obj_aa 97 | logits_bb = logits_bb - infinity_proxy * labels * same_obj_bb 98 | labels_aa = 0.0 * labels_aa 99 | labels_bb = 0.0 * labels_bb 100 | 101 | if not local_negatives: 102 | logits_aa = logits_aa - infinity_proxy * labels * (1 - same_obj_aa) 103 | logits_ab = logits_ab - infinity_proxy * labels * (1 - same_obj_ab) 104 | logits_ba = logits_ba - infinity_proxy * labels * (1 - same_obj_ba) 105 | logits_bb = logits_bb - infinity_proxy * labels * (1 - same_obj_bb) 106 | 107 | labels_abaa = torch.cat([labels_ab, labels_aa], dim=2) 108 | labels_babb = torch.cat([labels_ba, labels_bb], dim=2) 109 | 110 | labels_0 = labels_abaa.reshape((bs, num_samples, -1)) 111 | labels_1 = labels_babb.reshape((bs, num_samples, -1)) 112 | 113 | num_positives_0 = torch.sum(labels_0, dim=-1, keepdim=True) 114 | num_positives_1 = torch.sum(labels_1, dim=-1, keepdim=True) 115 | 116 | labels_0 = labels_0 / torch.maximum(num_positives_0, torch.tensor(1.0)) 117 | labels_1 = labels_1 / torch.maximum(num_positives_1, torch.tensor(1.0)) 118 | 119 | obj_area_0 = torch.sum(make_same_obj(pind1, pind1), dim=(2, 3)) 120 | obj_area_1 = torch.sum(make_same_obj(pind2, pind2), dim=(2, 3)) 121 | 122 | weights_0 = torch.greater(num_positives_0[..., 0], 1e-3).to(torch.float) 123 | weights_0 = weights_0 / obj_area_0 124 | weights_1 = torch.greater(num_positives_1[..., 0], 1e-3).to(torch.float) 125 | weights_1 = weights_1 / obj_area_1 126 | 127 | logits_abaa = torch.cat([logits_ab, logits_aa], dim=2) 128 | logits_babb = torch.cat([logits_ba, logits_bb], dim=2) 129 | 130 | logits_abaa = logits_abaa.reshape((bs, num_samples, -1)) 131 | logits_babb = logits_babb.reshape((bs, num_samples, -1)) 132 | 133 | loss_a = manual_cross_entropy(logits_abaa, labels_0, weight=weights_0) 134 | loss_b = manual_cross_entropy(logits_babb, labels_1, weight=weights_1) 135 | loss = loss_a + loss_b 136 | return loss 137 | -------------------------------------------------------------------------------- /detcon/models.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Sequence, Tuple 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | import torch.nn as nn 6 | import torchvision 7 | from einops import rearrange 8 | from torch_ema import ExponentialMovingAverage 9 | 10 | from detcon.losses import DetConBLoss 11 | 12 | 13 | class MLP(nn.Sequential): 14 | def __init__(self, input_dim: int, hidden_dim: int, output_dim: int) -> None: 15 | super().__init__( 16 | nn.Linear(input_dim, hidden_dim), 17 | nn.LayerNorm(hidden_dim), 18 | nn.ReLU(), 19 | nn.Linear(hidden_dim, output_dim), 20 | ) 21 | 22 | 23 | class Encoder(nn.Sequential): 24 | def __init__(self, backbone: str = "resnet50", pretrained: bool = False) -> None: 25 | model = getattr(torchvision.models, backbone)(pretrained) 26 | self.emb_dim = model.fc.in_features 27 | model.fc = nn.Identity() 28 | model.avgpool = nn.Identity() 29 | super().__init__(*list(model.children())) 30 | 31 | 32 | class MaskPooling(nn.Module): 33 | def __init__( 34 | self, num_classes: int, num_samples: int = 16, downsample: int = 32 35 | ) -> None: 36 | super().__init__() 37 | self.num_classes = num_classes 38 | self.num_samples = num_samples 39 | self.mask_ids = torch.arange(num_classes) 40 | self.pool = nn.AvgPool2d(kernel_size=downsample, stride=downsample) 41 | 42 | def pool_masks(self, masks: torch.Tensor) -> torch.Tensor: 43 | """Create binary masks and performs mask pooling 44 | 45 | Args: 46 | masks: (b, 1, h, w) 47 | 48 | Returns: 49 | masks: (b, num_classes, d) 50 | """ 51 | if masks.ndim < 4: 52 | masks = masks.unsqueeze(dim=1) 53 | 54 | masks = masks == self.mask_ids[None, :, None, None].to(masks.device) 55 | masks = self.pool(masks.to(torch.float)) 56 | masks = rearrange(masks, "b c h w -> b c (h w)") 57 | masks = torch.argmax(masks, dim=1) 58 | masks = torch.eye(self.num_classes).to(masks.device)[masks] 59 | masks = rearrange(masks, "b d c -> b c d") 60 | return masks 61 | 62 | def sample_masks(self, masks: torch.Tensor) -> torch.Tensor: 63 | """Samples which binary masks to use in the loss. 64 | 65 | Args: 66 | masks: (b, num_classes, d) 67 | 68 | Returns: 69 | masks: (b, num_samples, d) 70 | """ 71 | bs = masks.shape[0] 72 | mask_exists = torch.greater(masks.sum(dim=-1), 1e-3) 73 | sel_masks = mask_exists.to(torch.float) + 1e-11 74 | # torch.multinomial handles normalizing 75 | # sel_masks = sel_masks / sel_masks.sum(dim=1, keepdim=True) 76 | # sel_masks = torch.softmax(sel_masks, dim=-1) 77 | mask_ids = torch.multinomial(sel_masks, num_samples=self.num_samples) 78 | sampled_masks = torch.stack([masks[b][mask_ids[b]] for b in range(bs)]) 79 | return sampled_masks, mask_ids 80 | 81 | def forward(self, masks: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 82 | binary_masks = self.pool_masks(masks) 83 | sampled_masks, sampled_mask_ids = self.sample_masks(binary_masks) 84 | area = sampled_masks.sum(dim=-1, keepdim=True) 85 | sampled_masks = sampled_masks / torch.maximum(area, torch.tensor(1.0)) 86 | return sampled_masks, sampled_mask_ids 87 | 88 | 89 | class Network(nn.Module): 90 | def __init__( 91 | self, 92 | backbone: str = "resnet50", 93 | pretrained: bool = False, 94 | hidden_dim: int = 128, 95 | output_dim: int = 256, 96 | num_classes: int = 10, 97 | downsample: int = 32, 98 | num_samples: int = 16, 99 | ) -> None: 100 | super().__init__() 101 | self.encoder = Encoder(backbone, pretrained) 102 | self.projector = MLP(self.encoder.emb_dim, hidden_dim, output_dim) 103 | self.mask_pool = MaskPooling(num_classes, num_samples, downsample) 104 | 105 | def forward(self, x: torch.Tensor, masks: torch.Tensor) -> Sequence[torch.Tensor]: 106 | m, mids = self.mask_pool(masks) 107 | e = self.encoder(x) 108 | e = rearrange(e, "b c h w -> b (h w) c") 109 | e = m @ e 110 | p = self.projector(e) 111 | return e, p, m, mids 112 | 113 | 114 | class DetConB(pl.LightningModule): 115 | def __init__( 116 | self, 117 | num_classes: int = 21, 118 | num_samples: int = 5, 119 | backbone: str = "resnet50", 120 | pretrained: bool = False, 121 | downsample: int = 32, 122 | proj_hidden_dim: int = 128, 123 | proj_dim: int = 256, 124 | loss_fn: nn.Module = DetConBLoss(), 125 | ) -> None: 126 | super().__init__() 127 | self.save_hyperparameters(ignore=["loss_fn"]) 128 | self.loss_fn = loss_fn 129 | self.network = Network( 130 | backbone=backbone, 131 | pretrained=pretrained, 132 | hidden_dim=proj_hidden_dim, 133 | output_dim=proj_dim, 134 | num_classes=num_classes, 135 | downsample=downsample, 136 | num_samples=num_samples, 137 | ) 138 | self.ema = ExponentialMovingAverage(self.network.parameters(), decay=0.995) 139 | self.predictor = MLP(proj_dim, proj_hidden_dim, proj_dim) 140 | 141 | def configure_optimizers(self) -> torch.optim.Optimizer: 142 | return torch.optim.Adam(self.parameters(), lr=1e-3) 143 | 144 | def on_before_zero_grad(self, *args, **kwargs): 145 | """See https://forums.pytorchlightning.ai/t/adopting-exponential-moving-average-ema-for-pl-pipeline/488""" # noqa: E501 146 | self.ema.to(device=next(self.network.parameters()).device) 147 | self.ema.update(self.network.parameters()) 148 | 149 | def forward(self, x: torch.Tensor, y: torch.Tensor) -> Sequence[torch.Tensor]: 150 | return self.network(x, y) 151 | 152 | def training_step(self, batch: Dict, batch_idx: int) -> torch.Tensor: 153 | (x1, x2), (y1, y2) = batch["image"], batch["mask"] 154 | 155 | # encode and project 156 | _, p1, _, ids1 = self(x1, y1) 157 | _, p2, _, ids2 = self(x2, y2) 158 | 159 | # ema encode and project 160 | with self.ema.average_parameters(): 161 | _, ema_p1, _, ema_ids1 = self(x1, y1) 162 | _, ema_p2, _, ema_ids2 = self(x2, y2) 163 | 164 | # predict 165 | q1, q2 = self.predictor(p1), self.predictor(p2) 166 | 167 | # compute loss 168 | loss = self.loss_fn( 169 | pred1=q1, 170 | pred2=q2, 171 | target1=ema_p1.detach(), 172 | target2=ema_p2.detach(), 173 | pind1=ids1, 174 | pind2=ids2, 175 | tind1=ema_ids1, 176 | tind2=ema_ids2, 177 | ) 178 | self.log("loss", loss) 179 | return loss 180 | -------------------------------------------------------------------------------- /pretrain.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | 5 | import pytorch_lightning as pl 6 | from omegaconf import DictConfig, OmegaConf 7 | 8 | from detcon.datasets import VOCSSLDataModule 9 | from detcon.models import DetConB 10 | 11 | 12 | def main(cfg_path: str, cfg: DictConfig) -> None: 13 | pl.seed_everything(0, workers=True) 14 | module = DetConB(**cfg.module) 15 | datamodule = VOCSSLDataModule(**cfg.datamodule) 16 | trainer = pl.Trainer(**cfg.trainer) 17 | trainer.fit(model=module, datamodule=datamodule) 18 | shutil.copyfile(cfg_path, os.path.join(trainer.logger.log_dir, "config.yaml")) 19 | 20 | 21 | if __name__ == "__main__": 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument( 24 | "--cfg", type=str, required=True, help="Path to config.yaml file" 25 | ) 26 | args = parser.parse_args() 27 | cfg = OmegaConf.load(args.cfg) 28 | main(args.cfg, cfg) 29 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | target-version = ["py38", "py39"] 3 | color = true 4 | skip_magic_trailing_comma = true 5 | exclude = ''' 6 | /( 7 | | data 8 | | logo 9 | | logs 10 | | lightning_logs 11 | | output 12 | # Python 13 | | build 14 | | dist 15 | | \.cache 16 | | \.mypy_cache 17 | | \.pytest_cache 18 | | __pycache__ 19 | | .*\.egg-info 20 | # Git 21 | | \.git 22 | | \.github 23 | )/ 24 | ''' 25 | 26 | [tool.isort] 27 | profile = "black" 28 | known_first_party = ["tests", "detcon"] 29 | extend_skip = ["data", "logs", "lightning_logs"] 30 | skip_gitignore = true 31 | color_output = true 32 | 33 | [tool.pytest.ini_options] 34 | # Skip slow tests by default 35 | addopts = "-m 'not slow'" 36 | filterwarnings = [ 37 | "ignore:.*Create unlinked descriptors is going to go away.*:DeprecationWarning", 38 | # https://github.com/tensorflow/tensorboard/pull/5138 39 | "ignore:.*is a deprecated alias for the builtin.*:DeprecationWarning", 40 | ] 41 | markers = [ 42 | "slow: marks tests as slow", 43 | ] 44 | norecursedirs = [ 45 | ".ipynb_checkpoints", 46 | "data", 47 | "__pycache__", 48 | ] 49 | testpaths = [ 50 | "tests", 51 | ] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | kornia 4 | einops 5 | pytorch_lightning 6 | torch-ema 7 | pytest 8 | black 9 | isort[colors] 10 | flake8 11 | omegaconf 12 | pycocotools -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/isaaccorley/detcon-pytorch/a5e03faf0c27bdbe64b72625873c0b2d3a696f04/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import pytorch_lightning as pl 3 | 4 | from detcon.datasets import VOCSSLDataModule 5 | 6 | ROOT = "data" 7 | 8 | 9 | @pytest.fixture(scope="package") 10 | def ssl_datamodule() -> pl.LightningDataModule: 11 | dm = VOCSSLDataModule(root=ROOT, batch_size=2) 12 | dm.setup() 13 | return dm 14 | -------------------------------------------------------------------------------- /tests/test_datasets.py: -------------------------------------------------------------------------------- 1 | from detcon.datasets import VOCSegmentationDataModule, VOCSSLDataModule 2 | 3 | ROOT = "data" 4 | 5 | 6 | def test_voc_segmentation_datamodule() -> None: 7 | dm = VOCSegmentationDataModule(root=ROOT) 8 | dm.setup() 9 | 10 | batch = next(iter(dm.train_dataloader())) 11 | batch = dm.on_before_batch_transfer(batch, 0) 12 | batch = dm.on_after_batch_transfer(batch, 0) 13 | 14 | batch = next(iter(dm.val_dataloader())) 15 | batch = dm.on_before_batch_transfer(batch, 0) 16 | batch = dm.on_after_batch_transfer(batch, 0) 17 | 18 | batch = next(iter(dm.test_dataloader())) 19 | batch = dm.on_before_batch_transfer(batch, 0) 20 | batch = dm.on_after_batch_transfer(batch, 0) 21 | 22 | 23 | def test_voc_ssl_segmentation_datamodule() -> None: 24 | dm = VOCSSLDataModule(root=ROOT) 25 | dm.setup() 26 | 27 | batch = next(iter(dm.train_dataloader())) 28 | batch = dm.on_before_batch_transfer(batch, 0) 29 | batch = dm.on_after_batch_transfer(batch, 0) 30 | -------------------------------------------------------------------------------- /tests/test_losses.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import pytest 4 | import torch 5 | 6 | from detcon.losses import DetConBLoss 7 | 8 | 9 | @pytest.fixture 10 | def loss_inputs() -> Dict[str, torch.Tensor]: 11 | num_samples = 16 12 | batch_size = 2 13 | dim = 256 14 | num_classes = 21 15 | return { 16 | "pred1": torch.randn(batch_size, num_samples, dim), 17 | "pred2": torch.randn(batch_size, num_samples, dim), 18 | "target1": torch.randn(batch_size, num_samples, dim), 19 | "target2": torch.randn(batch_size, num_samples, dim), 20 | "pind1": torch.randint(low=0, high=num_classes, size=(batch_size, num_samples)), 21 | "pind2": torch.randint(low=0, high=num_classes, size=(batch_size, num_samples)), 22 | "tind1": torch.randint(low=0, high=num_classes, size=(batch_size, num_samples)), 23 | "tind2": torch.randint(low=0, high=num_classes, size=(batch_size, num_samples)), 24 | } 25 | 26 | 27 | def test_detconb_loss(loss_inputs: Dict[str, torch.Tensor]) -> None: 28 | loss = DetConBLoss() 29 | loss(**loss_inputs) 30 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Generator 2 | 3 | import pytest 4 | import pytorch_lightning as pl 5 | from _pytest.monkeypatch import MonkeyPatch 6 | 7 | from detcon.models import DetConB 8 | 9 | 10 | def mocked_log(*args: Any, **kwargs: Any) -> None: 11 | pass 12 | 13 | 14 | class TestDetConB: 15 | @pytest.fixture 16 | def module(self, monkeypatch: Generator[MonkeyPatch, None, None]) -> DetConB: 17 | module = DetConB() 18 | monkeypatch.setattr(module, "log", mocked_log) 19 | return module 20 | 21 | def test_configure_optimizers(self, module: DetConB) -> None: 22 | module.configure_optimizers() 23 | 24 | def test_training( 25 | self, ssl_datamodule: pl.LightningDataModule, module: DetConB 26 | ) -> None: 27 | batch = next(iter(ssl_datamodule.train_dataloader())) 28 | batch = ssl_datamodule.on_before_batch_transfer(batch, 0) 29 | batch = ssl_datamodule.on_after_batch_transfer(batch, 0) 30 | module.training_step(batch, 0) 31 | --------------------------------------------------------------------------------