├── .gitignore ├── LICENSE.txt ├── README.md ├── config ├── data │ ├── cars_mae.yaml │ ├── cub_mae.yaml │ ├── stl10_linear_probe.yaml │ └── stl10_mae.yaml ├── linear_probe.yaml └── mae.yaml ├── datamodule.py ├── lightning_readout.py ├── linear_probe.py ├── mae.py ├── requirements.txt ├── samples ├── bird-samples.png ├── bird-training-curves.png ├── birds-training.gif └── car-samples.png ├── train.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Logs 2 | *logs*/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 108 | __pypackages__/ 109 | 110 | # Celery stuff 111 | celerybeat-schedule 112 | celerybeat.pid 113 | 114 | # SageMath parsed files 115 | *.sage.py 116 | 117 | # Environments 118 | .env 119 | .venv 120 | env/ 121 | venv/ 122 | ENV/ 123 | env.bak/ 124 | venv.bak/ 125 | 126 | # Spyder project settings 127 | .spyderproject 128 | .spyproject 129 | 130 | # Rope project settings 131 | .ropeproject 132 | 133 | # mkdocs documentation 134 | /site 135 | 136 | # mypy 137 | .mypy_cache/ 138 | .dmypy.json 139 | dmypy.json 140 | 141 | # Pyre type checker 142 | .pyre/ 143 | 144 | # pytype static type analyzer 145 | .pytype/ 146 | 147 | # Cython debug symbols 148 | cython_debug/ 149 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Connor Anderson 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

Masked Autoencoders in PyTorch

