├── .gitignore ├── LICENSE ├── README.md ├── assets └── can.png ├── requirements.txt ├── scripts └── extract_encoder_weights.py ├── src ├── data.py ├── loss.py ├── model.py ├── network │ ├── decoder.py │ ├── encoder.py │ └── pos_embed.py └── pl_utils.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | output/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/#use-with-ide 113 | .pdm.toml 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ 164 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Ben Conrad 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CAN: Contrastive Masked Autoencoders and Noise Prediction Pretraining 2 | 3 | PyTorch reimplementation of ["A simple, efficient and scalable contrastive masked autoencoder for learning visual representations"](https://arxiv.org/abs/2210.16870). 4 | 5 | 6 |

7 | 8 |

9 | 10 | ## Requirements 11 | - Python 3.8+ 12 | - `pip install -r requirements` 13 | 14 | ## Usage 15 | To pretrain a ViT-b/16 network run: 16 | ``` 17 | python train.py --accelerator gpu --devices 1 --precision 16 --data.root path/to/data/ 18 | --max_epochs 1000 --data.batch_size 256 --model.encoder_name vit_base_patch16 19 | --model.mask_ratio 0.5 --model.weight_contrast 0.03 --model.weight_recon 0.67 20 | --model.weight_denoise 0.3 21 | ``` 22 | - Run `python train.py --help` for descriptions of all options. 23 | - `--model.encoder_name` can be one of `vit_tiny_patch16, vit_small_patch16, vit_base_patch16, vit_large_patch16, vit_huge_patch14`. 24 | 25 | ### Using a Pretrained Model 26 | Encoder weights can be extracted from a pretraining checkpoint file by running: 27 | ``` 28 | python scripts/extract_encoder_weights.py -c path/to/checkpoint/file 29 | ``` 30 | You can then initialize a ViT model with these weights with the following: 31 | ```python 32 | import torch 33 | from timm.models.vision_transformer import VisionTransformer 34 | 35 | weights = torch.load("path/to/weights/file") 36 | 37 | # Assuming weights are for a ViT-b/16 model 38 | model = VisionTransformer( 39 | patch_size=16, 40 | embed_dim=768, 41 | depth=12, 42 | num_heads=12, 43 | ) 44 | model.load_state_dict(weights) 45 | ``` 46 | - __Note__: `VisionTransformer` arguments should match the those used during pretraining (e.g. ViT-b/16, ViT-l/16, etc.). 47 | 48 | ## Citation 49 | ```bibtex 50 | @article{mishra2022simple, 51 | title={A simple, efficient and scalable contrastive masked autoencoder for learning visual representations}, 52 | author={Mishra, Shlok and Robinson, Joshua and Chang, Huiwen and Jacobs, David and Sarna, Aaron and Maschinot, Aaron and Krishnan, Dilip}, 53 | journal={arXiv preprint arXiv:2210.16870}, 54 | year={2022} 55 | } 56 | ``` 57 | -------------------------------------------------------------------------------- /assets/can.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bwconrad/can/f75905ca8f388a04e74b3e3373fa75fece7d801a/assets/can.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.6.0 2 | numpy==1.23.4 3 | Pillow==9.3.0 4 | pytorch_lightning[extra]==1.8.1 5 | timm==0.6.11 6 | torch==1.13.0 7 | torchvision==0.14.0 8 | transformers==4.21.0 9 | -------------------------------------------------------------------------------- /scripts/extract_encoder_weights.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to extract the encoder's state_dict from a checkpoint file 3 | """ 4 | 5 | from argparse import ArgumentParser 6 | 7 | import torch 8 | 9 | if __name__ == "__main__": 10 | parser = ArgumentParser() 11 | parser.add_argument("--checkpoint", "-c", type=str, required=True) 12 | parser.add_argument("--output", "-o", type=str, default="weights.pt") 13 | parser.add_argument("--prefix", "-p", type=str, default="encoder") 14 | 15 | args = parser.parse_args() 16 | 17 | checkpoint = torch.load(args.checkpoint, map_location="cpu") 18 | checkpoint = checkpoint["state_dict"] 19 | 20 | newmodel = {} 21 | for k, v in checkpoint.items(): 22 | if not k.startswith(args.prefix): 23 | continue 24 | 25 | k = k.replace(args.prefix + ".", "") 26 | newmodel[k] = v 27 | 28 | with open(args.output, "wb") as f: 29 | torch.save(newmodel, f) 30 | -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from typing import Callable, Tuple 4 | 5 | import pytorch_lightning as pl 6 | import torch 7 | import torch.utils.data as data 8 | from PIL import Image 9 | from torch.utils.data import DataLoader 10 | from torchvision.transforms import (ColorJitter, Compose, GaussianBlur, 11 | Normalize, RandomApply, RandomGrayscale, 12 | RandomHorizontalFlip, RandomResizedCrop, 13 | ToTensor) 14 | 15 | 16 | class DuelViewDataModule(pl.LightningDataModule): 17 | def __init__( 18 | self, 19 | root: str, 20 | batch_size: int = 256, 21 | workers: int = 4, 22 | num_val_samples: int = 1000, 23 | crop_size: int = 224, 24 | min_scale: float = 0.08, 25 | max_scale: float = 1.0, 26 | brightness: float = 0.8, 27 | contrast: float = 0.8, 28 | saturation: float = 0.8, 29 | hue: float = 0.2, 30 | color_jitter_prob: float = 0.8, 31 | gray_scale_prob: float = 0.2, 32 | flip_prob: float = 0.5, 33 | gaussian_prob: float = 0.5, 34 | mean: Tuple[float, float, float] = (0.485, 0.456, 0.406), 35 | std: Tuple[float, float, float] = (0.228, 0.224, 0.225), 36 | ): 37 | """Duel view data module 38 | 39 | Args: 40 | root: Path to image directory 41 | batch_size: Number of batch samples 42 | workers: Number of data workers 43 | num_val_samples: Number of samples to leave out for a validation set 44 | crop_size: Size of image crop 45 | min_scale: Minimum crop scale 46 | max_scale: Maximum crop scale 47 | brightness: Brightness intensity 48 | contrast: Contast intensity 49 | saturation: Saturation intensity 50 | hue: Hue intensity 51 | color_jitter_prob: Probability of applying color jitter 52 | gray_scale_prob: Probability of converting to grayscale 53 | flip_prob: Probability of applying horizontal flip 54 | gaussian_prob: Probability of applying Gaussian blurring 55 | mean: Image normalization channel means 56 | std: Image normalization channel standard deviations 57 | """ 58 | super().__init__() 59 | self.save_hyperparameters() 60 | self.root = root 61 | self.batch_size = batch_size 62 | self.workers = workers 63 | self.num_val_samples = num_val_samples 64 | self.crop_size = crop_size 65 | self.min_scale = min_scale 66 | self.max_scale = max_scale 67 | self.brightness = brightness 68 | self.contrast = contrast 69 | self.saturation = saturation 70 | self.hue = hue 71 | self.color_jitter_prob = color_jitter_prob 72 | self.gray_scale_prob = gray_scale_prob 73 | self.flip_prob = flip_prob 74 | self.gaussian_prob = gaussian_prob 75 | self.mean = mean 76 | self.std = std 77 | 78 | self.transforms = MultiViewTransform( 79 | Transforms( 80 | crop_size=self.crop_size, 81 | min_scale=self.min_scale, 82 | max_scale=self.max_scale, 83 | brightness=self.brightness, 84 | contrast=self.contrast, 85 | saturation=self.saturation, 86 | hue=self.hue, 87 | color_jitter_prob=self.color_jitter_prob, 88 | gray_scale_prob=self.gray_scale_prob, 89 | gaussian_prob=self.gaussian_prob, 90 | flip_prob=self.flip_prob, 91 | mean=self.mean, 92 | std=self.std, 93 | ), 94 | n_views=2, 95 | ) 96 | 97 | def setup(self, stage: str = "fit"): 98 | if stage == "fit": 99 | dataset = SimpleDataset(self.root, self.transforms) 100 | 101 | # Randomly take num_val_samples images for a validation set 102 | self.train_dataset, self.val_dataset = data.random_split( 103 | dataset, 104 | [len(dataset) - self.num_val_samples, self.num_val_samples], 105 | generator=torch.Generator().manual_seed(42), # Fixed seed 106 | ) 107 | 108 | def train_dataloader(self): 109 | return DataLoader( 110 | self.train_dataset, 111 | batch_size=self.batch_size, 112 | shuffle=True, 113 | num_workers=self.workers, 114 | pin_memory=True, 115 | drop_last=True, 116 | persistent_workers=True, 117 | ) 118 | 119 | def val_dataloader(self): 120 | return DataLoader( 121 | self.val_dataset, 122 | batch_size=self.batch_size, 123 | shuffle=False, 124 | num_workers=self.workers, 125 | pin_memory=True, 126 | drop_last=False, 127 | persistent_workers=True, 128 | ) 129 | 130 | 131 | class SimpleDataset(data.Dataset): 132 | def __init__(self, root: str, transforms: Callable): 133 | """Image dataset from nested directory 134 | 135 | Args: 136 | root: Path to directory 137 | transforms: Image augmentations 138 | """ 139 | super().__init__() 140 | self.root = root 141 | self.paths = [ 142 | f for f in glob(f"{root}/**/*", recursive=True) if os.path.isfile(f) 143 | ] 144 | self.transforms = transforms 145 | 146 | print(f"Loaded {len(self.paths)} images from {root}") 147 | 148 | def __getitem__(self, index: int): 149 | img = Image.open(self.paths[index]).convert("RGB") 150 | img = self.transforms(img) 151 | return img 152 | 153 | def __len__(self): 154 | return len(self.paths) 155 | 156 | 157 | class MultiViewTransform: 158 | def __init__(self, transforms: Callable, n_views: int = 2): 159 | """Wrapper class to apply transforms multiple times on an image 160 | 161 | Args: 162 | transforms: Image augmentation pipeline 163 | n_views: Number of augmented views to return 164 | """ 165 | self.transforms = transforms 166 | self.n_views = n_views 167 | 168 | def __call__(self, img: Image.Image): 169 | return [self.transforms(img) for _ in range(self.n_views)] 170 | 171 | 172 | class Transforms: 173 | def __init__( 174 | self, 175 | crop_size: int = 224, 176 | min_scale: float = 0.08, 177 | max_scale: float = 1.0, 178 | brightness: float = 0.8, 179 | contrast: float = 0.8, 180 | saturation: float = 0.8, 181 | hue: float = 0.2, 182 | color_jitter_prob: float = 0.8, 183 | gray_scale_prob: float = 0.2, 184 | gaussian_prob: float = 0.5, 185 | flip_prob: float = 0.5, 186 | mean: Tuple[float, float, float] = (0.485, 0.456, 0.406), 187 | std: Tuple[float, float, float] = (0.228, 0.224, 0.225), 188 | ): 189 | """Augmentation pipeline for contrastive learning 190 | 191 | Args: 192 | crop_size: Size of image crop 193 | min_scale: Minimum crop scale 194 | max_scale: Maximum crop scale 195 | brightness: Brightness intensity 196 | contast: Contast intensity 197 | saturation: Saturation intensity 198 | hue: Hue intensity 199 | color_jitter_prob: Probability of applying color jitter 200 | gray_scale_prob: Probability of converting to grayscale 201 | gaussian_prob: Probability of applying Gausian blurring 202 | flip_prob: Probability of applying horizontal flip 203 | mean: Image normalization means 204 | std: Image normalization standard deviations 205 | """ 206 | super().__init__() 207 | 208 | self.transforms = Compose( 209 | [ 210 | RandomResizedCrop(size=crop_size, scale=(min_scale, max_scale)), 211 | RandomApply( 212 | [ 213 | ColorJitter( 214 | brightness=brightness, # type:ignore 215 | contrast=contrast, # type:ignore 216 | saturation=saturation, # type:ignore 217 | hue=hue, # type:ignore 218 | ) 219 | ], 220 | p=color_jitter_prob, 221 | ), 222 | RandomGrayscale(p=gray_scale_prob), 223 | RandomApply([GaussianBlur(kernel_size=23)], p=gaussian_prob), 224 | RandomHorizontalFlip(p=flip_prob), 225 | ToTensor(), 226 | Normalize(mean=mean, std=std), 227 | ] 228 | ) 229 | 230 | def __call__(self, img: Image.Image): 231 | return self.transforms(img) 232 | -------------------------------------------------------------------------------- /src/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | import torch.nn.functional as F 4 | 5 | 6 | def masked_mse_loss( 7 | pred: torch.Tensor, 8 | target: torch.Tensor, 9 | mask: torch.Tensor, 10 | normalize_targets: bool = False, 11 | ): 12 | """MSE loss on masked patches 13 | 14 | Args: 15 | pred: B x num_patches x D tensor of predict patches 16 | target: B x num_patches x D tensor of target patch values 17 | mask: B x num_patches binary mask with masked patches marked with 1 18 | 19 | Return: 20 | loss: Masked mean square error loss 21 | """ 22 | # Normalize target pixel values 23 | if normalize_targets: 24 | mean = target.mean(dim=-1, keepdim=True) 25 | var = target.var(dim=-1, keepdim=True) 26 | target = (target - mean) / (var + 1.0e-6) ** 0.5 27 | 28 | # Calculate MSE loss 29 | loss = (pred - target) ** 2 30 | loss = loss.mean(dim=-1) # Per patch loss 31 | loss = (loss * mask).sum() / mask.sum() # Mean of masked patches 32 | 33 | return loss 34 | 35 | 36 | """ 37 | Modified from: 38 | https://github.com/vturrisi/solo-learn/blob/main/solo/losses/simclr.py 39 | https://github.com/vturrisi/solo-learn/blob/main/solo/utils/misc.py 40 | """ 41 | 42 | 43 | def info_nce_loss(z: torch.Tensor, temperature: float = 0.1) -> torch.Tensor: 44 | """Computes SimCLR's loss given batch of projected features z 45 | from different views, a positive boolean mask of all positives and 46 | a negative boolean mask of all negatives. 47 | 48 | Args: 49 | z (torch.Tensor): (2*B) x D tensor containing features from the views. 50 | 51 | Return: 52 | torch.Tensor: SimCLR loss. 53 | """ 54 | 55 | z = F.normalize(z, dim=-1) 56 | gathered_z = gather(z) 57 | 58 | sim = torch.exp(torch.einsum("if, jf -> ij", z, gathered_z) / temperature) 59 | 60 | indexes = torch.arange(z.size(0) // 2, device=z.device).repeat(2) 61 | gathered_indexes = gather(indexes) 62 | 63 | indexes = indexes.unsqueeze(0) 64 | gathered_indexes = gathered_indexes.unsqueeze(0) 65 | 66 | # positives 67 | pos_mask = indexes.t() == gathered_indexes 68 | pos_mask[:, z.size(0) * get_rank() :].fill_diagonal_(0) 69 | 70 | # negatives 71 | neg_mask = indexes.t() != gathered_indexes 72 | 73 | pos = torch.sum(sim * pos_mask, 1) 74 | neg = torch.sum(sim * neg_mask, 1) 75 | loss = -(torch.mean(torch.log(pos / (pos + neg)))) 76 | return loss 77 | 78 | 79 | def get_rank(): 80 | if dist.is_available() and dist.is_initialized(): 81 | return dist.get_rank() 82 | return 0 83 | 84 | 85 | class GatherLayer(torch.autograd.Function): 86 | """ 87 | Gathers tensors from all process and supports backward propagation 88 | for the gradients across processes. 89 | """ 90 | 91 | @staticmethod 92 | def forward(ctx, x): 93 | if dist.is_available() and dist.is_initialized(): 94 | output = [torch.zeros_like(x) for _ in range(dist.get_world_size())] 95 | dist.all_gather(output, x) 96 | else: 97 | output = [x] 98 | return tuple(output) 99 | 100 | @staticmethod 101 | def backward(ctx, *grads): 102 | if dist.is_available() and dist.is_initialized(): 103 | all_gradients = torch.stack(grads) 104 | dist.all_reduce(all_gradients) 105 | grad_out = all_gradients[get_rank()] 106 | else: 107 | grad_out = grads[0] 108 | return grad_out 109 | 110 | 111 | def gather(X, dim=0): 112 | """Gathers tensors from all processes, supporting backward propagation.""" 113 | return torch.cat(GatherLayer.apply(X), dim=dim) 114 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional, Tuple 3 | 4 | import pytorch_lightning as pl 5 | import torch 6 | import torch.nn as nn 7 | from einops import rearrange 8 | from timm.optim.optim_factory import param_groups_weight_decay 9 | from torch.optim import SGD, Adam, AdamW 10 | from torch.optim.lr_scheduler import LambdaLR 11 | from torchvision.utils import make_grid, save_image 12 | from transformers.optimization import get_cosine_schedule_with_warmup 13 | 14 | from src.loss import info_nce_loss, masked_mse_loss 15 | from src.network.decoder import VitDecoder 16 | from src.network.encoder import build_encoder 17 | from src.network.pos_embed import get_1d_sincos_pos_embed 18 | 19 | 20 | class CANModel(pl.LightningModule): 21 | def __init__( 22 | self, 23 | img_size: int = 224, 24 | encoder_name: str = "vit_base_patch16", 25 | decoder_embed_dim: int = 512, 26 | decoder_depth: int = 8, 27 | decoder_num_heads: int = 16, 28 | decoder_embed_unmasked_tokens: bool = True, 29 | projector_hidden_dim: int = 4096, 30 | projector_out_dim: int = 128, 31 | noise_embed_in_dim: int = 768, 32 | noise_embed_hidden_dim: int = 768, 33 | mask_ratio: float = 0.5, 34 | norm_pixel_loss: bool = True, 35 | temperature: float = 0.1, 36 | noise_std_max: float = 0.05, 37 | weight_contrast: float = 0.03, 38 | weight_recon: float = 0.67, 39 | weight_denoise: float = 0.3, 40 | lr: float = 2.5e-4, 41 | optimizer: str = "adamw", 42 | betas: Tuple[float, float] = (0.9, 0.95), 43 | weight_decay: float = 0.05, 44 | momentum: float = 0.9, 45 | scheduler: str = "cosine", 46 | warmup_epochs: int = 0, 47 | channel_last: bool = False, 48 | ): 49 | """Contrastive Masked Autoencoder and Noise Prediction Pretraining Model 50 | 51 | Args: 52 | img_size: Size of input image 53 | encoder_name: Name of encoder network 54 | decoder_embed_dim: Embed dim of decoder 55 | decoder_depth: Number of transformer blocks in the decoder 56 | decoder_num_heads: Number of attention heads in the decoder 57 | decoder_embed_unmasked_tokens: Apply decoder embedding layer on both masked and unmasked tokens. 58 | Else only apply to masked tokens 59 | projector_hidden_dim: Hidden dim of projector 60 | projector_out_dim: Output dim of projector 61 | noise_embed_in_dim: Dim of noise level sinusoidal embedding 62 | noise_embed_hidden_dim: Hidden dim of noising embed MLP 63 | mask_ratio: Ratio of input image patches to mask 64 | norm_pixel_loss: Calculate loss using normalized pixel value targets 65 | temperature: Temperature for contrastive loss 66 | noise_std_max: Maximum noise standard deviation 67 | weight_contrast: Weight for contrastive loss 68 | weight_recon: Weight for patch reconstruction loss 69 | weight_denoise: Weight for denoising loss 70 | lr: Learning rate (should be scaled with batch size. i.e. lr = base_lr*batch_size/256) 71 | optimizer: Name of optimizer (adam | adamw | sgd) 72 | betas: Adam beta parameters 73 | weight_decay: Optimizer weight decay 74 | momentum: SGD momentum parameter 75 | scheduler: Name of learning rate scheduler [cosine, none] 76 | warmup_epochs: Number of warmup epochs 77 | channel_last: Change to channel last memory format for possible training speed up 78 | """ 79 | super().__init__() 80 | self.save_hyperparameters() 81 | self.img_size = img_size 82 | self.encoder_name = encoder_name 83 | self.decoder_embed_dim = decoder_embed_dim 84 | self.decoder_depth = decoder_depth 85 | self.decoder_num_heads = decoder_num_heads 86 | self.decoder_embed_unmasked_tokens = decoder_embed_unmasked_tokens 87 | self.projector_hidden_dim = projector_hidden_dim 88 | self.projector_out_dim = projector_out_dim 89 | self.noise_embed_in_dim = noise_embed_in_dim 90 | self.noise_embed_hidden_dim = noise_embed_hidden_dim 91 | self.mask_ratio = mask_ratio 92 | self.norm_pixel_loss = norm_pixel_loss 93 | self.temperature = temperature 94 | self.noise_std_max = noise_std_max 95 | self.weight_contrast = weight_contrast 96 | self.weight_recon = weight_recon 97 | self.weight_denoise = weight_denoise 98 | self.lr = lr 99 | self.optimizer = optimizer 100 | self.betas = betas 101 | self.weight_decay = weight_decay 102 | self.momentum = momentum 103 | self.scheduler = scheduler 104 | self.warmup_epochs = warmup_epochs 105 | self.channel_last = channel_last 106 | 107 | # Initialize networks 108 | self.encoder, self.patch_size = build_encoder( 109 | encoder_name, img_size=self.img_size 110 | ) 111 | self.decoder = VitDecoder( 112 | patch_size=self.patch_size, 113 | num_patches=self.encoder.patch_embed.num_patches, 114 | in_dim=self.encoder.embed_dim, 115 | embed_dim=self.decoder_embed_dim, 116 | depth=self.decoder_depth, 117 | num_heads=self.decoder_num_heads, 118 | embed_unmasked_tokens=self.decoder_embed_unmasked_tokens, 119 | ) 120 | self.projector = nn.Sequential( 121 | nn.Linear(self.encoder.embed_dim, self.projector_hidden_dim), 122 | nn.BatchNorm1d(self.projector_hidden_dim), 123 | nn.ReLU(), 124 | nn.Linear(self.projector_hidden_dim, self.projector_hidden_dim), 125 | nn.BatchNorm1d(self.projector_hidden_dim), 126 | nn.ReLU(), 127 | nn.Linear(self.projector_hidden_dim, self.projector_out_dim), 128 | ) 129 | # Based on updated openreview version (as of Nov 17, 2022), the MLP is two layers 130 | # without BN (maybe?) and input and hidden dims the same as the encoder embedding 131 | self.noise_embed = nn.Sequential( 132 | nn.Linear(self.noise_embed_in_dim, self.noise_embed_hidden_dim), 133 | nn.ReLU(), 134 | nn.Linear( 135 | self.noise_embed_hidden_dim, 136 | self.encoder.embed_dim 137 | if self.decoder_embed_unmasked_tokens 138 | else self.decoder_embed_dim, 139 | ), 140 | ) 141 | 142 | # Change to channel last memory format 143 | # https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html 144 | if self.channel_last: 145 | self = self.to(memory_format=torch.channels_last) 146 | 147 | def patchify(self, x: torch.Tensor): 148 | """Rearrange image into patches 149 | 150 | Args: 151 | x: Tensor of size (b, 3, h, w) 152 | 153 | Return: 154 | x: Tensor of size (b, h*w, patch_size^2 * 3) 155 | """ 156 | assert x.shape[2] == x.shape[3] and x.shape[2] % self.patch_size == 0 157 | 158 | return rearrange( 159 | x, 160 | "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", 161 | p1=self.patch_size, 162 | p2=self.patch_size, 163 | ) 164 | 165 | def unpatchify(self, x: torch.Tensor): 166 | """Rearrange patches back to an image 167 | 168 | Args: 169 | x: Tensor of size (b, h*w, patch_size^2 * 3) 170 | 171 | Return: 172 | x: Tensor of size (b, 3, h, w) 173 | """ 174 | h = w = int(x.shape[1] ** 0.5) 175 | return rearrange( 176 | x, 177 | " b (h w) (p1 p2 c) -> b c (h p1) (w p2)", 178 | p1=self.patch_size, 179 | p2=self.patch_size, 180 | h=h, 181 | w=w, 182 | ) 183 | 184 | def log_samples(self, inp: torch.Tensor, pred: torch.Tensor, mask: torch.Tensor): 185 | """Log sample images""" 186 | # Only log up to 16 images 187 | inp, pred, mask = inp[:16], pred[:16], mask[:16] 188 | 189 | # Patchify the input image 190 | inp = self.patchify(inp) 191 | 192 | # Merge original and predicted patches 193 | pred = pred * mask[:, :, None] 194 | inp = inp * (1 - mask[:, :, None]) 195 | res = self.unpatchify(inp) + self.unpatchify(pred) 196 | 197 | # Log result 198 | if "CSVLogger" in str(self.logger.__class__): 199 | path = os.path.join( 200 | self.logger.log_dir, # type:ignore 201 | "samples", 202 | ) 203 | if not os.path.exists(path): 204 | os.makedirs(path) 205 | filename = os.path.join(path, str(self.current_epoch) + "ep.png") 206 | save_image(res, filename, nrow=4, normalize=True) 207 | elif "WandbLogger" in str(self.logger.__class__): 208 | grid = make_grid(res, nrow=4, normalize=True) 209 | self.logger.log_image(key="sample", images=[grid]) # type:ignore 210 | 211 | @torch.no_grad() 212 | def add_noise(self, x: torch.Tensor): 213 | """Add noise to input image 214 | 215 | Args: 216 | x: Tensor of size (b, c, h, w) 217 | 218 | Return: 219 | x_noise: x tensor with added Gaussian noise of size (b, c, h, w) 220 | noise: Noise tensor of size (b, c, h, w) 221 | std: Noise standard deviation (noise level) tensor of size (b,) 222 | """ 223 | # Sample std uniformly from [0, self.noise_std_max] 224 | std = torch.rand(x.size(0), device=x.device) * self.noise_std_max 225 | 226 | # Sample noise 227 | noise = torch.randn_like(x) * std[:, None, None, None] 228 | 229 | # Add noise to x 230 | x_noise = x + noise 231 | 232 | return x_noise, noise, std 233 | 234 | def shared_step( 235 | self, 236 | x: Tuple[torch.Tensor, torch.Tensor], 237 | mode: str = "train", 238 | batch_idx: Optional[int] = None, 239 | ): 240 | x1, x2 = x 241 | 242 | if self.channel_last: 243 | x1 = x1.to(memory_format=torch.channels_last) # type:ignore 244 | x2 = x2.to(memory_format=torch.channels_last) # type:ignore 245 | 246 | # Add noise to views 247 | x1_noise, noise1, std1 = self.add_noise(x1) 248 | x2_noise, noise2, std2 = self.add_noise(x2) 249 | 250 | # Mask and extract features 251 | z1, mask1, idx_unshuffle1 = self.encoder(x1_noise, self.mask_ratio) 252 | z2, mask2, idx_unshuffle2 = self.encoder(x2_noise, self.mask_ratio) 253 | 254 | # Pass mean encoder features through projector 255 | u1 = self.projector(torch.mean(z1[:, 1:, :], dim=1)) # Skip cls token 256 | u2 = self.projector(torch.mean(z2[:, 1:, :], dim=1)) 257 | 258 | # Generate noise level embedding 259 | p1 = self.noise_embed( 260 | get_1d_sincos_pos_embed(std1, dim=self.noise_embed_in_dim) 261 | ) 262 | p2 = self.noise_embed( 263 | get_1d_sincos_pos_embed(std2, dim=self.noise_embed_in_dim) 264 | ) 265 | 266 | # Predict masked patches and noise 267 | x1_pred = self.decoder(z1, idx_unshuffle1, p1) 268 | x2_pred = self.decoder(z2, idx_unshuffle2, p2) 269 | 270 | # Contrastive loss 271 | loss_contrast = info_nce_loss(torch.cat([u1, u2]), temperature=self.temperature) 272 | 273 | # Patch reconstruction loss 274 | loss_recon = ( 275 | masked_mse_loss(x1_pred, self.patchify(x1), mask1, self.norm_pixel_loss) 276 | + masked_mse_loss(x2_pred, self.patchify(x2), mask2, self.norm_pixel_loss) 277 | ) / 2 278 | 279 | # Denoising loss 280 | loss_denoise = ( 281 | masked_mse_loss( 282 | x1_pred, self.patchify(noise1), 1 - mask1, self.norm_pixel_loss 283 | ) 284 | + masked_mse_loss( 285 | x2_pred, self.patchify(noise2), 1 - mask2, self.norm_pixel_loss 286 | ) 287 | ) / 2 288 | 289 | # Combined loss 290 | loss = ( 291 | self.weight_contrast * loss_contrast 292 | + self.weight_recon * loss_recon 293 | + self.weight_denoise * loss_denoise 294 | ) 295 | 296 | # Log 297 | self.log(f"{mode}_loss", loss) 298 | self.log(f"{mode}_loss_contrast", loss_contrast) 299 | self.log(f"{mode}_loss_recon", loss_recon) 300 | self.log(f"{mode}_loss_denoise", loss_denoise) 301 | if mode == "val" and batch_idx == 0: 302 | self.log_samples(x1, x1_pred, mask1) 303 | 304 | return {"loss": loss} 305 | 306 | def training_step(self, x, _): 307 | self.log("lr", self.trainer.optimizers[0].param_groups[0]["lr"], prog_bar=True) 308 | return self.shared_step(x, mode="train") 309 | 310 | def validation_step(self, x, batch_idx): 311 | return self.shared_step(x, mode="val", batch_idx=batch_idx) 312 | 313 | def configure_optimizers(self): 314 | """Initialize optimizer and learning rate schedule""" 315 | # Set weight decay to 0 for bias and norm layers (following MAE) 316 | params = param_groups_weight_decay( 317 | self.encoder, self.weight_decay 318 | ) + param_groups_weight_decay(self.decoder, self.weight_decay) 319 | 320 | # Optimizer 321 | if self.optimizer == "adam": 322 | optimizer = Adam( 323 | params, 324 | lr=self.lr, 325 | betas=self.betas, 326 | weight_decay=self.weight_decay, 327 | ) 328 | elif self.optimizer == "adamw": 329 | optimizer = AdamW( 330 | params, 331 | lr=self.lr, 332 | betas=self.betas, 333 | weight_decay=self.weight_decay, 334 | ) 335 | elif self.optimizer == "sgd": 336 | optimizer = SGD( 337 | params, 338 | lr=self.lr, 339 | momentum=self.momentum, 340 | weight_decay=self.weight_decay, 341 | ) 342 | else: 343 | raise ValueError( 344 | f"{self.optimizer} is not an available optimizer. Should be one of ['adam', 'adamw', 'sgd']" 345 | ) 346 | 347 | # Learning rate schedule 348 | if self.scheduler == "cosine": 349 | epoch_steps = ( 350 | self.trainer.estimated_stepping_batches 351 | // self.trainer.max_epochs # type:ignore 352 | ) 353 | scheduler = get_cosine_schedule_with_warmup( 354 | optimizer, 355 | num_training_steps=self.trainer.estimated_stepping_batches, # type:ignore 356 | num_warmup_steps=epoch_steps * self.warmup_epochs, 357 | ) 358 | elif self.scheduler == "none": 359 | scheduler = LambdaLR(optimizer, lambda _: 1) 360 | else: 361 | raise ValueError( 362 | f"{self.scheduler} is not an available optimizer. Should be one of ['cosine', 'none']" 363 | ) 364 | 365 | return { 366 | "optimizer": optimizer, 367 | "lr_scheduler": { 368 | "scheduler": scheduler, 369 | "interval": "step", 370 | }, 371 | } 372 | -------------------------------------------------------------------------------- /src/network/decoder.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Optional 3 | 4 | import torch 5 | import torch.nn as nn 6 | from einops import repeat 7 | from timm.models.vision_transformer import Block 8 | 9 | from src.network.pos_embed import get_2d_sincos_pos_embed 10 | 11 | 12 | class VitDecoder(nn.Module): 13 | def __init__( 14 | self, 15 | patch_size: int = 16, 16 | num_patches: int = 196, 17 | in_channels: int = 3, 18 | depth: int = 8, 19 | embed_dim: int = 512, 20 | in_dim: int = 768, 21 | num_heads: int = 16, 22 | mlp_ratio: int = 4, 23 | norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), # type:ignore 24 | act_layer: nn.Module = nn.GELU, # type:ignore 25 | embed_unmasked_tokens: bool = True, 26 | ): 27 | super().__init__() 28 | self.embed_unmasked_tokens = embed_unmasked_tokens 29 | 30 | # Projection from encoder to decoder dim 31 | self.embed = nn.Linear(in_dim, embed_dim, bias=True) 32 | 33 | # Mask token 34 | self.mask_token = nn.Parameter( 35 | torch.zeros(1, 1, in_dim if embed_unmasked_tokens else embed_dim) 36 | ) 37 | 38 | # Sin-cos position embedding 39 | self.pos_embed = nn.Parameter( 40 | torch.zeros((1, num_patches + 1, embed_dim)), requires_grad=False 41 | ) 42 | 43 | self.blocks = nn.Sequential( 44 | *[ 45 | Block( 46 | dim=embed_dim, 47 | num_heads=num_heads, 48 | mlp_ratio=mlp_ratio, 49 | qkv_bias=True, 50 | norm_layer=norm_layer, # type:ignore 51 | act_layer=act_layer, # type:ignore 52 | ) 53 | for _ in range(depth) 54 | ] 55 | ) 56 | self.norm = norm_layer(embed_dim) 57 | self.head = nn.Linear(embed_dim, patch_size**2 * in_channels, bias=True) 58 | 59 | self.init_weights(num_patches) 60 | 61 | def init_weights(self, num_patches: int): 62 | # Initialize to sin-cos position embedding 63 | pos_embed = get_2d_sincos_pos_embed( 64 | self.pos_embed.shape[-1], 65 | int(num_patches**0.5), 66 | cls_token=True, 67 | ) 68 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 69 | 70 | # Mask token 71 | torch.nn.init.normal_(self.mask_token, std=0.02) 72 | 73 | # All other weights 74 | self.apply(self._init_weights) 75 | 76 | def _init_weights(self, m: nn.Module): 77 | if isinstance(m, nn.Linear): 78 | torch.nn.init.xavier_uniform_(m.weight) 79 | if isinstance(m, nn.Linear) and m.bias is not None: 80 | nn.init.constant_(m.bias, 0) 81 | elif isinstance(m, nn.LayerNorm): 82 | nn.init.constant_(m.bias, 0) 83 | nn.init.constant_(m.weight, 1.0) 84 | 85 | def forward( 86 | self, 87 | x: torch.Tensor, 88 | idx_unshuffle: torch.Tensor, 89 | p: Optional[torch.Tensor] = None, 90 | ): 91 | if not self.embed_unmasked_tokens: 92 | # Project only masked tokens to decoder embed size 93 | x = self.embed(x) 94 | 95 | # Append mask tokens to input 96 | L = idx_unshuffle.shape[1] 97 | B, L_unmasked, D = x.shape 98 | mask_tokens = self.mask_token.repeat(B, L + 1 - L_unmasked, 1) 99 | temp = torch.concat([x[:, 1:, :], mask_tokens], dim=1) # Skip cls token 100 | 101 | # Unshuffle tokens 102 | temp = torch.gather( 103 | temp, dim=1, index=repeat(idx_unshuffle, "b l -> b l d", d=D) 104 | ) 105 | 106 | # Add noise level embedding 107 | if p is not None: 108 | temp = temp + p[:, None, :] 109 | 110 | # Prepend cls token 111 | x = torch.cat([x[:, :1, :], temp], dim=1) 112 | 113 | if self.embed_unmasked_tokens: 114 | # Project masked and unmasked tokens to decoder embed size 115 | x = self.embed(x) 116 | 117 | # Add pos embed 118 | x = x + self.pos_embed 119 | 120 | # Apply transformer layers 121 | x = self.blocks(x) 122 | 123 | # Predict pixel values 124 | x = self.head(self.norm(x)) 125 | 126 | return x[:, 1:, :] # Don't return cls token 127 | 128 | 129 | def dec512d8b(patch_size: int, num_patches: int, in_dim: int, **kwargs): 130 | return VitDecoder( 131 | patch_size=patch_size, 132 | num_patches=num_patches, 133 | in_dim=in_dim, 134 | embed_dim=512, 135 | depth=8, 136 | num_heads=16, 137 | **kwargs, 138 | ) 139 | 140 | 141 | MODEL_DICT = {"dec512d8b": dec512d8b} 142 | 143 | 144 | def build_decoder(model: str, **kwargs): 145 | try: 146 | model_fn = MODEL_DICT[model] 147 | except: 148 | raise ValueError( 149 | f"{model} is not an available decoder. Should be one of {[k for k in MODEL_DICT.keys()]}" 150 | ) 151 | 152 | return model_fn(**kwargs) 153 | -------------------------------------------------------------------------------- /src/network/encoder.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import timm.models.vision_transformer as vision_transformer 4 | import torch 5 | import torch.nn as nn 6 | from einops import repeat 7 | 8 | from src.network.pos_embed import get_2d_sincos_pos_embed 9 | 10 | 11 | class VisionTransformer(vision_transformer.VisionTransformer): 12 | """Vision transformer for masked image modeling. 13 | Uses fixed sin-cos position embeddings 14 | """ 15 | 16 | def __init__(self, **kwargs): 17 | super(VisionTransformer, self).__init__(**kwargs) 18 | assert self.num_prefix_tokens == 1 # Must have cls token 19 | 20 | # Re-initialize with fixed sin-cos position embedding 21 | self.pos_embed = nn.Parameter( 22 | torch.zeros(self.pos_embed.shape), requires_grad=False 23 | ) 24 | self.init_pos_embed() 25 | 26 | def init_pos_embed(self): 27 | """Initialize sin-cos position embeddings""" 28 | pos_embed = get_2d_sincos_pos_embed( 29 | self.pos_embed.shape[-1], 30 | int(self.patch_embed.num_patches**0.5), 31 | cls_token=True, 32 | ) 33 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 34 | 35 | def random_masking(self, x: torch.Tensor, mask_ratio: float): 36 | """Randomly mask mask_ratio patches of an image 37 | 38 | Args: 39 | x: Tensor of shape B x L x D 40 | mask_ratio: Ratio of patches to mask 41 | 42 | Return: 43 | x_masked: Tensor of non-masked patches 44 | mask: Tensor of size B x L where the positions of masked 45 | patches are marked by 1 and else 0 46 | idx_unshuffle: Tensor of size B x L with the sorting order 47 | to unshuffle patches back to the original order 48 | """ 49 | B, L, D = x.shape 50 | 51 | # Number of patches to keep 52 | num_keep = int(L * (1 - mask_ratio)) 53 | 54 | # Sort array of random noise 55 | noise = torch.rand((B, L), device=x.device) 56 | idx_shuffle = torch.argsort(noise, dim=1) 57 | idx_unshuffle = torch.argsort(idx_shuffle, dim=1) # Undo shuffling 58 | 59 | # Keep indices of n_keep smallest values 60 | idx_keep = idx_shuffle[:, :num_keep] 61 | x_masked = torch.gather(x, dim=1, index=repeat(idx_keep, "b l -> b l d", d=D)) 62 | 63 | # Generate binary mask 64 | mask = torch.ones((B, L), device=x.device) 65 | mask[:, :num_keep] = 0 66 | mask = torch.gather(mask, dim=1, index=idx_unshuffle) 67 | 68 | return x_masked, mask, idx_unshuffle 69 | 70 | def forward(self, x: torch.Tensor, mask_ratio: float = 0.75): 71 | # Patch embed image 72 | x = self.patch_embed(x) 73 | 74 | # Add pos embed skipping cls token 75 | x = x + self.pos_embed[:, 1:, :] 76 | 77 | # Mask the image 78 | x, mask, idx_unshuffle = self.random_masking(x, mask_ratio) 79 | 80 | # Append the cls token 81 | cls_token = self.cls_token + self.pos_embed[:, 0, :] 82 | x = torch.cat((cls_token.expand(x.shape[0], -1, -1), x), dim=1) 83 | 84 | # Apply transformer layers 85 | x = self.norm_pre(x) 86 | x = self.blocks(x) 87 | x = self.norm(x) 88 | 89 | return x, mask, idx_unshuffle 90 | 91 | 92 | def build_encoder(model: str, **kwargs): 93 | try: 94 | model_fn, patch_size = MODEL_DICT[model] 95 | except: 96 | raise ValueError( 97 | f"{model} is not an available encoder. Should be one of {[k for k in MODEL_DICT.keys()]}" 98 | ) 99 | 100 | return model_fn(**kwargs), patch_size 101 | 102 | 103 | def vit_tiny_patch16(**kwargs): 104 | return VisionTransformer( 105 | patch_size=16, 106 | embed_dim=192, 107 | depth=12, 108 | num_heads=3, 109 | mlp_ratio=4, 110 | qkv_bias=True, 111 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 112 | weight_init="jax", 113 | **kwargs, 114 | ) 115 | 116 | 117 | def vit_small_patch16(**kwargs): 118 | return VisionTransformer( 119 | patch_size=16, 120 | embed_dim=384, 121 | depth=12, 122 | num_heads=6, 123 | mlp_ratio=4, 124 | qkv_bias=True, 125 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 126 | weight_init="jax", 127 | **kwargs, 128 | ) 129 | 130 | 131 | def vit_base_patch16(**kwargs): 132 | return VisionTransformer( 133 | patch_size=16, 134 | embed_dim=768, 135 | depth=12, 136 | num_heads=12, 137 | mlp_ratio=4, 138 | qkv_bias=True, 139 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 140 | weight_init="jax", 141 | **kwargs, 142 | ) 143 | 144 | 145 | def vit_large_patch16(**kwargs): 146 | return VisionTransformer( 147 | patch_size=16, 148 | embed_dim=1024, 149 | depth=24, 150 | num_heads=16, 151 | mlp_ratio=4, 152 | qkv_bias=True, 153 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 154 | weight_init="jax", 155 | **kwargs, 156 | ) 157 | 158 | 159 | def vit_huge_patch14(**kwargs): 160 | return VisionTransformer( 161 | patch_size=14, 162 | embed_dim=1280, 163 | depth=32, 164 | num_heads=16, 165 | mlp_ratio=4, 166 | qkv_bias=True, 167 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 168 | weight_init="jax", 169 | **kwargs, 170 | ) 171 | 172 | 173 | MODEL_DICT = { 174 | "vit_tiny_patch16": (vit_tiny_patch16, 16), 175 | "vit_small_patch16": (vit_small_patch16, 16), 176 | "vit_base_patch16": (vit_base_patch16, 16), 177 | "vit_large_patch16": (vit_large_patch16, 16), 178 | "vit_huge_patch14": (vit_huge_patch14, 14), 179 | } 180 | -------------------------------------------------------------------------------- /src/network/pos_embed.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def get_1d_sincos_pos_embed(x: torch.Tensor, dim: int): 8 | """From: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py""" 9 | half_dim = dim // 2 10 | emb = math.log(10000) / (half_dim - 1) 11 | emb = torch.exp(torch.arange(half_dim, device=x.device) * -emb) 12 | emb = x[:, None] * emb[None, :] 13 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 14 | return emb 15 | 16 | 17 | """From: https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20""" 18 | 19 | 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | 47 | return emb 48 | 49 | 50 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 51 | """ 52 | embed_dim: output dimension for each position 53 | pos: a list of positions to be encoded: size (M,) 54 | out: (M, D) 55 | """ 56 | assert embed_dim % 2 == 0 57 | omega = np.arange(embed_dim // 2, dtype=np.float) 58 | omega /= embed_dim / 2.0 59 | omega = 1.0 / 10000**omega # (D/2,) 60 | 61 | pos = pos.reshape(-1) # (M,) 62 | out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product 63 | 64 | emb_sin = np.sin(out) # (M, D/2) 65 | emb_cos = np.cos(out) # (M, D/2) 66 | 67 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 68 | return emb 69 | -------------------------------------------------------------------------------- /src/pl_utils.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | from typing import Any, Optional 3 | 4 | from pytorch_lightning.cli import LightningArgumentParser 5 | from pytorch_lightning.loggers import Logger 6 | from pytorch_lightning.loggers.csv_logs import CSVLogger 7 | from pytorch_lightning.loggers.tensorboard import TensorBoardLogger 8 | from pytorch_lightning.loggers.wandb import WandbLogger 9 | 10 | 11 | class MyLightningArgumentParser(LightningArgumentParser): 12 | def __init__(self, *args: Any, **kwargs: Any) -> None: 13 | super().__init__(*args, **kwargs) 14 | self.add_logger_args() 15 | 16 | def add_logger_args(self) -> None: 17 | # Shared args 18 | self.add_argument( 19 | "--logger_type", 20 | type=str, 21 | help="Name of logger", 22 | default="csv", 23 | choices=["csv", "wandb"], 24 | ) 25 | self.add_argument( 26 | "--save_path", 27 | type=str, 28 | help="Save path of outputs", 29 | default="output/", 30 | ) 31 | self.add_argument( 32 | "--name", type=str, help="Name of experiment", default="default" 33 | ) 34 | 35 | # Wandb args 36 | self.add_argument( 37 | "--project", type=str, help="Name of wandb project", default="default" 38 | ) 39 | 40 | 41 | def init_logger(args: Namespace) -> Optional[Logger]: 42 | """Initialize logger from arguments 43 | 44 | Args: 45 | args: parsed argument namespace 46 | 47 | Returns: 48 | logger: initialized logger object 49 | """ 50 | if args.logger_type == "wandb": 51 | return WandbLogger( 52 | project=args.project, 53 | name=args.name, 54 | save_dir=args.save_path, 55 | ) 56 | elif args.logger_type == "csv": 57 | return CSVLogger(name=args.name, save_dir=args.save_path) 58 | else: 59 | ValueError( 60 | f"{args.logger_type} is not an available logger. Should be one of ['cvs', 'wandb']" 61 | ) 62 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from pytorch_lightning.callbacks import ModelCheckpoint 3 | 4 | from src.data import DuelViewDataModule 5 | from src.model import CANModel 6 | from src.pl_utils import MyLightningArgumentParser, init_logger 7 | 8 | model_class = CANModel 9 | dm_class = DuelViewDataModule 10 | 11 | # Parse arguments 12 | parser = MyLightningArgumentParser() 13 | parser.add_lightning_class_args(pl.Trainer, None) # type:ignore 14 | parser.add_lightning_class_args(dm_class, "data") 15 | parser.add_lightning_class_args(model_class, "model") 16 | parser.link_arguments("data.crop_size", "model.img_size") 17 | args = parser.parse_args() 18 | 19 | # Setup trainer 20 | logger = init_logger(args) 21 | checkpoint_callback = ModelCheckpoint( 22 | filename="best-{epoch}-{val_loss:.4f}", 23 | monitor="val_loss", 24 | mode="min", 25 | save_last=True, 26 | ) 27 | dm = dm_class(**args["data"]) 28 | model = model_class(**args["model"]) 29 | 30 | trainer = pl.Trainer.from_argparse_args( 31 | args, logger=logger, callbacks=[checkpoint_callback] 32 | ) 33 | 34 | # Train 35 | trainer.tune(model, dm) 36 | trainer.fit(model, dm) 37 | --------------------------------------------------------------------------------