├── .gitignore
├── LICENSE
├── README.md
├── assets
└── can.png
├── requirements.txt
├── scripts
└── extract_encoder_weights.py
├── src
├── data.py
├── loss.py
├── model.py
├── network
│ ├── decoder.py
│ ├── encoder.py
│ └── pos_embed.py
└── pl_utils.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | data/
2 | output/
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 | cover/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | .pybuilder/
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints
83 |
84 | # IPython
85 | profile_default/
86 | ipython_config.py
87 |
88 | # pyenv
89 | # For a library or package, you might want to ignore these files since the code is
90 | # intended to run in multiple environments; otherwise, check them in:
91 | # .python-version
92 |
93 | # pipenv
94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
97 | # install all needed dependencies.
98 | #Pipfile.lock
99 |
100 | # poetry
101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102 | # This is especially recommended for binary packages to ensure reproducibility, and is more
103 | # commonly ignored for libraries.
104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105 | #poetry.lock
106 |
107 | # pdm
108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109 | #pdm.lock
110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111 | # in version control.
112 | # https://pdm.fming.dev/#use-with-ide
113 | .pdm.toml
114 |
115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
116 | __pypackages__/
117 |
118 | # Celery stuff
119 | celerybeat-schedule
120 | celerybeat.pid
121 |
122 | # SageMath parsed files
123 | *.sage.py
124 |
125 | # Environments
126 | .env
127 | .venv
128 | env/
129 | venv/
130 | ENV/
131 | env.bak/
132 | venv.bak/
133 |
134 | # Spyder project settings
135 | .spyderproject
136 | .spyproject
137 |
138 | # Rope project settings
139 | .ropeproject
140 |
141 | # mkdocs documentation
142 | /site
143 |
144 | # mypy
145 | .mypy_cache/
146 | .dmypy.json
147 | dmypy.json
148 |
149 | # Pyre type checker
150 | .pyre/
151 |
152 | # pytype static type analyzer
153 | .pytype/
154 |
155 | # Cython debug symbols
156 | cython_debug/
157 |
158 | # PyCharm
159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161 | # and can be added to the global gitignore or merged into this file. For a more nuclear
162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163 | #.idea/
164 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Ben Conrad
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CAN: Contrastive Masked Autoencoders and Noise Prediction Pretraining
2 |
3 | PyTorch reimplementation of ["A simple, efficient and scalable contrastive masked autoencoder for learning visual representations"](https://arxiv.org/abs/2210.16870).
4 |
5 |
6 |
7 |
8 |
9 |
10 | ## Requirements
11 | - Python 3.8+
12 | - `pip install -r requirements`
13 |
14 | ## Usage
15 | To pretrain a ViT-b/16 network run:
16 | ```
17 | python train.py --accelerator gpu --devices 1 --precision 16 --data.root path/to/data/
18 | --max_epochs 1000 --data.batch_size 256 --model.encoder_name vit_base_patch16
19 | --model.mask_ratio 0.5 --model.weight_contrast 0.03 --model.weight_recon 0.67
20 | --model.weight_denoise 0.3
21 | ```
22 | - Run `python train.py --help` for descriptions of all options.
23 | - `--model.encoder_name` can be one of `vit_tiny_patch16, vit_small_patch16, vit_base_patch16, vit_large_patch16, vit_huge_patch14`.
24 |
25 | ### Using a Pretrained Model
26 | Encoder weights can be extracted from a pretraining checkpoint file by running:
27 | ```
28 | python scripts/extract_encoder_weights.py -c path/to/checkpoint/file
29 | ```
30 | You can then initialize a ViT model with these weights with the following:
31 | ```python
32 | import torch
33 | from timm.models.vision_transformer import VisionTransformer
34 |
35 | weights = torch.load("path/to/weights/file")
36 |
37 | # Assuming weights are for a ViT-b/16 model
38 | model = VisionTransformer(
39 | patch_size=16,
40 | embed_dim=768,
41 | depth=12,
42 | num_heads=12,
43 | )
44 | model.load_state_dict(weights)
45 | ```
46 | - __Note__: `VisionTransformer` arguments should match the those used during pretraining (e.g. ViT-b/16, ViT-l/16, etc.).
47 |
48 | ## Citation
49 | ```bibtex
50 | @article{mishra2022simple,
51 | title={A simple, efficient and scalable contrastive masked autoencoder for learning visual representations},
52 | author={Mishra, Shlok and Robinson, Joshua and Chang, Huiwen and Jacobs, David and Sarna, Aaron and Maschinot, Aaron and Krishnan, Dilip},
53 | journal={arXiv preprint arXiv:2210.16870},
54 | year={2022}
55 | }
56 | ```
57 |
--------------------------------------------------------------------------------
/assets/can.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bwconrad/can/f75905ca8f388a04e74b3e3373fa75fece7d801a/assets/can.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | einops==0.6.0
2 | numpy==1.23.4
3 | Pillow==9.3.0
4 | pytorch_lightning[extra]==1.8.1
5 | timm==0.6.11
6 | torch==1.13.0
7 | torchvision==0.14.0
8 | transformers==4.21.0
9 |
--------------------------------------------------------------------------------
/scripts/extract_encoder_weights.py:
--------------------------------------------------------------------------------
1 | """
2 | Script to extract the encoder's state_dict from a checkpoint file
3 | """
4 |
5 | from argparse import ArgumentParser
6 |
7 | import torch
8 |
9 | if __name__ == "__main__":
10 | parser = ArgumentParser()
11 | parser.add_argument("--checkpoint", "-c", type=str, required=True)
12 | parser.add_argument("--output", "-o", type=str, default="weights.pt")
13 | parser.add_argument("--prefix", "-p", type=str, default="encoder")
14 |
15 | args = parser.parse_args()
16 |
17 | checkpoint = torch.load(args.checkpoint, map_location="cpu")
18 | checkpoint = checkpoint["state_dict"]
19 |
20 | newmodel = {}
21 | for k, v in checkpoint.items():
22 | if not k.startswith(args.prefix):
23 | continue
24 |
25 | k = k.replace(args.prefix + ".", "")
26 | newmodel[k] = v
27 |
28 | with open(args.output, "wb") as f:
29 | torch.save(newmodel, f)
30 |
--------------------------------------------------------------------------------
/src/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | from glob import glob
3 | from typing import Callable, Tuple
4 |
5 | import pytorch_lightning as pl
6 | import torch
7 | import torch.utils.data as data
8 | from PIL import Image
9 | from torch.utils.data import DataLoader
10 | from torchvision.transforms import (ColorJitter, Compose, GaussianBlur,
11 | Normalize, RandomApply, RandomGrayscale,
12 | RandomHorizontalFlip, RandomResizedCrop,
13 | ToTensor)
14 |
15 |
16 | class DuelViewDataModule(pl.LightningDataModule):
17 | def __init__(
18 | self,
19 | root: str,
20 | batch_size: int = 256,
21 | workers: int = 4,
22 | num_val_samples: int = 1000,
23 | crop_size: int = 224,
24 | min_scale: float = 0.08,
25 | max_scale: float = 1.0,
26 | brightness: float = 0.8,
27 | contrast: float = 0.8,
28 | saturation: float = 0.8,
29 | hue: float = 0.2,
30 | color_jitter_prob: float = 0.8,
31 | gray_scale_prob: float = 0.2,
32 | flip_prob: float = 0.5,
33 | gaussian_prob: float = 0.5,
34 | mean: Tuple[float, float, float] = (0.485, 0.456, 0.406),
35 | std: Tuple[float, float, float] = (0.228, 0.224, 0.225),
36 | ):
37 | """Duel view data module
38 |
39 | Args:
40 | root: Path to image directory
41 | batch_size: Number of batch samples
42 | workers: Number of data workers
43 | num_val_samples: Number of samples to leave out for a validation set
44 | crop_size: Size of image crop
45 | min_scale: Minimum crop scale
46 | max_scale: Maximum crop scale
47 | brightness: Brightness intensity
48 | contrast: Contast intensity
49 | saturation: Saturation intensity
50 | hue: Hue intensity
51 | color_jitter_prob: Probability of applying color jitter
52 | gray_scale_prob: Probability of converting to grayscale
53 | flip_prob: Probability of applying horizontal flip
54 | gaussian_prob: Probability of applying Gaussian blurring
55 | mean: Image normalization channel means
56 | std: Image normalization channel standard deviations
57 | """
58 | super().__init__()
59 | self.save_hyperparameters()
60 | self.root = root
61 | self.batch_size = batch_size
62 | self.workers = workers
63 | self.num_val_samples = num_val_samples
64 | self.crop_size = crop_size
65 | self.min_scale = min_scale
66 | self.max_scale = max_scale
67 | self.brightness = brightness
68 | self.contrast = contrast
69 | self.saturation = saturation
70 | self.hue = hue
71 | self.color_jitter_prob = color_jitter_prob
72 | self.gray_scale_prob = gray_scale_prob
73 | self.flip_prob = flip_prob
74 | self.gaussian_prob = gaussian_prob
75 | self.mean = mean
76 | self.std = std
77 |
78 | self.transforms = MultiViewTransform(
79 | Transforms(
80 | crop_size=self.crop_size,
81 | min_scale=self.min_scale,
82 | max_scale=self.max_scale,
83 | brightness=self.brightness,
84 | contrast=self.contrast,
85 | saturation=self.saturation,
86 | hue=self.hue,
87 | color_jitter_prob=self.color_jitter_prob,
88 | gray_scale_prob=self.gray_scale_prob,
89 | gaussian_prob=self.gaussian_prob,
90 | flip_prob=self.flip_prob,
91 | mean=self.mean,
92 | std=self.std,
93 | ),
94 | n_views=2,
95 | )
96 |
97 | def setup(self, stage: str = "fit"):
98 | if stage == "fit":
99 | dataset = SimpleDataset(self.root, self.transforms)
100 |
101 | # Randomly take num_val_samples images for a validation set
102 | self.train_dataset, self.val_dataset = data.random_split(
103 | dataset,
104 | [len(dataset) - self.num_val_samples, self.num_val_samples],
105 | generator=torch.Generator().manual_seed(42), # Fixed seed
106 | )
107 |
108 | def train_dataloader(self):
109 | return DataLoader(
110 | self.train_dataset,
111 | batch_size=self.batch_size,
112 | shuffle=True,
113 | num_workers=self.workers,
114 | pin_memory=True,
115 | drop_last=True,
116 | persistent_workers=True,
117 | )
118 |
119 | def val_dataloader(self):
120 | return DataLoader(
121 | self.val_dataset,
122 | batch_size=self.batch_size,
123 | shuffle=False,
124 | num_workers=self.workers,
125 | pin_memory=True,
126 | drop_last=False,
127 | persistent_workers=True,
128 | )
129 |
130 |
131 | class SimpleDataset(data.Dataset):
132 | def __init__(self, root: str, transforms: Callable):
133 | """Image dataset from nested directory
134 |
135 | Args:
136 | root: Path to directory
137 | transforms: Image augmentations
138 | """
139 | super().__init__()
140 | self.root = root
141 | self.paths = [
142 | f for f in glob(f"{root}/**/*", recursive=True) if os.path.isfile(f)
143 | ]
144 | self.transforms = transforms
145 |
146 | print(f"Loaded {len(self.paths)} images from {root}")
147 |
148 | def __getitem__(self, index: int):
149 | img = Image.open(self.paths[index]).convert("RGB")
150 | img = self.transforms(img)
151 | return img
152 |
153 | def __len__(self):
154 | return len(self.paths)
155 |
156 |
157 | class MultiViewTransform:
158 | def __init__(self, transforms: Callable, n_views: int = 2):
159 | """Wrapper class to apply transforms multiple times on an image
160 |
161 | Args:
162 | transforms: Image augmentation pipeline
163 | n_views: Number of augmented views to return
164 | """
165 | self.transforms = transforms
166 | self.n_views = n_views
167 |
168 | def __call__(self, img: Image.Image):
169 | return [self.transforms(img) for _ in range(self.n_views)]
170 |
171 |
172 | class Transforms:
173 | def __init__(
174 | self,
175 | crop_size: int = 224,
176 | min_scale: float = 0.08,
177 | max_scale: float = 1.0,
178 | brightness: float = 0.8,
179 | contrast: float = 0.8,
180 | saturation: float = 0.8,
181 | hue: float = 0.2,
182 | color_jitter_prob: float = 0.8,
183 | gray_scale_prob: float = 0.2,
184 | gaussian_prob: float = 0.5,
185 | flip_prob: float = 0.5,
186 | mean: Tuple[float, float, float] = (0.485, 0.456, 0.406),
187 | std: Tuple[float, float, float] = (0.228, 0.224, 0.225),
188 | ):
189 | """Augmentation pipeline for contrastive learning
190 |
191 | Args:
192 | crop_size: Size of image crop
193 | min_scale: Minimum crop scale
194 | max_scale: Maximum crop scale
195 | brightness: Brightness intensity
196 | contast: Contast intensity
197 | saturation: Saturation intensity
198 | hue: Hue intensity
199 | color_jitter_prob: Probability of applying color jitter
200 | gray_scale_prob: Probability of converting to grayscale
201 | gaussian_prob: Probability of applying Gausian blurring
202 | flip_prob: Probability of applying horizontal flip
203 | mean: Image normalization means
204 | std: Image normalization standard deviations
205 | """
206 | super().__init__()
207 |
208 | self.transforms = Compose(
209 | [
210 | RandomResizedCrop(size=crop_size, scale=(min_scale, max_scale)),
211 | RandomApply(
212 | [
213 | ColorJitter(
214 | brightness=brightness, # type:ignore
215 | contrast=contrast, # type:ignore
216 | saturation=saturation, # type:ignore
217 | hue=hue, # type:ignore
218 | )
219 | ],
220 | p=color_jitter_prob,
221 | ),
222 | RandomGrayscale(p=gray_scale_prob),
223 | RandomApply([GaussianBlur(kernel_size=23)], p=gaussian_prob),
224 | RandomHorizontalFlip(p=flip_prob),
225 | ToTensor(),
226 | Normalize(mean=mean, std=std),
227 | ]
228 | )
229 |
230 | def __call__(self, img: Image.Image):
231 | return self.transforms(img)
232 |
--------------------------------------------------------------------------------
/src/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 | import torch.nn.functional as F
4 |
5 |
6 | def masked_mse_loss(
7 | pred: torch.Tensor,
8 | target: torch.Tensor,
9 | mask: torch.Tensor,
10 | normalize_targets: bool = False,
11 | ):
12 | """MSE loss on masked patches
13 |
14 | Args:
15 | pred: B x num_patches x D tensor of predict patches
16 | target: B x num_patches x D tensor of target patch values
17 | mask: B x num_patches binary mask with masked patches marked with 1
18 |
19 | Return:
20 | loss: Masked mean square error loss
21 | """
22 | # Normalize target pixel values
23 | if normalize_targets:
24 | mean = target.mean(dim=-1, keepdim=True)
25 | var = target.var(dim=-1, keepdim=True)
26 | target = (target - mean) / (var + 1.0e-6) ** 0.5
27 |
28 | # Calculate MSE loss
29 | loss = (pred - target) ** 2
30 | loss = loss.mean(dim=-1) # Per patch loss
31 | loss = (loss * mask).sum() / mask.sum() # Mean of masked patches
32 |
33 | return loss
34 |
35 |
36 | """
37 | Modified from:
38 | https://github.com/vturrisi/solo-learn/blob/main/solo/losses/simclr.py
39 | https://github.com/vturrisi/solo-learn/blob/main/solo/utils/misc.py
40 | """
41 |
42 |
43 | def info_nce_loss(z: torch.Tensor, temperature: float = 0.1) -> torch.Tensor:
44 | """Computes SimCLR's loss given batch of projected features z
45 | from different views, a positive boolean mask of all positives and
46 | a negative boolean mask of all negatives.
47 |
48 | Args:
49 | z (torch.Tensor): (2*B) x D tensor containing features from the views.
50 |
51 | Return:
52 | torch.Tensor: SimCLR loss.
53 | """
54 |
55 | z = F.normalize(z, dim=-1)
56 | gathered_z = gather(z)
57 |
58 | sim = torch.exp(torch.einsum("if, jf -> ij", z, gathered_z) / temperature)
59 |
60 | indexes = torch.arange(z.size(0) // 2, device=z.device).repeat(2)
61 | gathered_indexes = gather(indexes)
62 |
63 | indexes = indexes.unsqueeze(0)
64 | gathered_indexes = gathered_indexes.unsqueeze(0)
65 |
66 | # positives
67 | pos_mask = indexes.t() == gathered_indexes
68 | pos_mask[:, z.size(0) * get_rank() :].fill_diagonal_(0)
69 |
70 | # negatives
71 | neg_mask = indexes.t() != gathered_indexes
72 |
73 | pos = torch.sum(sim * pos_mask, 1)
74 | neg = torch.sum(sim * neg_mask, 1)
75 | loss = -(torch.mean(torch.log(pos / (pos + neg))))
76 | return loss
77 |
78 |
79 | def get_rank():
80 | if dist.is_available() and dist.is_initialized():
81 | return dist.get_rank()
82 | return 0
83 |
84 |
85 | class GatherLayer(torch.autograd.Function):
86 | """
87 | Gathers tensors from all process and supports backward propagation
88 | for the gradients across processes.
89 | """
90 |
91 | @staticmethod
92 | def forward(ctx, x):
93 | if dist.is_available() and dist.is_initialized():
94 | output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
95 | dist.all_gather(output, x)
96 | else:
97 | output = [x]
98 | return tuple(output)
99 |
100 | @staticmethod
101 | def backward(ctx, *grads):
102 | if dist.is_available() and dist.is_initialized():
103 | all_gradients = torch.stack(grads)
104 | dist.all_reduce(all_gradients)
105 | grad_out = all_gradients[get_rank()]
106 | else:
107 | grad_out = grads[0]
108 | return grad_out
109 |
110 |
111 | def gather(X, dim=0):
112 | """Gathers tensors from all processes, supporting backward propagation."""
113 | return torch.cat(GatherLayer.apply(X), dim=dim)
114 |
--------------------------------------------------------------------------------
/src/model.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Optional, Tuple
3 |
4 | import pytorch_lightning as pl
5 | import torch
6 | import torch.nn as nn
7 | from einops import rearrange
8 | from timm.optim.optim_factory import param_groups_weight_decay
9 | from torch.optim import SGD, Adam, AdamW
10 | from torch.optim.lr_scheduler import LambdaLR
11 | from torchvision.utils import make_grid, save_image
12 | from transformers.optimization import get_cosine_schedule_with_warmup
13 |
14 | from src.loss import info_nce_loss, masked_mse_loss
15 | from src.network.decoder import VitDecoder
16 | from src.network.encoder import build_encoder
17 | from src.network.pos_embed import get_1d_sincos_pos_embed
18 |
19 |
20 | class CANModel(pl.LightningModule):
21 | def __init__(
22 | self,
23 | img_size: int = 224,
24 | encoder_name: str = "vit_base_patch16",
25 | decoder_embed_dim: int = 512,
26 | decoder_depth: int = 8,
27 | decoder_num_heads: int = 16,
28 | decoder_embed_unmasked_tokens: bool = True,
29 | projector_hidden_dim: int = 4096,
30 | projector_out_dim: int = 128,
31 | noise_embed_in_dim: int = 768,
32 | noise_embed_hidden_dim: int = 768,
33 | mask_ratio: float = 0.5,
34 | norm_pixel_loss: bool = True,
35 | temperature: float = 0.1,
36 | noise_std_max: float = 0.05,
37 | weight_contrast: float = 0.03,
38 | weight_recon: float = 0.67,
39 | weight_denoise: float = 0.3,
40 | lr: float = 2.5e-4,
41 | optimizer: str = "adamw",
42 | betas: Tuple[float, float] = (0.9, 0.95),
43 | weight_decay: float = 0.05,
44 | momentum: float = 0.9,
45 | scheduler: str = "cosine",
46 | warmup_epochs: int = 0,
47 | channel_last: bool = False,
48 | ):
49 | """Contrastive Masked Autoencoder and Noise Prediction Pretraining Model
50 |
51 | Args:
52 | img_size: Size of input image
53 | encoder_name: Name of encoder network
54 | decoder_embed_dim: Embed dim of decoder
55 | decoder_depth: Number of transformer blocks in the decoder
56 | decoder_num_heads: Number of attention heads in the decoder
57 | decoder_embed_unmasked_tokens: Apply decoder embedding layer on both masked and unmasked tokens.
58 | Else only apply to masked tokens
59 | projector_hidden_dim: Hidden dim of projector
60 | projector_out_dim: Output dim of projector
61 | noise_embed_in_dim: Dim of noise level sinusoidal embedding
62 | noise_embed_hidden_dim: Hidden dim of noising embed MLP
63 | mask_ratio: Ratio of input image patches to mask
64 | norm_pixel_loss: Calculate loss using normalized pixel value targets
65 | temperature: Temperature for contrastive loss
66 | noise_std_max: Maximum noise standard deviation
67 | weight_contrast: Weight for contrastive loss
68 | weight_recon: Weight for patch reconstruction loss
69 | weight_denoise: Weight for denoising loss
70 | lr: Learning rate (should be scaled with batch size. i.e. lr = base_lr*batch_size/256)
71 | optimizer: Name of optimizer (adam | adamw | sgd)
72 | betas: Adam beta parameters
73 | weight_decay: Optimizer weight decay
74 | momentum: SGD momentum parameter
75 | scheduler: Name of learning rate scheduler [cosine, none]
76 | warmup_epochs: Number of warmup epochs
77 | channel_last: Change to channel last memory format for possible training speed up
78 | """
79 | super().__init__()
80 | self.save_hyperparameters()
81 | self.img_size = img_size
82 | self.encoder_name = encoder_name
83 | self.decoder_embed_dim = decoder_embed_dim
84 | self.decoder_depth = decoder_depth
85 | self.decoder_num_heads = decoder_num_heads
86 | self.decoder_embed_unmasked_tokens = decoder_embed_unmasked_tokens
87 | self.projector_hidden_dim = projector_hidden_dim
88 | self.projector_out_dim = projector_out_dim
89 | self.noise_embed_in_dim = noise_embed_in_dim
90 | self.noise_embed_hidden_dim = noise_embed_hidden_dim
91 | self.mask_ratio = mask_ratio
92 | self.norm_pixel_loss = norm_pixel_loss
93 | self.temperature = temperature
94 | self.noise_std_max = noise_std_max
95 | self.weight_contrast = weight_contrast
96 | self.weight_recon = weight_recon
97 | self.weight_denoise = weight_denoise
98 | self.lr = lr
99 | self.optimizer = optimizer
100 | self.betas = betas
101 | self.weight_decay = weight_decay
102 | self.momentum = momentum
103 | self.scheduler = scheduler
104 | self.warmup_epochs = warmup_epochs
105 | self.channel_last = channel_last
106 |
107 | # Initialize networks
108 | self.encoder, self.patch_size = build_encoder(
109 | encoder_name, img_size=self.img_size
110 | )
111 | self.decoder = VitDecoder(
112 | patch_size=self.patch_size,
113 | num_patches=self.encoder.patch_embed.num_patches,
114 | in_dim=self.encoder.embed_dim,
115 | embed_dim=self.decoder_embed_dim,
116 | depth=self.decoder_depth,
117 | num_heads=self.decoder_num_heads,
118 | embed_unmasked_tokens=self.decoder_embed_unmasked_tokens,
119 | )
120 | self.projector = nn.Sequential(
121 | nn.Linear(self.encoder.embed_dim, self.projector_hidden_dim),
122 | nn.BatchNorm1d(self.projector_hidden_dim),
123 | nn.ReLU(),
124 | nn.Linear(self.projector_hidden_dim, self.projector_hidden_dim),
125 | nn.BatchNorm1d(self.projector_hidden_dim),
126 | nn.ReLU(),
127 | nn.Linear(self.projector_hidden_dim, self.projector_out_dim),
128 | )
129 | # Based on updated openreview version (as of Nov 17, 2022), the MLP is two layers
130 | # without BN (maybe?) and input and hidden dims the same as the encoder embedding
131 | self.noise_embed = nn.Sequential(
132 | nn.Linear(self.noise_embed_in_dim, self.noise_embed_hidden_dim),
133 | nn.ReLU(),
134 | nn.Linear(
135 | self.noise_embed_hidden_dim,
136 | self.encoder.embed_dim
137 | if self.decoder_embed_unmasked_tokens
138 | else self.decoder_embed_dim,
139 | ),
140 | )
141 |
142 | # Change to channel last memory format
143 | # https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html
144 | if self.channel_last:
145 | self = self.to(memory_format=torch.channels_last)
146 |
147 | def patchify(self, x: torch.Tensor):
148 | """Rearrange image into patches
149 |
150 | Args:
151 | x: Tensor of size (b, 3, h, w)
152 |
153 | Return:
154 | x: Tensor of size (b, h*w, patch_size^2 * 3)
155 | """
156 | assert x.shape[2] == x.shape[3] and x.shape[2] % self.patch_size == 0
157 |
158 | return rearrange(
159 | x,
160 | "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
161 | p1=self.patch_size,
162 | p2=self.patch_size,
163 | )
164 |
165 | def unpatchify(self, x: torch.Tensor):
166 | """Rearrange patches back to an image
167 |
168 | Args:
169 | x: Tensor of size (b, h*w, patch_size^2 * 3)
170 |
171 | Return:
172 | x: Tensor of size (b, 3, h, w)
173 | """
174 | h = w = int(x.shape[1] ** 0.5)
175 | return rearrange(
176 | x,
177 | " b (h w) (p1 p2 c) -> b c (h p1) (w p2)",
178 | p1=self.patch_size,
179 | p2=self.patch_size,
180 | h=h,
181 | w=w,
182 | )
183 |
184 | def log_samples(self, inp: torch.Tensor, pred: torch.Tensor, mask: torch.Tensor):
185 | """Log sample images"""
186 | # Only log up to 16 images
187 | inp, pred, mask = inp[:16], pred[:16], mask[:16]
188 |
189 | # Patchify the input image
190 | inp = self.patchify(inp)
191 |
192 | # Merge original and predicted patches
193 | pred = pred * mask[:, :, None]
194 | inp = inp * (1 - mask[:, :, None])
195 | res = self.unpatchify(inp) + self.unpatchify(pred)
196 |
197 | # Log result
198 | if "CSVLogger" in str(self.logger.__class__):
199 | path = os.path.join(
200 | self.logger.log_dir, # type:ignore
201 | "samples",
202 | )
203 | if not os.path.exists(path):
204 | os.makedirs(path)
205 | filename = os.path.join(path, str(self.current_epoch) + "ep.png")
206 | save_image(res, filename, nrow=4, normalize=True)
207 | elif "WandbLogger" in str(self.logger.__class__):
208 | grid = make_grid(res, nrow=4, normalize=True)
209 | self.logger.log_image(key="sample", images=[grid]) # type:ignore
210 |
211 | @torch.no_grad()
212 | def add_noise(self, x: torch.Tensor):
213 | """Add noise to input image
214 |
215 | Args:
216 | x: Tensor of size (b, c, h, w)
217 |
218 | Return:
219 | x_noise: x tensor with added Gaussian noise of size (b, c, h, w)
220 | noise: Noise tensor of size (b, c, h, w)
221 | std: Noise standard deviation (noise level) tensor of size (b,)
222 | """
223 | # Sample std uniformly from [0, self.noise_std_max]
224 | std = torch.rand(x.size(0), device=x.device) * self.noise_std_max
225 |
226 | # Sample noise
227 | noise = torch.randn_like(x) * std[:, None, None, None]
228 |
229 | # Add noise to x
230 | x_noise = x + noise
231 |
232 | return x_noise, noise, std
233 |
234 | def shared_step(
235 | self,
236 | x: Tuple[torch.Tensor, torch.Tensor],
237 | mode: str = "train",
238 | batch_idx: Optional[int] = None,
239 | ):
240 | x1, x2 = x
241 |
242 | if self.channel_last:
243 | x1 = x1.to(memory_format=torch.channels_last) # type:ignore
244 | x2 = x2.to(memory_format=torch.channels_last) # type:ignore
245 |
246 | # Add noise to views
247 | x1_noise, noise1, std1 = self.add_noise(x1)
248 | x2_noise, noise2, std2 = self.add_noise(x2)
249 |
250 | # Mask and extract features
251 | z1, mask1, idx_unshuffle1 = self.encoder(x1_noise, self.mask_ratio)
252 | z2, mask2, idx_unshuffle2 = self.encoder(x2_noise, self.mask_ratio)
253 |
254 | # Pass mean encoder features through projector
255 | u1 = self.projector(torch.mean(z1[:, 1:, :], dim=1)) # Skip cls token
256 | u2 = self.projector(torch.mean(z2[:, 1:, :], dim=1))
257 |
258 | # Generate noise level embedding
259 | p1 = self.noise_embed(
260 | get_1d_sincos_pos_embed(std1, dim=self.noise_embed_in_dim)
261 | )
262 | p2 = self.noise_embed(
263 | get_1d_sincos_pos_embed(std2, dim=self.noise_embed_in_dim)
264 | )
265 |
266 | # Predict masked patches and noise
267 | x1_pred = self.decoder(z1, idx_unshuffle1, p1)
268 | x2_pred = self.decoder(z2, idx_unshuffle2, p2)
269 |
270 | # Contrastive loss
271 | loss_contrast = info_nce_loss(torch.cat([u1, u2]), temperature=self.temperature)
272 |
273 | # Patch reconstruction loss
274 | loss_recon = (
275 | masked_mse_loss(x1_pred, self.patchify(x1), mask1, self.norm_pixel_loss)
276 | + masked_mse_loss(x2_pred, self.patchify(x2), mask2, self.norm_pixel_loss)
277 | ) / 2
278 |
279 | # Denoising loss
280 | loss_denoise = (
281 | masked_mse_loss(
282 | x1_pred, self.patchify(noise1), 1 - mask1, self.norm_pixel_loss
283 | )
284 | + masked_mse_loss(
285 | x2_pred, self.patchify(noise2), 1 - mask2, self.norm_pixel_loss
286 | )
287 | ) / 2
288 |
289 | # Combined loss
290 | loss = (
291 | self.weight_contrast * loss_contrast
292 | + self.weight_recon * loss_recon
293 | + self.weight_denoise * loss_denoise
294 | )
295 |
296 | # Log
297 | self.log(f"{mode}_loss", loss)
298 | self.log(f"{mode}_loss_contrast", loss_contrast)
299 | self.log(f"{mode}_loss_recon", loss_recon)
300 | self.log(f"{mode}_loss_denoise", loss_denoise)
301 | if mode == "val" and batch_idx == 0:
302 | self.log_samples(x1, x1_pred, mask1)
303 |
304 | return {"loss": loss}
305 |
306 | def training_step(self, x, _):
307 | self.log("lr", self.trainer.optimizers[0].param_groups[0]["lr"], prog_bar=True)
308 | return self.shared_step(x, mode="train")
309 |
310 | def validation_step(self, x, batch_idx):
311 | return self.shared_step(x, mode="val", batch_idx=batch_idx)
312 |
313 | def configure_optimizers(self):
314 | """Initialize optimizer and learning rate schedule"""
315 | # Set weight decay to 0 for bias and norm layers (following MAE)
316 | params = param_groups_weight_decay(
317 | self.encoder, self.weight_decay
318 | ) + param_groups_weight_decay(self.decoder, self.weight_decay)
319 |
320 | # Optimizer
321 | if self.optimizer == "adam":
322 | optimizer = Adam(
323 | params,
324 | lr=self.lr,
325 | betas=self.betas,
326 | weight_decay=self.weight_decay,
327 | )
328 | elif self.optimizer == "adamw":
329 | optimizer = AdamW(
330 | params,
331 | lr=self.lr,
332 | betas=self.betas,
333 | weight_decay=self.weight_decay,
334 | )
335 | elif self.optimizer == "sgd":
336 | optimizer = SGD(
337 | params,
338 | lr=self.lr,
339 | momentum=self.momentum,
340 | weight_decay=self.weight_decay,
341 | )
342 | else:
343 | raise ValueError(
344 | f"{self.optimizer} is not an available optimizer. Should be one of ['adam', 'adamw', 'sgd']"
345 | )
346 |
347 | # Learning rate schedule
348 | if self.scheduler == "cosine":
349 | epoch_steps = (
350 | self.trainer.estimated_stepping_batches
351 | // self.trainer.max_epochs # type:ignore
352 | )
353 | scheduler = get_cosine_schedule_with_warmup(
354 | optimizer,
355 | num_training_steps=self.trainer.estimated_stepping_batches, # type:ignore
356 | num_warmup_steps=epoch_steps * self.warmup_epochs,
357 | )
358 | elif self.scheduler == "none":
359 | scheduler = LambdaLR(optimizer, lambda _: 1)
360 | else:
361 | raise ValueError(
362 | f"{self.scheduler} is not an available optimizer. Should be one of ['cosine', 'none']"
363 | )
364 |
365 | return {
366 | "optimizer": optimizer,
367 | "lr_scheduler": {
368 | "scheduler": scheduler,
369 | "interval": "step",
370 | },
371 | }
372 |
--------------------------------------------------------------------------------
/src/network/decoder.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from typing import Optional
3 |
4 | import torch
5 | import torch.nn as nn
6 | from einops import repeat
7 | from timm.models.vision_transformer import Block
8 |
9 | from src.network.pos_embed import get_2d_sincos_pos_embed
10 |
11 |
12 | class VitDecoder(nn.Module):
13 | def __init__(
14 | self,
15 | patch_size: int = 16,
16 | num_patches: int = 196,
17 | in_channels: int = 3,
18 | depth: int = 8,
19 | embed_dim: int = 512,
20 | in_dim: int = 768,
21 | num_heads: int = 16,
22 | mlp_ratio: int = 4,
23 | norm_layer: nn.Module = partial(nn.LayerNorm, eps=1e-6), # type:ignore
24 | act_layer: nn.Module = nn.GELU, # type:ignore
25 | embed_unmasked_tokens: bool = True,
26 | ):
27 | super().__init__()
28 | self.embed_unmasked_tokens = embed_unmasked_tokens
29 |
30 | # Projection from encoder to decoder dim
31 | self.embed = nn.Linear(in_dim, embed_dim, bias=True)
32 |
33 | # Mask token
34 | self.mask_token = nn.Parameter(
35 | torch.zeros(1, 1, in_dim if embed_unmasked_tokens else embed_dim)
36 | )
37 |
38 | # Sin-cos position embedding
39 | self.pos_embed = nn.Parameter(
40 | torch.zeros((1, num_patches + 1, embed_dim)), requires_grad=False
41 | )
42 |
43 | self.blocks = nn.Sequential(
44 | *[
45 | Block(
46 | dim=embed_dim,
47 | num_heads=num_heads,
48 | mlp_ratio=mlp_ratio,
49 | qkv_bias=True,
50 | norm_layer=norm_layer, # type:ignore
51 | act_layer=act_layer, # type:ignore
52 | )
53 | for _ in range(depth)
54 | ]
55 | )
56 | self.norm = norm_layer(embed_dim)
57 | self.head = nn.Linear(embed_dim, patch_size**2 * in_channels, bias=True)
58 |
59 | self.init_weights(num_patches)
60 |
61 | def init_weights(self, num_patches: int):
62 | # Initialize to sin-cos position embedding
63 | pos_embed = get_2d_sincos_pos_embed(
64 | self.pos_embed.shape[-1],
65 | int(num_patches**0.5),
66 | cls_token=True,
67 | )
68 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
69 |
70 | # Mask token
71 | torch.nn.init.normal_(self.mask_token, std=0.02)
72 |
73 | # All other weights
74 | self.apply(self._init_weights)
75 |
76 | def _init_weights(self, m: nn.Module):
77 | if isinstance(m, nn.Linear):
78 | torch.nn.init.xavier_uniform_(m.weight)
79 | if isinstance(m, nn.Linear) and m.bias is not None:
80 | nn.init.constant_(m.bias, 0)
81 | elif isinstance(m, nn.LayerNorm):
82 | nn.init.constant_(m.bias, 0)
83 | nn.init.constant_(m.weight, 1.0)
84 |
85 | def forward(
86 | self,
87 | x: torch.Tensor,
88 | idx_unshuffle: torch.Tensor,
89 | p: Optional[torch.Tensor] = None,
90 | ):
91 | if not self.embed_unmasked_tokens:
92 | # Project only masked tokens to decoder embed size
93 | x = self.embed(x)
94 |
95 | # Append mask tokens to input
96 | L = idx_unshuffle.shape[1]
97 | B, L_unmasked, D = x.shape
98 | mask_tokens = self.mask_token.repeat(B, L + 1 - L_unmasked, 1)
99 | temp = torch.concat([x[:, 1:, :], mask_tokens], dim=1) # Skip cls token
100 |
101 | # Unshuffle tokens
102 | temp = torch.gather(
103 | temp, dim=1, index=repeat(idx_unshuffle, "b l -> b l d", d=D)
104 | )
105 |
106 | # Add noise level embedding
107 | if p is not None:
108 | temp = temp + p[:, None, :]
109 |
110 | # Prepend cls token
111 | x = torch.cat([x[:, :1, :], temp], dim=1)
112 |
113 | if self.embed_unmasked_tokens:
114 | # Project masked and unmasked tokens to decoder embed size
115 | x = self.embed(x)
116 |
117 | # Add pos embed
118 | x = x + self.pos_embed
119 |
120 | # Apply transformer layers
121 | x = self.blocks(x)
122 |
123 | # Predict pixel values
124 | x = self.head(self.norm(x))
125 |
126 | return x[:, 1:, :] # Don't return cls token
127 |
128 |
129 | def dec512d8b(patch_size: int, num_patches: int, in_dim: int, **kwargs):
130 | return VitDecoder(
131 | patch_size=patch_size,
132 | num_patches=num_patches,
133 | in_dim=in_dim,
134 | embed_dim=512,
135 | depth=8,
136 | num_heads=16,
137 | **kwargs,
138 | )
139 |
140 |
141 | MODEL_DICT = {"dec512d8b": dec512d8b}
142 |
143 |
144 | def build_decoder(model: str, **kwargs):
145 | try:
146 | model_fn = MODEL_DICT[model]
147 | except:
148 | raise ValueError(
149 | f"{model} is not an available decoder. Should be one of {[k for k in MODEL_DICT.keys()]}"
150 | )
151 |
152 | return model_fn(**kwargs)
153 |
--------------------------------------------------------------------------------
/src/network/encoder.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | import timm.models.vision_transformer as vision_transformer
4 | import torch
5 | import torch.nn as nn
6 | from einops import repeat
7 |
8 | from src.network.pos_embed import get_2d_sincos_pos_embed
9 |
10 |
11 | class VisionTransformer(vision_transformer.VisionTransformer):
12 | """Vision transformer for masked image modeling.
13 | Uses fixed sin-cos position embeddings
14 | """
15 |
16 | def __init__(self, **kwargs):
17 | super(VisionTransformer, self).__init__(**kwargs)
18 | assert self.num_prefix_tokens == 1 # Must have cls token
19 |
20 | # Re-initialize with fixed sin-cos position embedding
21 | self.pos_embed = nn.Parameter(
22 | torch.zeros(self.pos_embed.shape), requires_grad=False
23 | )
24 | self.init_pos_embed()
25 |
26 | def init_pos_embed(self):
27 | """Initialize sin-cos position embeddings"""
28 | pos_embed = get_2d_sincos_pos_embed(
29 | self.pos_embed.shape[-1],
30 | int(self.patch_embed.num_patches**0.5),
31 | cls_token=True,
32 | )
33 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
34 |
35 | def random_masking(self, x: torch.Tensor, mask_ratio: float):
36 | """Randomly mask mask_ratio patches of an image
37 |
38 | Args:
39 | x: Tensor of shape B x L x D
40 | mask_ratio: Ratio of patches to mask
41 |
42 | Return:
43 | x_masked: Tensor of non-masked patches
44 | mask: Tensor of size B x L where the positions of masked
45 | patches are marked by 1 and else 0
46 | idx_unshuffle: Tensor of size B x L with the sorting order
47 | to unshuffle patches back to the original order
48 | """
49 | B, L, D = x.shape
50 |
51 | # Number of patches to keep
52 | num_keep = int(L * (1 - mask_ratio))
53 |
54 | # Sort array of random noise
55 | noise = torch.rand((B, L), device=x.device)
56 | idx_shuffle = torch.argsort(noise, dim=1)
57 | idx_unshuffle = torch.argsort(idx_shuffle, dim=1) # Undo shuffling
58 |
59 | # Keep indices of n_keep smallest values
60 | idx_keep = idx_shuffle[:, :num_keep]
61 | x_masked = torch.gather(x, dim=1, index=repeat(idx_keep, "b l -> b l d", d=D))
62 |
63 | # Generate binary mask
64 | mask = torch.ones((B, L), device=x.device)
65 | mask[:, :num_keep] = 0
66 | mask = torch.gather(mask, dim=1, index=idx_unshuffle)
67 |
68 | return x_masked, mask, idx_unshuffle
69 |
70 | def forward(self, x: torch.Tensor, mask_ratio: float = 0.75):
71 | # Patch embed image
72 | x = self.patch_embed(x)
73 |
74 | # Add pos embed skipping cls token
75 | x = x + self.pos_embed[:, 1:, :]
76 |
77 | # Mask the image
78 | x, mask, idx_unshuffle = self.random_masking(x, mask_ratio)
79 |
80 | # Append the cls token
81 | cls_token = self.cls_token + self.pos_embed[:, 0, :]
82 | x = torch.cat((cls_token.expand(x.shape[0], -1, -1), x), dim=1)
83 |
84 | # Apply transformer layers
85 | x = self.norm_pre(x)
86 | x = self.blocks(x)
87 | x = self.norm(x)
88 |
89 | return x, mask, idx_unshuffle
90 |
91 |
92 | def build_encoder(model: str, **kwargs):
93 | try:
94 | model_fn, patch_size = MODEL_DICT[model]
95 | except:
96 | raise ValueError(
97 | f"{model} is not an available encoder. Should be one of {[k for k in MODEL_DICT.keys()]}"
98 | )
99 |
100 | return model_fn(**kwargs), patch_size
101 |
102 |
103 | def vit_tiny_patch16(**kwargs):
104 | return VisionTransformer(
105 | patch_size=16,
106 | embed_dim=192,
107 | depth=12,
108 | num_heads=3,
109 | mlp_ratio=4,
110 | qkv_bias=True,
111 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
112 | weight_init="jax",
113 | **kwargs,
114 | )
115 |
116 |
117 | def vit_small_patch16(**kwargs):
118 | return VisionTransformer(
119 | patch_size=16,
120 | embed_dim=384,
121 | depth=12,
122 | num_heads=6,
123 | mlp_ratio=4,
124 | qkv_bias=True,
125 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
126 | weight_init="jax",
127 | **kwargs,
128 | )
129 |
130 |
131 | def vit_base_patch16(**kwargs):
132 | return VisionTransformer(
133 | patch_size=16,
134 | embed_dim=768,
135 | depth=12,
136 | num_heads=12,
137 | mlp_ratio=4,
138 | qkv_bias=True,
139 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
140 | weight_init="jax",
141 | **kwargs,
142 | )
143 |
144 |
145 | def vit_large_patch16(**kwargs):
146 | return VisionTransformer(
147 | patch_size=16,
148 | embed_dim=1024,
149 | depth=24,
150 | num_heads=16,
151 | mlp_ratio=4,
152 | qkv_bias=True,
153 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
154 | weight_init="jax",
155 | **kwargs,
156 | )
157 |
158 |
159 | def vit_huge_patch14(**kwargs):
160 | return VisionTransformer(
161 | patch_size=14,
162 | embed_dim=1280,
163 | depth=32,
164 | num_heads=16,
165 | mlp_ratio=4,
166 | qkv_bias=True,
167 | norm_layer=partial(nn.LayerNorm, eps=1e-6),
168 | weight_init="jax",
169 | **kwargs,
170 | )
171 |
172 |
173 | MODEL_DICT = {
174 | "vit_tiny_patch16": (vit_tiny_patch16, 16),
175 | "vit_small_patch16": (vit_small_patch16, 16),
176 | "vit_base_patch16": (vit_base_patch16, 16),
177 | "vit_large_patch16": (vit_large_patch16, 16),
178 | "vit_huge_patch14": (vit_huge_patch14, 14),
179 | }
180 |
--------------------------------------------------------------------------------
/src/network/pos_embed.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import numpy as np
4 | import torch
5 |
6 |
7 | def get_1d_sincos_pos_embed(x: torch.Tensor, dim: int):
8 | """From: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py"""
9 | half_dim = dim // 2
10 | emb = math.log(10000) / (half_dim - 1)
11 | emb = torch.exp(torch.arange(half_dim, device=x.device) * -emb)
12 | emb = x[:, None] * emb[None, :]
13 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
14 | return emb
15 |
16 |
17 | """From: https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20"""
18 |
19 |
20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
21 | """
22 | grid_size: int of the grid height and width
23 | return:
24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
25 | """
26 | grid_h = np.arange(grid_size, dtype=np.float32)
27 | grid_w = np.arange(grid_size, dtype=np.float32)
28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
29 | grid = np.stack(grid, axis=0)
30 |
31 | grid = grid.reshape([2, 1, grid_size, grid_size])
32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
33 | if cls_token:
34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
35 | return pos_embed
36 |
37 |
38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
39 | assert embed_dim % 2 == 0
40 |
41 | # use half of dimensions to encode grid_h
42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
44 |
45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
46 |
47 | return emb
48 |
49 |
50 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
51 | """
52 | embed_dim: output dimension for each position
53 | pos: a list of positions to be encoded: size (M,)
54 | out: (M, D)
55 | """
56 | assert embed_dim % 2 == 0
57 | omega = np.arange(embed_dim // 2, dtype=np.float)
58 | omega /= embed_dim / 2.0
59 | omega = 1.0 / 10000**omega # (D/2,)
60 |
61 | pos = pos.reshape(-1) # (M,)
62 | out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
63 |
64 | emb_sin = np.sin(out) # (M, D/2)
65 | emb_cos = np.cos(out) # (M, D/2)
66 |
67 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
68 | return emb
69 |
--------------------------------------------------------------------------------
/src/pl_utils.py:
--------------------------------------------------------------------------------
1 | from argparse import Namespace
2 | from typing import Any, Optional
3 |
4 | from pytorch_lightning.cli import LightningArgumentParser
5 | from pytorch_lightning.loggers import Logger
6 | from pytorch_lightning.loggers.csv_logs import CSVLogger
7 | from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
8 | from pytorch_lightning.loggers.wandb import WandbLogger
9 |
10 |
11 | class MyLightningArgumentParser(LightningArgumentParser):
12 | def __init__(self, *args: Any, **kwargs: Any) -> None:
13 | super().__init__(*args, **kwargs)
14 | self.add_logger_args()
15 |
16 | def add_logger_args(self) -> None:
17 | # Shared args
18 | self.add_argument(
19 | "--logger_type",
20 | type=str,
21 | help="Name of logger",
22 | default="csv",
23 | choices=["csv", "wandb"],
24 | )
25 | self.add_argument(
26 | "--save_path",
27 | type=str,
28 | help="Save path of outputs",
29 | default="output/",
30 | )
31 | self.add_argument(
32 | "--name", type=str, help="Name of experiment", default="default"
33 | )
34 |
35 | # Wandb args
36 | self.add_argument(
37 | "--project", type=str, help="Name of wandb project", default="default"
38 | )
39 |
40 |
41 | def init_logger(args: Namespace) -> Optional[Logger]:
42 | """Initialize logger from arguments
43 |
44 | Args:
45 | args: parsed argument namespace
46 |
47 | Returns:
48 | logger: initialized logger object
49 | """
50 | if args.logger_type == "wandb":
51 | return WandbLogger(
52 | project=args.project,
53 | name=args.name,
54 | save_dir=args.save_path,
55 | )
56 | elif args.logger_type == "csv":
57 | return CSVLogger(name=args.name, save_dir=args.save_path)
58 | else:
59 | ValueError(
60 | f"{args.logger_type} is not an available logger. Should be one of ['cvs', 'wandb']"
61 | )
62 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning as pl
2 | from pytorch_lightning.callbacks import ModelCheckpoint
3 |
4 | from src.data import DuelViewDataModule
5 | from src.model import CANModel
6 | from src.pl_utils import MyLightningArgumentParser, init_logger
7 |
8 | model_class = CANModel
9 | dm_class = DuelViewDataModule
10 |
11 | # Parse arguments
12 | parser = MyLightningArgumentParser()
13 | parser.add_lightning_class_args(pl.Trainer, None) # type:ignore
14 | parser.add_lightning_class_args(dm_class, "data")
15 | parser.add_lightning_class_args(model_class, "model")
16 | parser.link_arguments("data.crop_size", "model.img_size")
17 | args = parser.parse_args()
18 |
19 | # Setup trainer
20 | logger = init_logger(args)
21 | checkpoint_callback = ModelCheckpoint(
22 | filename="best-{epoch}-{val_loss:.4f}",
23 | monitor="val_loss",
24 | mode="min",
25 | save_last=True,
26 | )
27 | dm = dm_class(**args["data"])
28 | model = model_class(**args["model"])
29 |
30 | trainer = pl.Trainer.from_argparse_args(
31 | args, logger=logger, callbacks=[checkpoint_callback]
32 | )
33 |
34 | # Train
35 | trainer.tune(model, dm)
36 | trainer.fit(model, dm)
37 |
--------------------------------------------------------------------------------