3 | 4 | PyTorch 5 | Lightning 6 | 7 | 8 |
9 | 10 | A simple, unofficial implementation of MAE ([Masked Autoencoders are Scalable Vision Learners](https://arxiv.org/abs/2111.06377)) using [pytorch-lightning](https://www.pytorchlightning.ai/). A PyTorch implementation by the authors can be found [here](https://github.com/facebookresearch/mae). 11 | 12 | Currently implements training on [CUB](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html), [StanfordCars](http://ai.stanford.edu/~jkrause/cars/car_dataset.html), [STL-10](https://cs.stanford.edu/~acoates/stl10/) but is easily extensible to any other image dataset. 13 | 14 | ## Updates 15 | 16 | ### September 1, 2023 17 | 18 | - Updated for compatibility with Pytorch 2.0 and PyTorch-Lightning 2.0. This probably breaks backwards compatibility. Created a release for the old version of the code. 19 | - Modified parts of the training code for better conciseness and efficiency. 20 | - Added additional features, including the option to save some validation reconstructions during training. **Note**: having trouble with saving reconstructions during distributed training; freezes at the end of the validation epoch. 21 | - Retrained CUB and Cars models with new code and a stronger decoder. 22 | 23 | ### February 4, 2022 24 | 25 | - Fixed a bug in the code for generating mask indices. Retrained and updated the reconstruction figures (see below). They aren't quite as pretty now, but they make more sense. 26 | 27 | ## Setup 28 | 29 | ```bash 30 | # Clone the repository 31 | git clone https://github.com/catalys1/mae-pytorch.git 32 | cd mae-pytorch 33 | 34 | # Install required libraries (inside a virtual environment preferably) 35 | pip install -r requirements.txt 36 | 37 | # Set up .env for path to data 38 | echo "DATADIR=/path/to/data" > .env 39 | ``` 40 | 41 | ## Usage 42 | 43 | ### MAE training 44 | 45 | Training options are provided through configuration files, handled by [LightningCLI](https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_cli.html). See `config/` for examples. 46 | 47 | Train an MAE model on the CUB dataset: 48 | ```bash 49 | python train.py fit -c config/mae.yaml -c config/data/cub_mae.yaml 50 | ``` 51 | 52 | Using multiple GPUs: 53 | ```bash 54 | python train.py fit -c config/mae.yaml -c config/data/cub_mae.yaml --trainer.devices 8 55 | ``` 56 | ### Linear probing (currently only supported on STL-10) 57 | Evaluate the learned representations using a linear probe. First, pretrain the model on the 100.000 samples of the 'unlabeled' split. 58 | ```bash 59 | python train.py fit -c config/mae.yaml -c config/data/stl10_mae.yaml 60 | ``` 61 | Now, append a linear probe to the last layer of the frozen encoder and discard the decoder. The appended classifier is then trained on 4000 labeled samples of the 'train' split (another 1000 are used for training validation) and evaluated on the 'test' split. To do so, simply provide the path to the pretrained model checkpoint in the command below. 62 | 63 | ```bash 64 | python linear_probe.py -c config/linear_probe.yaml -c config/data/stl10_linear_probe.yaml --model.init_args.ckpt_path 65 | ``` 66 | 67 | This yields 77.96\% accuracy on the test data. 68 | 69 | 70 | ### Fine-tuning 71 | 72 | Not yet implemented. 73 | 74 | ## Implementation 75 | 76 | The default model uses ViT-Base for the encoder, and a small ViT (`depth=6`, `width=384`) for the decoder. This is smaller than the model used in the paper. 77 | 78 | ## Dependencies 79 | 80 | - Configuration and training is handled completely by [pytorch-lightning](https://pytorchlightning.ai). 81 | - The MAE model uses the VisionTransformer from [timm](https://github.com/rwightman/pytorch-image-models). 82 | - Interface to FGVC datasets through [fgvcdata](https://github.com/catalys1/fgvc-data-pytorch). 83 | - Configurable environment variables through [python-dotenv](https://pypi.org/project/python-dotenv/). 84 | 85 | ## Results 86 | 87 | Image reconstructions of CUB validation set images after training with the following command: 88 | ```bash 89 | python train.py fit -c config/mae.yaml -c config/data/cub_mae.yaml --data.init_args.batch_size 256 --data.init_args.num_workers 12 90 | ``` 91 | 92 | ![Bird Reconstructions](samples/bird-samples.png) 93 | 94 | Image reconstructions of Cars validation set images after training with the following command: 95 | ```bash 96 | python train.py fit -c config/mae.yaml -c config/data/cars_mae.yaml --data.init_args.batch_size 256 --data.init_args.num_workers 16 97 | ``` 98 | 99 | ![Cars Reconstructions](samples/car-samples.png) 100 | 101 | ### Hyperparameters 102 | 103 | | Param | Setting | 104 | | -- | -- | 105 | | GPUs | 1xA100 | 106 | | Batch size | 256 | 107 | | Learning rate | 1.5e-4 | 108 | | LR schedule | Cosine decay | 109 | | Warmup | 10% of steps | 110 | | Training steps | 78,000 | 111 | 112 | 113 | ### Training 114 | 115 | Training and validation loss curves for CUB. 116 | 117 | ![CUB training curves](samples/bird-training-curves.png) 118 | 119 | Validation image reconstructions over the course of training. 120 | 121 | ![CUB training progress](samples/birds-training.gif) 122 | -------------------------------------------------------------------------------- /config/data/cars_mae.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: datamodule.StanfordCarsDataModule 3 | init_args: 4 | data_dir: ${oc.env:DATADIR}/StanfordCars 5 | batch_size: 64 6 | num_workers: 8 7 | num_samples: 100000 -------------------------------------------------------------------------------- /config/data/cub_mae.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: datamodule.CubDataModule 3 | init_args: 4 | data_dir: ${oc.env:DATADIR}/CUB_200_2011/ 5 | batch_size: 64 6 | num_workers: 8 7 | num_samples: 100000 8 | -------------------------------------------------------------------------------- /config/data/stl10_linear_probe.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: datamodule.STL10LinearProbeDataModule 3 | init_args: 4 | data_dir: ${oc.env:DATADIR}/STL10 5 | batch_size: 64 6 | num_workers: 8 7 | num_samples: 100000 8 | -------------------------------------------------------------------------------- /config/data/stl10_mae.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | class_path: datamodule.STL10PretrainDataModule 3 | init_args: 4 | data_dir: ${oc.env:DATADIR}/STL10 5 | batch_size: 64 6 | num_workers: 8 7 | num_samples: 100000 8 | -------------------------------------------------------------------------------- /config/linear_probe.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 123 2 | 3 | trainer: 4 | devices: 1 5 | strategy: auto 6 | max_epochs: 100 7 | precision: 16-mixed 8 | callbacks: 9 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint 10 | init_args: 11 | filename: latest 12 | every_n_epochs: 1 13 | save_on_train_epoch_end: True 14 | - class_path: pytorch_lightning.callbacks.lr_monitor.LearningRateMonitor 15 | init_args: 16 | logging_interval: step 17 | 18 | model: 19 | class_path: mae.MAE_linear_probe 20 | init_args: 21 | ckpt_path: last.ckpt 22 | 23 | 24 | -------------------------------------------------------------------------------- /config/mae.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 123 2 | trainer: 3 | devices: 1 4 | strategy: auto 5 | max_epochs: 200 6 | #limit_val_batches: 0.0 7 | precision: 16-mixed 8 | callbacks: 9 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint 10 | init_args: 11 | filename: "{epoch}" 12 | save_last: True 13 | every_n_epochs: 10 14 | save_top_k: -1 15 | save_on_train_epoch_end: True 16 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint 17 | init_args: 18 | filename: latest 19 | every_n_epochs: 1 20 | save_on_train_epoch_end: True 21 | - class_path: pytorch_lightning.callbacks.lr_monitor.LearningRateMonitor 22 | init_args: 23 | logging_interval: step 24 | model: 25 | class_path: mae.MAE 26 | init_args: 27 | image_size: 28 | - 224 29 | - 224 30 | patch_size: 16 31 | keep: 0.25 32 | enc_width: 768 33 | dec_width: 0.5 34 | enc_depth: 12 35 | dec_depth: 6 36 | lr: 0.00015 37 | save_imgs_every: 1 38 | num_save_imgs: 36 39 | -------------------------------------------------------------------------------- /datamodule.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import fgvcdata 4 | from pytorch_lightning import LightningDataModule 5 | from torch.utils.data import DataLoader, random_split 6 | from torchvision.datasets import CIFAR10, CIFAR100, STL10 7 | from torchvision.transforms import transforms 8 | 9 | 10 | __all__ = [ 11 | 'CubDataModule', 12 | 'DogsDataModule', 13 | 'StanfordCarsDataModule', 14 | 'AircraftDataModule', 15 | 'Cifar10DataModule', 16 | 'Cifar100DataModule', 17 | 'STL10DataModule' 18 | ] 19 | 20 | 21 | to_tensor = transforms.Compose([ 22 | transforms.ToTensor(), 23 | transforms.Normalize((0.5,)*3, (0.5,)*3) 24 | ]) 25 | 26 | 27 | class DataRepeater(object): 28 | def __init__(self, dataset, size=None): 29 | self.data = dataset 30 | self.size = size 31 | if self.size is None: 32 | self.size = len(dataset) 33 | 34 | def __len__(self): 35 | return self.size 36 | 37 | def __getitem__(self, idx): 38 | i = idx % len(self.data) 39 | return self.data[i] 40 | 41 | 42 | class _BaseDataModule(LightningDataModule): 43 | def __init__( 44 | self, 45 | data_dir: str, 46 | batch_size: int = 64, 47 | num_workers: int = 4, 48 | pin_memory: bool = True, 49 | size: int = 224, 50 | augment: bool = True, 51 | num_samples: Optional[int] = None, 52 | ): 53 | super().__init__() 54 | 55 | self.augment = augment 56 | self.data_dir = data_dir 57 | self.batch_size = batch_size 58 | self.num_workers = num_workers 59 | self.pin_memory = pin_memory 60 | if isinstance(size, int): 61 | self.size = (size, size) 62 | else: 63 | self.size = size 64 | self.num_samples = num_samples 65 | 66 | def setup(self, stage=None): 67 | pass 68 | 69 | def train_dataloader(self): 70 | return DataLoader( 71 | dataset = self.data_train, 72 | batch_size = self.batch_size, 73 | num_workers = self.num_workers, 74 | pin_memory = self.pin_memory, 75 | shuffle = True, 76 | drop_last = True 77 | ) 78 | 79 | def val_dataloader(self): 80 | return DataLoader( 81 | dataset = self.data_val, 82 | batch_size = self.batch_size, 83 | num_workers = self.num_workers, 84 | pin_memory = self.pin_memory, 85 | shuffle = False, 86 | drop_last = False 87 | ) 88 | 89 | 90 | class _FGVCDataModule(_BaseDataModule): 91 | def transforms(self, crop_scale=(0.2, 1), val=False): 92 | if not val: 93 | tform = transforms.Compose([ 94 | transforms.RandomResizedCrop(self.size, scale=crop_scale), 95 | transforms.RandomHorizontalFlip(0.5), 96 | transforms.ColorJitter(0.25, 0.25, 0.25), 97 | to_tensor, 98 | ]) 99 | else: 100 | tform = transforms.Compose([ 101 | transforms.Resize([int(round(8 * x / 7)) for x in self.size]), 102 | transforms.CenterCrop(self.size), 103 | to_tensor, 104 | ]) 105 | return tform 106 | 107 | def setup(self, stage=None): 108 | self.data_train = self.dataclass( 109 | root = f'{self.data_dir}/train', 110 | transform = self.transforms(val=not self.augment) 111 | ) 112 | if self.num_samples is not None: 113 | self.data_train = DataRepeater(self.data_train, self.num_samples) 114 | 115 | self.data_val = self.dataclass( 116 | root = f'{self.data_dir}/val', 117 | transform = self.transforms(val=True) 118 | ) 119 | 120 | 121 | class CubDataModule(_FGVCDataModule): 122 | dataclass = fgvcdata.CUB 123 | num_class = 200 124 | 125 | class DogsDataModule(_FGVCDataModule): 126 | dataclass = fgvcdata.StanfordDogs 127 | num_class = 120 128 | 129 | 130 | class StanfordCarsDataModule(_FGVCDataModule): 131 | dataclass = fgvcdata.StanfordCars 132 | num_class = 196 133 | 134 | def transforms(self, crop_scale=(0.25, 1), val=False): 135 | return super().transforms(crop_scale, val) 136 | 137 | 138 | class AircraftDataModule(_FGVCDataModule): 139 | dataclass = fgvcdata.Aircraft 140 | num_class = 100 141 | 142 | def transforms(self, crop_scale=(0.25, 1), val=False): 143 | return super().transforms(crop_scale, val) 144 | 145 | 146 | class _CifarDataModule(_BaseDataModule): 147 | def prepare_data(self): 148 | self.dataclass(self.data_dir, download=True) 149 | 150 | def transforms(self, val=False): 151 | if not val: 152 | tform = transforms.Compose([ 153 | transforms.Resize(self.size), 154 | transforms.Pad(self.size[0] // 8, padding_mode='reflect'), 155 | transforms.RandomAffine((-10, 10), (0, 1/8), (1, 1.2)), 156 | transforms.CenterCrop(self.size), 157 | transforms.RandomHorizontalFlip(0.5), 158 | to_tensor(self.normalize), 159 | ]) 160 | else: 161 | tform = transforms.Compose([ 162 | transforms.Resize(self.size), 163 | to_tensor(self.normalize), 164 | ]) 165 | return tform 166 | 167 | def setup(self, stage=None): 168 | self.data_train = self.dataclass( 169 | root = self.data_dir, 170 | train = True, 171 | transform = self.transforms(not self.augment) 172 | ) 173 | self.data_val = self.dataclass( 174 | root = self.data_dir, 175 | train = False, 176 | transform = self.transforms(True) 177 | ) 178 | 179 | 180 | class Cifar10DataModule(_CifarDataModule): 181 | num_class = 10 182 | dataclass = CIFAR10 183 | 184 | 185 | class Cifar100DataModule(_CifarDataModule): 186 | num_class = 100 187 | dataclass = CIFAR100 188 | 189 | 190 | class STL10PretrainDataModule(_BaseDataModule): 191 | num_classes = 10 192 | def prepare_data(self): 193 | pass 194 | 195 | def transforms(self, val=False): 196 | if not val: 197 | tform = transforms.Compose([ 198 | transforms.Resize(self.size), 199 | transforms.Pad(self.size[0] // 8, padding_mode='reflect'), 200 | transforms.RandomAffine((-10, 10), (0, 1/8), (1, 1.2)), 201 | transforms.CenterCrop(self.size), 202 | transforms.RandomHorizontalFlip(0.5), 203 | transforms.ToTensor(), 204 | transforms.Normalize((0.4467, 0.4398, 0.4066), (0.2603, 0.2566, 0.2713)) 205 | ]) 206 | else: 207 | tform = transforms.Compose([ 208 | transforms.Resize(self.size), 209 | transforms.ToTensor(), 210 | transforms.Normalize((0.4467, 0.4398, 0.4066), (0.2603, 0.2566, 0.2713)) 211 | ]) 212 | return tform 213 | 214 | def setup(self, stage=None): 215 | self.data_train = STL10( 216 | root = self.data_dir, 217 | split = 'unlabeled', 218 | transform = self.transforms(not self.augment) 219 | ) 220 | self.data_val = STL10( 221 | root = self.data_dir, 222 | split = 'train', 223 | transform = self.transforms(True) 224 | ) 225 | 226 | class STL10LinearProbeDataModule(_BaseDataModule): 227 | num_classes = 10 228 | def prepare_data(self): 229 | pass 230 | 231 | def transforms(self): 232 | tform = transforms.Compose([ 233 | transforms.Resize(self.size), 234 | transforms.ToTensor(), 235 | transforms.Normalize((0.4467, 0.4398, 0.4066), (0.2603, 0.2566, 0.2713)) 236 | ]) 237 | return tform 238 | 239 | def setup(self, stage=None): 240 | self.data_train = STL10( 241 | root = self.data_dir, 242 | split = 'train', 243 | transform = self.transforms() 244 | ) 245 | 246 | data_val_test = STL10( 247 | root = self.data_dir, 248 | split = 'test', 249 | transform = self.transforms() 250 | ) 251 | test_size = int(0.7 * len(data_val_test)) 252 | val_size = len(data_val_test) - test_size 253 | self.data_val, self.data_test = random_split(data_val_test, [val_size, test_size]) 254 | 255 | def test_dataloader(self): 256 | return DataLoader( 257 | dataset = self.data_test, 258 | batch_size = self.batch_size, 259 | num_workers = self.num_workers, 260 | pin_memory = self.pin_memory, 261 | shuffle = False, 262 | drop_last = True 263 | ) 264 | 265 | -------------------------------------------------------------------------------- /lightning_readout.py: -------------------------------------------------------------------------------- 1 | import math 2 | from datamodule import STL10ReadoutDataModule 3 | import os 4 | from typing import Any, Tuple, Union 5 | from mae import MAE_linear_probing, MAE 6 | import pytorch_lightning as pl 7 | from pytorch_lightning import LightningModule, Trainer 8 | import timm 9 | import torch 10 | from torch import distributed 11 | import torchvision 12 | os.environ['CUDA_LAUNCH_BLOCKING'] = "1" 13 | # os.environ['TORCH_USE_CUDA_DSA'] = "1" 14 | # Load ViT model from ckpt 15 | ckpt_model = torch.load('last.ckpt') # upload checkpoint to aws 16 | model = MAE_linear_probing(ckpt_model) 17 | # model = MAE() 18 | 19 | # Prepare trainer with correct data splits 20 | stl10 = STL10ReadoutDataModule(data_dir='/scratch-shared/matt1/data', batch_size=64, num_workers=17, pin_memory=True, size=224, augment=True, num_samples=None) 21 | trainer = Trainer() 22 | trainer.fit(model, datamodule=stl10) 23 | pass 24 | -------------------------------------------------------------------------------- /linear_probe.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.cli import LightningCLI 2 | from dotenv import load_dotenv 3 | import torch 4 | from dotenv import load_dotenv 5 | 6 | if __name__ == '__main__': 7 | load_dotenv('.env') 8 | torch.set_float32_matmul_precision('medium') 9 | cli = LightningCLI(parser_kwargs={'parser_mode': 'omegaconf'}, run=False) 10 | cli.trainer.fit(cli.model, cli.datamodule) 11 | cli.trainer.test(cli.model, cli.datamodule) 12 | 13 | 14 | -------------------------------------------------------------------------------- /mae.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from typing import Any, Tuple, Union 4 | 5 | import pytorch_lightning as pl 6 | from pytorch_lightning.utilities.types import STEP_OUTPUT 7 | import timm 8 | import torch 9 | from torch import distributed 10 | import torchvision 11 | import torch.nn as nn 12 | 13 | 14 | ############################################################################## 15 | # Masked Autoencoder (MAE) 16 | ############################################################################## 17 | 18 | # Encoder: 19 | # Use the timm VisionTransformer, but all we need is the "blocks" and the final "norm" submodules 20 | # Add a fixed positional encoding at the beginning (sin-cos, original transformer style) 21 | # Add a linear projection on the output to match the decoder dimension 22 | 23 | # Decoder: 24 | # Use the timm VisionTransformer, as in the encoder 25 | # Position embeddings are added to the decoder input (sin-cos); note that they are different than 26 | # the encoder's, because the dimension is different 27 | # There is a shared, learnable [MASK] token that is used at every masked position 28 | # A classification token can be included, but it should work similarly without (using average pooling, 29 | # according to the paper); we don't include a classification token here 30 | 31 | # The loss is MSE computed only on the masked patches, as in the paper 32 | 33 | 34 | class ViTBlocks(torch.nn.Module): 35 | '''The main processing blocks of ViT. Excludes things like patch embedding and classificaton 36 | layer. 37 | 38 | Args: 39 | width: size of the feature dimension. 40 | depth: number of blocks in the network. 41 | end_norm: whether to end with LayerNorm or not. 42 | ''' 43 | def __init__( 44 | self, 45 | width: int = 768, 46 | depth: int = 12, 47 | end_norm: bool = True, 48 | ): 49 | super().__init__() 50 | 51 | # transformer blocks from ViT 52 | ViT = timm.models.vision_transformer.VisionTransformer 53 | vit = ViT(embed_dim=width, depth=depth) 54 | self.layers = vit.blocks 55 | if end_norm: 56 | # final normalization 57 | self.layers.add_module('norm', vit.norm) 58 | 59 | def forward(self, x: torch.Tensor): 60 | return self.layers(x) 61 | 62 | 63 | class MaskedAutoencoder(torch.nn.Module): 64 | '''Masked Autoencoder for visual representation learning. 65 | 66 | Args: 67 | image_size: (height, width) of the input images. 68 | patch_size: side length of a patch. 69 | keep: percentage of tokens to process in the encoder. (1 - keep) is the percentage of masked tokens. 70 | enc_width: width (feature dimension) of the encoder. 71 | dec_width: width (feature dimension) of the decoder. If a float, it is interpreted as a percentage 72 | of enc_width. 73 | enc_depth: depth (number of blocks) of the encoder 74 | dec_depth: depth (number of blocks) of the decoder 75 | ''' 76 | def __init__( 77 | self, 78 | image_size: Tuple[int, int] = (224, 224), 79 | patch_size: int = 16, 80 | keep: float = 0.25, 81 | enc_width: int = 768, 82 | dec_width: Union[int, float] = 0.25, 83 | enc_depth: int = 12, 84 | dec_depth: int = 4, 85 | ): 86 | super().__init__() 87 | 88 | assert image_size[0] % patch_size == 0 and image_size[1] % patch_size == 0 89 | 90 | self.image_size = image_size 91 | self.patch_size = patch_size 92 | self.keep = keep 93 | self.n = (image_size[0] * image_size[1]) // patch_size**2 # number of patches 94 | 95 | if isinstance(dec_width, float) and dec_width > 0 and dec_width < 1: 96 | dec_width = int(dec_width * enc_width) 97 | else: 98 | dec_width = int(dec_width) 99 | self.enc_width = enc_width 100 | self.dec_width = dec_width 101 | self.enc_depth = enc_depth 102 | self.dec_depth = dec_depth 103 | 104 | # linear patch embedding 105 | self.embed_conv = torch.nn.Conv2d(3, enc_width, patch_size, patch_size) 106 | 107 | # mask token and position encoding 108 | self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, dec_width, requires_grad=True)) 109 | self.register_buffer('pos_encoder', self.pos_encoding(self.n, enc_width).requires_grad_(False)) 110 | self.register_buffer('pos_decoder', self.pos_encoding(self.n, dec_width).requires_grad_(False)) 111 | 112 | # encoder 113 | self.encoder = ViTBlocks(width=enc_width, depth=enc_depth) 114 | 115 | # linear projection from enc_width to dec_width 116 | self.project = torch.nn.Linear(enc_width, dec_width) 117 | 118 | # decoder 119 | self.decoder = ViTBlocks(width=dec_width, depth=dec_depth, end_norm=False) 120 | 121 | # linear projection to pixel dimensions 122 | self.pixel_project = torch.nn.Linear(dec_width, 3 * patch_size**2) 123 | 124 | self.freeze_mask = False # set to True to reuse the same mask multiple times 125 | 126 | @property 127 | def freeze_mask(self): 128 | '''When True, the previously computed mask will be used on new inputs, instead of creating a new one.''' 129 | return self._freeze_mask 130 | 131 | @freeze_mask.setter 132 | def freeze_mask(self, val: bool): 133 | self._freeze_mask = val 134 | 135 | @staticmethod 136 | def pos_encoding(n: int, d: int, k: int=10000): 137 | '''Create sine-cosine positional embeddings. 138 | 139 | Args: 140 | n: the number of embedding vectors, corresponding to the number of tokens (patches) in the image. 141 | d: the dimension of the embeddings 142 | k: value that determines the maximum frequency (10,000 by default) 143 | 144 | Returns: 145 | (n, d) tensor of position encoding vectors 146 | ''' 147 | x = torch.meshgrid( 148 | torch.arange(n, dtype=torch.float32), 149 | torch.arange(d, dtype=torch.float32), 150 | indexing='ij' 151 | ) 152 | pos = torch.zeros_like(x[0]) 153 | pos[:, ::2] = x[0][:, ::2].div(torch.pow(k, x[1][:, ::2].div(d // 2))).sin_() 154 | pos[:, 1::2] = x[0][:,1::2].div(torch.pow(k, x[1][:,1::2].div(d // 2))).cos_() 155 | return pos 156 | 157 | @staticmethod 158 | def generate_mask_index(bs: int, n_tok: int, device: str='cpu'): 159 | '''Create a randomly permuted token-index tensor for determining which tokens to mask. 160 | 161 | Args: 162 | bs: batch size 163 | n_tok: number of tokens per image 164 | device: the device where the tensors should be created 165 | 166 | Returns: 167 | (bs, 1) tensor of batch indices [0, 1, ..., bs - 1]^T 168 | (bs, n_tok) tensor of token indices, randomly permuted 169 | ''' 170 | idx = torch.rand(bs, n_tok, device=device).argsort(dim=1) 171 | return idx 172 | 173 | @staticmethod 174 | def select_tokens(x: torch.Tensor, idx: torch.Tensor): 175 | '''Return the tokens from `x` corresponding to the indices in `idx`. 176 | ''' 177 | idx = idx.unsqueeze(-1).expand(-1, -1, x.shape[-1]) 178 | return x.gather(dim=1, index=idx) 179 | 180 | def image_as_tokens(self, x: torch.Tensor): 181 | '''Reshape an image of shape (b, c, h, w) to a set of vectorized patches 182 | of shape (b, h*w/p^2, c*p^2). In other words, the set of non-overlapping 183 | patches of size (3, p, p) in the image are turned into vectors (tokens); 184 | dimension 1 of the output indexes each patch. 185 | ''' 186 | b, c, h, w = x.shape 187 | p = self.patch_size 188 | x = x.reshape(b, c, h // p, p, w // p, p).permute(0, 2, 4, 1, 3, 5) 189 | x = x.reshape(b, (h * w) // p**2, c * p * p) 190 | return x 191 | 192 | def tokens_as_image(self, x: torch.Tensor): 193 | '''Reshape a set of token vectors into an image. This is the reverse operation 194 | of `image_as_tokens`. 195 | ''' 196 | b = x.shape[0] 197 | im, p = self.image_size, self.patch_size 198 | hh, ww = im[0] // p, im[1] // p 199 | x = x.reshape(b, hh, ww, 3, p, p).permute(0, 3, 1, 4, 2, 5) 200 | x = x.reshape(b, 3, p * hh, p * ww) 201 | return x 202 | 203 | def masked_image(self, x: torch.Tensor): 204 | '''Return a copy of the image batch, with the masked patches set to 0. Used 205 | for visualization. 206 | ''' 207 | x = self.image_as_tokens(x).clone() 208 | bidx = torch.arange(x.shape[0], device=x.device)[:, None] 209 | x[bidx, self.idx[:, int(self.keep * self.n):]] = 0 210 | return self.tokens_as_image(x) 211 | 212 | def embed(self, x: torch.Tensor): 213 | return self.embed_conv(x).flatten(2).transpose(1, 2) 214 | 215 | def mask_input(self, x: torch.Tensor): 216 | '''Mask the image patches uniformly at random, as described in the paper: the patch tokens are 217 | randomly permuted (per image), and the first N are returned, where N corresponds to percentage 218 | of patches kept (not masked). 219 | 220 | Returns the masked (truncated) tokens. The mask indices are saved as `self.bidx` and `self.idx`. 221 | ''' 222 | # create a new mask if self.freeze_mask is False, or if no mask has been created yet 223 | if not hasattr(self, 'idx') or not self.freeze_mask: 224 | self.idx = self.generate_mask_index(x.shape[0], x.shape[1], x.device) 225 | 226 | k = int(self.keep * self.n) 227 | x = self.select_tokens(x, self.idx[:, :k]) 228 | return x 229 | 230 | def forward_features(self, x: torch.Tensor): 231 | x = self.embed(x) 232 | x = x + self.pos_encoder 233 | x = self.mask_input(x) 234 | x = self.encoder(x) 235 | 236 | return x 237 | 238 | def forward(self, x: torch.Tensor): 239 | x = self.forward_features(x) 240 | x = self.project(x) 241 | 242 | k = self.n - x.shape[1] # number of masked tokens 243 | mask_toks = self.mask_token.expand(x.shape[0], k, -1) 244 | x = torch.cat([x, mask_toks], 1) 245 | x = self.select_tokens(x, self.idx.argsort(1)) 246 | x = x + self.pos_decoder 247 | x = self.decoder(x) 248 | x = self.pixel_project(x) 249 | 250 | return x 251 | 252 | 253 | class MAE(pl.LightningModule): 254 | '''Masked Autoencoder LightningModule. 255 | 256 | Args: 257 | image_size: (height, width) of the input images. 258 | patch_size: size of the image patches. 259 | keep: percentage of tokens to keep. (1 - keep) is the percentage of masked tokens. 260 | enc_width: width of the encoder features. 261 | dec_width: width of the decoder features. 262 | enc_depth: depth of the encoder. 263 | dec_depth: depth of the decoder. 264 | lr: learning rate 265 | save_imgs_every: save some reconstructions every nth epoch. 266 | num_save_immgs: number of reconstructed images to save. 267 | ''' 268 | def __init__( 269 | self, 270 | image_size: Tuple[int, int] = (224, 224), 271 | patch_size: int = 16, 272 | keep: float = 0.25, 273 | enc_width: int = 768, 274 | dec_width: Union[int, float] = 0.5, 275 | enc_depth: int = 12, 276 | dec_depth: int = 6, 277 | lr: float = 1.5e-4, 278 | base_batch_size: int = 256, 279 | normalize_for_loss: bool = False, 280 | save_imgs_every: int = 1, 281 | num_save_imgs: int = 36, 282 | ): 283 | super().__init__() 284 | 285 | self.mae = MaskedAutoencoder( 286 | image_size=image_size, 287 | patch_size=patch_size, 288 | keep=keep, 289 | enc_width=enc_width, 290 | enc_depth=enc_depth, 291 | dec_width=dec_width, 292 | dec_depth=dec_depth, 293 | ) 294 | 295 | self.keep = keep 296 | self.n = self.mae.n 297 | self.lr = lr 298 | self.base_batch_size = base_batch_size 299 | self.normalize_for_loss = normalize_for_loss 300 | self.save_imgs_every = save_imgs_every 301 | self.num_save_imgs = num_save_imgs 302 | 303 | self.saved_imgs_list = [] 304 | 305 | def on_train_batch_end(self, *args, **kwargs): 306 | if self.trainer.global_step == 2 and self.trainer.is_global_zero: 307 | # print GPU memory usage once at beginning of training 308 | avail, total = torch.cuda.mem_get_info() 309 | mem_used = 100 * (1 - (avail / total)) 310 | gb = 1024**3 311 | self.print(f'GPU memory used: {(total-avail)/gb:.2f} of {total/gb:.2f} GB ({mem_used:.2f}%)') 312 | if self.trainer.num_nodes > 1 or self.trainer.num_devices > 1: 313 | distributed.barrier() 314 | 315 | def training_step(self, batch: Any, batch_idx: int, *args, **kwargs): 316 | x, _ = batch 317 | pred = self.mae(x) 318 | loss = self.masked_mse_loss(x, pred) 319 | self.log('train/loss', loss, prog_bar=True) 320 | return {'loss': loss} 321 | 322 | def validation_step(self, batch: Any, batch_idx: int, *args, **kwargs): 323 | x, _ = batch 324 | pred = self.mae(x) 325 | loss = self.masked_mse_loss(x, pred) 326 | self.log('val/loss', loss, prog_bar=True, sync_dist=True) 327 | 328 | if self.save_imgs_every: 329 | p = int(self.save_imgs_every) 330 | if self.trainer.current_epoch % p == 0: 331 | nb = self.trainer.num_val_batches[0] 332 | ns = self.num_save_imgs 333 | per_batch = math.ceil(ns / nb) 334 | self.saved_imgs_list.append(pred[:per_batch]) 335 | 336 | return {'loss': loss} 337 | 338 | def on_validation_epoch_end(self): 339 | if self.save_imgs_every: 340 | if self.trainer.is_global_zero: 341 | imgs = torch.cat(self.saved_imgs_list, 0) 342 | self.saved_imgs_list.clear() 343 | self.save_imgs(imgs[:self.num_save_imgs]) 344 | if self.trainer.num_nodes > 1 or self.trainer.num_devices > 1: 345 | distributed.barrier() 346 | 347 | # @pl.utilities.rank_zero_only 348 | def save_imgs(self, imgs: torch.Tensor): 349 | with torch.no_grad(): 350 | r = int(imgs.shape[0]**0.5) 351 | imgs = self.mae.tokens_as_image(imgs.detach()) 352 | imgs = imgs.add_(1).mul_(127.5).clamp_(0, 255).byte() 353 | imgs = torchvision.utils.make_grid(imgs, r).cpu() 354 | epoch = self.trainer.current_epoch 355 | dir = os.path.join(self.trainer.log_dir, 'imgs') 356 | os.makedirs(dir, exist_ok=True) 357 | torchvision.io.write_png(imgs, os.path.join(dir, f'epoch_{epoch}_imgs.png')) 358 | 359 | def configure_optimizers(self): 360 | total_steps = self.trainer.estimated_stepping_batches 361 | devices, nodes = self.trainer.num_devices, self.trainer.num_nodes 362 | batch_size = self.trainer.train_dataloader.batch_size 363 | lr_scale = devices * nodes * batch_size / self.base_batch_size 364 | lr = self.lr * lr_scale 365 | 366 | optim = torch.optim.AdamW(self.parameters(), lr=lr, betas=(.9, .95), weight_decay=0.05) 367 | schedule = torch.optim.lr_scheduler.OneCycleLR( 368 | optim, 369 | max_lr=lr, 370 | total_steps=total_steps, 371 | pct_start=0.1, 372 | cycle_momentum=False, 373 | ) 374 | return { 375 | 'optimizer': optim, 376 | 'lr_scheduler': {'scheduler': schedule, 'interval': 'step'} 377 | } 378 | 379 | def masked_mse_loss(self, img: torch.Tensor, recon: torch.Tensor): 380 | # turn the image into patch-vectors for comparison to model output 381 | x = self.mae.image_as_tokens(img) 382 | if self.normalize_for_loss: 383 | std, mean = torch.std_mean(x, dim=-1, keepdim=True) 384 | x = x.sub(mean).div(std + 1e-5) 385 | # only compute on the mask token outputs, which is everything after the first (n * keep) 386 | idx = self.mae.idx[:, int(self.keep * self.n):] 387 | x = self.mae.select_tokens(x, idx) 388 | y = self.mae.select_tokens(recon, idx) 389 | return torch.nn.functional.mse_loss(x, y) 390 | 391 | class MAE_linear_probe(pl.LightningModule): 392 | '''Frozen MAE encoder with trainable linear readout to class labels 393 | https://lightning.ai/docs/pytorch/stable/advanced/transfer_learning.html 394 | 395 | ''' 396 | def __init__( 397 | self, 398 | ckpt_path: str, 399 | ): 400 | super().__init__() 401 | mae_module = MAE() 402 | mae_module.load_state_dict(torch.load(ckpt_path)['state_dict']) 403 | self.mae = mae_module.mae 404 | 405 | self.feature_extractor = self.mae.encoder 406 | 407 | self.classifier = torch.nn.Linear(self.mae.enc_width, 10) 408 | self.classifier.weight.data.normal_(mean=0.0, std=0.01) 409 | self.classifier.bias.data.zero_() 410 | 411 | def forward(self, x): 412 | x = self.mae.embed(x) 413 | x = x + self.mae.pos_encoder 414 | self.feature_extractor.eval() 415 | with torch.no_grad(): 416 | x = self.feature_extractor(x) 417 | x = x.mean(dim=1) # average pool over the patch dimension 418 | x = self.classifier(x) 419 | return x 420 | 421 | def configure_optimizers(self): 422 | optimizer = torch.optim.AdamW(self.parameters(), lr=5e-4) 423 | return optimizer 424 | 425 | def training_step(self, batch: Any, batch_idx: int, *args, **kwargs): 426 | x, labels = batch 427 | pred = self.forward(x) 428 | loss = self.loss_fn(pred, labels) 429 | self.log('train/loss', loss, prog_bar=True, sync_dist=True, on_step=False, on_epoch=True) 430 | return {'loss': loss} 431 | 432 | def validation_step(self, batch: Any, batch_idx: int, *args, **kwargs): 433 | x, labels = batch 434 | pred = self.forward(x) 435 | loss = self.loss_fn(pred, labels) 436 | _, predicted = torch.max(pred, 1) 437 | correct = (predicted == labels).sum().item() 438 | self.log('val/loss', loss, prog_bar=True, sync_dist=True, on_step=False, on_epoch=True) 439 | self.log('val/acc', correct / len(labels), prog_bar=True, on_step=False, sync_dist=True, on_epoch=True) 440 | 441 | def test_step(self, batch: Any, batch_idx: int, *args, **kwargs): 442 | x, labels = batch 443 | pred = self.forward(x) 444 | loss = self.loss_fn(pred, labels) 445 | _, predicted = torch.max(pred, 1) 446 | correct = (predicted == labels).sum().item() 447 | self.log('test/loss', loss, prog_bar=True, sync_dist=True, on_step=False, on_epoch=True) 448 | self.log('test/acc', correct / len(labels), prog_bar=True, on_step=False, sync_dist=True, on_epoch=True) 449 | 450 | def loss_fn(self, x, y): 451 | fn = torch.nn.CrossEntropyLoss() 452 | return fn(x, y) 453 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.0.0 2 | torchvision>=0.15.0 3 | pytorch-lightning>=2.0.2 4 | jsonargparse[signatures]>=4.21.1 5 | omegaconf>=2.1.1 6 | fgvcdata>=0.1.0 7 | timm>=0.9.2 8 | python-dotenv>=0.17.3 9 | pillow>=9.4.0 -------------------------------------------------------------------------------- /samples/bird-samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catalys1/mae-pytorch/1c6248f042fd6eccacb281ca5be47799de7c7d7d/samples/bird-samples.png -------------------------------------------------------------------------------- /samples/bird-training-curves.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catalys1/mae-pytorch/1c6248f042fd6eccacb281ca5be47799de7c7d7d/samples/bird-training-curves.png -------------------------------------------------------------------------------- /samples/birds-training.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catalys1/mae-pytorch/1c6248f042fd6eccacb281ca5be47799de7c7d7d/samples/birds-training.gif -------------------------------------------------------------------------------- /samples/car-samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/catalys1/mae-pytorch/1c6248f042fd6eccacb281ca5be47799de7c7d7d/samples/car-samples.png -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning.cli import LightningCLI 2 | import torch 3 | from dotenv import load_dotenv 4 | 5 | if __name__ == '__main__': 6 | 7 | load_dotenv('.env') 8 | torch.set_float32_matmul_precision('medium') 9 | cli = LightningCLI(parser_kwargs={'parser_mode': 'omegaconf'}) 10 | -------------------------------------------------------------------------------- /visualize.py: -------------------------------------------------------------------------------- 1 | import dotenv 2 | from omegaconf import OmegaConf 3 | from pathlib import Path 4 | from pytorch_lightning import seed_everything 5 | import random 6 | import torch 7 | import torchvision 8 | 9 | import datamodule 10 | import mae 11 | 12 | if __name__ == '__main__': 13 | import argparse 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('logdir', help='Path to logging directory for a trained model') 17 | parser.add_argument('-n', '--num_imgs', type=int, default=8, help='Number of images to display') 18 | parser.add_argument('-s', '--seed', type=int, default=987654321, help='Random seed') 19 | parser.add_argument('-d', '--device', type=str, default='cpu', help='Device type: "cpu" or "cuda"') 20 | 21 | args = parser.parse_args() 22 | 23 | dotenv.load_dotenv('.env') 24 | seed_everything(args.seed) 25 | device = args.device 26 | 27 | root = Path(args.logdir) 28 | config = OmegaConf.load(root.joinpath('config.yaml')) 29 | 30 | ### data setup 31 | print('Data... ', end='') 32 | dm_class = config.data.class_path.rsplit('.', 1)[-1] 33 | data_dir = config.data.init_args.data_dir 34 | dm = getattr(datamodule, dm_class)(data_dir) 35 | dm.setup() 36 | data = dm.data_val 37 | print(data) 38 | 39 | ### model setup 40 | print('Model... ', end='') 41 | ckpt_path = root.joinpath('checkpoints', 'last.ckpt') 42 | model = mae.MAE.load_from_checkpoint(ckpt_path, map_location='cpu') 43 | model = model.mae.to(device) 44 | print(model.__class__.__name__) 45 | 46 | ### get model predictions 47 | print('Getting predictions...', end='') 48 | img_indices = random.choices(range(len(data)), k=args.num_imgs) 49 | imgs = torch.stack([data[i][0] for i in img_indices], 0).to(device) 50 | 51 | preds = model.tokens_as_image(model(imgs)) 52 | masked = model.masked_image(imgs) 53 | print('done') 54 | 55 | ### create visualization 56 | viz = torchvision.utils.make_grid( 57 | torch.cat([imgs, masked, preds], 0).clamp(-1, 1), 58 | nrow=args.num_imgs, 59 | normalize=True, 60 | value_range=(-1, 1), 61 | ) 62 | viz = viz.mul_(255).byte() 63 | 64 | torchvision.io.write_png(viz, str(root.joinpath('samples.png'))) --------------------------------------------------------------------------------