├── .gitignore
├── LICENSE.txt
├── README.md
├── config
├── data
│ ├── cars_mae.yaml
│ ├── cub_mae.yaml
│ ├── stl10_linear_probe.yaml
│ └── stl10_mae.yaml
├── linear_probe.yaml
└── mae.yaml
├── datamodule.py
├── lightning_readout.py
├── linear_probe.py
├── mae.py
├── requirements.txt
├── samples
├── bird-samples.png
├── bird-training-curves.png
├── birds-training.gif
└── car-samples.png
├── train.py
└── visualize.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Logs
2 | *logs*/
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 | cover/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | .pybuilder/
79 | target/
80 |
81 | # Jupyter Notebook
82 | .ipynb_checkpoints
83 |
84 | # IPython
85 | profile_default/
86 | ipython_config.py
87 |
88 | # pyenv
89 | # For a library or package, you might want to ignore these files since the code is
90 | # intended to run in multiple environments; otherwise, check them in:
91 | # .python-version
92 |
93 | # pipenv
94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
97 | # install all needed dependencies.
98 | #Pipfile.lock
99 |
100 | # poetry
101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102 | # This is especially recommended for binary packages to ensure reproducibility, and is more
103 | # commonly ignored for libraries.
104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105 | #poetry.lock
106 |
107 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
108 | __pypackages__/
109 |
110 | # Celery stuff
111 | celerybeat-schedule
112 | celerybeat.pid
113 |
114 | # SageMath parsed files
115 | *.sage.py
116 |
117 | # Environments
118 | .env
119 | .venv
120 | env/
121 | venv/
122 | ENV/
123 | env.bak/
124 | venv.bak/
125 |
126 | # Spyder project settings
127 | .spyderproject
128 | .spyproject
129 |
130 | # Rope project settings
131 | .ropeproject
132 |
133 | # mkdocs documentation
134 | /site
135 |
136 | # mypy
137 | .mypy_cache/
138 | .dmypy.json
139 | dmypy.json
140 |
141 | # Pyre type checker
142 | .pyre/
143 |
144 | # pytype static type analyzer
145 | .pytype/
146 |
147 | # Cython debug symbols
148 | cython_debug/
149 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Connor Anderson
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
Masked Autoencoders in PyTorch
3 |
4 |

5 |

6 |
7 |
8 |
9 |
10 | A simple, unofficial implementation of MAE ([Masked Autoencoders are Scalable Vision Learners](https://arxiv.org/abs/2111.06377)) using [pytorch-lightning](https://www.pytorchlightning.ai/). A PyTorch implementation by the authors can be found [here](https://github.com/facebookresearch/mae).
11 |
12 | Currently implements training on [CUB](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html), [StanfordCars](http://ai.stanford.edu/~jkrause/cars/car_dataset.html), [STL-10](https://cs.stanford.edu/~acoates/stl10/) but is easily extensible to any other image dataset.
13 |
14 | ## Updates
15 |
16 | ### September 1, 2023
17 |
18 | - Updated for compatibility with Pytorch 2.0 and PyTorch-Lightning 2.0. This probably breaks backwards compatibility. Created a release for the old version of the code.
19 | - Modified parts of the training code for better conciseness and efficiency.
20 | - Added additional features, including the option to save some validation reconstructions during training. **Note**: having trouble with saving reconstructions during distributed training; freezes at the end of the validation epoch.
21 | - Retrained CUB and Cars models with new code and a stronger decoder.
22 |
23 | ### February 4, 2022
24 |
25 | - Fixed a bug in the code for generating mask indices. Retrained and updated the reconstruction figures (see below). They aren't quite as pretty now, but they make more sense.
26 |
27 | ## Setup
28 |
29 | ```bash
30 | # Clone the repository
31 | git clone https://github.com/catalys1/mae-pytorch.git
32 | cd mae-pytorch
33 |
34 | # Install required libraries (inside a virtual environment preferably)
35 | pip install -r requirements.txt
36 |
37 | # Set up .env for path to data
38 | echo "DATADIR=/path/to/data" > .env
39 | ```
40 |
41 | ## Usage
42 |
43 | ### MAE training
44 |
45 | Training options are provided through configuration files, handled by [LightningCLI](https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_cli.html). See `config/` for examples.
46 |
47 | Train an MAE model on the CUB dataset:
48 | ```bash
49 | python train.py fit -c config/mae.yaml -c config/data/cub_mae.yaml
50 | ```
51 |
52 | Using multiple GPUs:
53 | ```bash
54 | python train.py fit -c config/mae.yaml -c config/data/cub_mae.yaml --trainer.devices 8
55 | ```
56 | ### Linear probing (currently only supported on STL-10)
57 | Evaluate the learned representations using a linear probe. First, pretrain the model on the 100.000 samples of the 'unlabeled' split.
58 | ```bash
59 | python train.py fit -c config/mae.yaml -c config/data/stl10_mae.yaml
60 | ```
61 | Now, append a linear probe to the last layer of the frozen encoder and discard the decoder. The appended classifier is then trained on 4000 labeled samples of the 'train' split (another 1000 are used for training validation) and evaluated on the 'test' split. To do so, simply provide the path to the pretrained model checkpoint in the command below.
62 |
63 | ```bash
64 | python linear_probe.py -c config/linear_probe.yaml -c config/data/stl10_linear_probe.yaml --model.init_args.ckpt_path
65 | ```
66 |
67 | This yields 77.96\% accuracy on the test data.
68 |
69 |
70 | ### Fine-tuning
71 |
72 | Not yet implemented.
73 |
74 | ## Implementation
75 |
76 | The default model uses ViT-Base for the encoder, and a small ViT (`depth=6`, `width=384`) for the decoder. This is smaller than the model used in the paper.
77 |
78 | ## Dependencies
79 |
80 | - Configuration and training is handled completely by [pytorch-lightning](https://pytorchlightning.ai).
81 | - The MAE model uses the VisionTransformer from [timm](https://github.com/rwightman/pytorch-image-models).
82 | - Interface to FGVC datasets through [fgvcdata](https://github.com/catalys1/fgvc-data-pytorch).
83 | - Configurable environment variables through [python-dotenv](https://pypi.org/project/python-dotenv/).
84 |
85 | ## Results
86 |
87 | Image reconstructions of CUB validation set images after training with the following command:
88 | ```bash
89 | python train.py fit -c config/mae.yaml -c config/data/cub_mae.yaml --data.init_args.batch_size 256 --data.init_args.num_workers 12
90 | ```
91 |
92 | 
93 |
94 | Image reconstructions of Cars validation set images after training with the following command:
95 | ```bash
96 | python train.py fit -c config/mae.yaml -c config/data/cars_mae.yaml --data.init_args.batch_size 256 --data.init_args.num_workers 16
97 | ```
98 |
99 | 
100 |
101 | ### Hyperparameters
102 |
103 | | Param | Setting |
104 | | -- | -- |
105 | | GPUs | 1xA100 |
106 | | Batch size | 256 |
107 | | Learning rate | 1.5e-4 |
108 | | LR schedule | Cosine decay |
109 | | Warmup | 10% of steps |
110 | | Training steps | 78,000 |
111 |
112 |
113 | ### Training
114 |
115 | Training and validation loss curves for CUB.
116 |
117 | 
118 |
119 | Validation image reconstructions over the course of training.
120 |
121 | 
122 |
--------------------------------------------------------------------------------
/config/data/cars_mae.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | class_path: datamodule.StanfordCarsDataModule
3 | init_args:
4 | data_dir: ${oc.env:DATADIR}/StanfordCars
5 | batch_size: 64
6 | num_workers: 8
7 | num_samples: 100000
--------------------------------------------------------------------------------
/config/data/cub_mae.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | class_path: datamodule.CubDataModule
3 | init_args:
4 | data_dir: ${oc.env:DATADIR}/CUB_200_2011/
5 | batch_size: 64
6 | num_workers: 8
7 | num_samples: 100000
8 |
--------------------------------------------------------------------------------
/config/data/stl10_linear_probe.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | class_path: datamodule.STL10LinearProbeDataModule
3 | init_args:
4 | data_dir: ${oc.env:DATADIR}/STL10
5 | batch_size: 64
6 | num_workers: 8
7 | num_samples: 100000
8 |
--------------------------------------------------------------------------------
/config/data/stl10_mae.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | class_path: datamodule.STL10PretrainDataModule
3 | init_args:
4 | data_dir: ${oc.env:DATADIR}/STL10
5 | batch_size: 64
6 | num_workers: 8
7 | num_samples: 100000
8 |
--------------------------------------------------------------------------------
/config/linear_probe.yaml:
--------------------------------------------------------------------------------
1 | seed_everything: 123
2 |
3 | trainer:
4 | devices: 1
5 | strategy: auto
6 | max_epochs: 100
7 | precision: 16-mixed
8 | callbacks:
9 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint
10 | init_args:
11 | filename: latest
12 | every_n_epochs: 1
13 | save_on_train_epoch_end: True
14 | - class_path: pytorch_lightning.callbacks.lr_monitor.LearningRateMonitor
15 | init_args:
16 | logging_interval: step
17 |
18 | model:
19 | class_path: mae.MAE_linear_probe
20 | init_args:
21 | ckpt_path: last.ckpt
22 |
23 |
24 |
--------------------------------------------------------------------------------
/config/mae.yaml:
--------------------------------------------------------------------------------
1 | seed_everything: 123
2 | trainer:
3 | devices: 1
4 | strategy: auto
5 | max_epochs: 200
6 | #limit_val_batches: 0.0
7 | precision: 16-mixed
8 | callbacks:
9 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint
10 | init_args:
11 | filename: "{epoch}"
12 | save_last: True
13 | every_n_epochs: 10
14 | save_top_k: -1
15 | save_on_train_epoch_end: True
16 | - class_path: pytorch_lightning.callbacks.ModelCheckpoint
17 | init_args:
18 | filename: latest
19 | every_n_epochs: 1
20 | save_on_train_epoch_end: True
21 | - class_path: pytorch_lightning.callbacks.lr_monitor.LearningRateMonitor
22 | init_args:
23 | logging_interval: step
24 | model:
25 | class_path: mae.MAE
26 | init_args:
27 | image_size:
28 | - 224
29 | - 224
30 | patch_size: 16
31 | keep: 0.25
32 | enc_width: 768
33 | dec_width: 0.5
34 | enc_depth: 12
35 | dec_depth: 6
36 | lr: 0.00015
37 | save_imgs_every: 1
38 | num_save_imgs: 36
39 |
--------------------------------------------------------------------------------
/datamodule.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import fgvcdata
4 | from pytorch_lightning import LightningDataModule
5 | from torch.utils.data import DataLoader, random_split
6 | from torchvision.datasets import CIFAR10, CIFAR100, STL10
7 | from torchvision.transforms import transforms
8 |
9 |
10 | __all__ = [
11 | 'CubDataModule',
12 | 'DogsDataModule',
13 | 'StanfordCarsDataModule',
14 | 'AircraftDataModule',
15 | 'Cifar10DataModule',
16 | 'Cifar100DataModule',
17 | 'STL10DataModule'
18 | ]
19 |
20 |
21 | to_tensor = transforms.Compose([
22 | transforms.ToTensor(),
23 | transforms.Normalize((0.5,)*3, (0.5,)*3)
24 | ])
25 |
26 |
27 | class DataRepeater(object):
28 | def __init__(self, dataset, size=None):
29 | self.data = dataset
30 | self.size = size
31 | if self.size is None:
32 | self.size = len(dataset)
33 |
34 | def __len__(self):
35 | return self.size
36 |
37 | def __getitem__(self, idx):
38 | i = idx % len(self.data)
39 | return self.data[i]
40 |
41 |
42 | class _BaseDataModule(LightningDataModule):
43 | def __init__(
44 | self,
45 | data_dir: str,
46 | batch_size: int = 64,
47 | num_workers: int = 4,
48 | pin_memory: bool = True,
49 | size: int = 224,
50 | augment: bool = True,
51 | num_samples: Optional[int] = None,
52 | ):
53 | super().__init__()
54 |
55 | self.augment = augment
56 | self.data_dir = data_dir
57 | self.batch_size = batch_size
58 | self.num_workers = num_workers
59 | self.pin_memory = pin_memory
60 | if isinstance(size, int):
61 | self.size = (size, size)
62 | else:
63 | self.size = size
64 | self.num_samples = num_samples
65 |
66 | def setup(self, stage=None):
67 | pass
68 |
69 | def train_dataloader(self):
70 | return DataLoader(
71 | dataset = self.data_train,
72 | batch_size = self.batch_size,
73 | num_workers = self.num_workers,
74 | pin_memory = self.pin_memory,
75 | shuffle = True,
76 | drop_last = True
77 | )
78 |
79 | def val_dataloader(self):
80 | return DataLoader(
81 | dataset = self.data_val,
82 | batch_size = self.batch_size,
83 | num_workers = self.num_workers,
84 | pin_memory = self.pin_memory,
85 | shuffle = False,
86 | drop_last = False
87 | )
88 |
89 |
90 | class _FGVCDataModule(_BaseDataModule):
91 | def transforms(self, crop_scale=(0.2, 1), val=False):
92 | if not val:
93 | tform = transforms.Compose([
94 | transforms.RandomResizedCrop(self.size, scale=crop_scale),
95 | transforms.RandomHorizontalFlip(0.5),
96 | transforms.ColorJitter(0.25, 0.25, 0.25),
97 | to_tensor,
98 | ])
99 | else:
100 | tform = transforms.Compose([
101 | transforms.Resize([int(round(8 * x / 7)) for x in self.size]),
102 | transforms.CenterCrop(self.size),
103 | to_tensor,
104 | ])
105 | return tform
106 |
107 | def setup(self, stage=None):
108 | self.data_train = self.dataclass(
109 | root = f'{self.data_dir}/train',
110 | transform = self.transforms(val=not self.augment)
111 | )
112 | if self.num_samples is not None:
113 | self.data_train = DataRepeater(self.data_train, self.num_samples)
114 |
115 | self.data_val = self.dataclass(
116 | root = f'{self.data_dir}/val',
117 | transform = self.transforms(val=True)
118 | )
119 |
120 |
121 | class CubDataModule(_FGVCDataModule):
122 | dataclass = fgvcdata.CUB
123 | num_class = 200
124 |
125 | class DogsDataModule(_FGVCDataModule):
126 | dataclass = fgvcdata.StanfordDogs
127 | num_class = 120
128 |
129 |
130 | class StanfordCarsDataModule(_FGVCDataModule):
131 | dataclass = fgvcdata.StanfordCars
132 | num_class = 196
133 |
134 | def transforms(self, crop_scale=(0.25, 1), val=False):
135 | return super().transforms(crop_scale, val)
136 |
137 |
138 | class AircraftDataModule(_FGVCDataModule):
139 | dataclass = fgvcdata.Aircraft
140 | num_class = 100
141 |
142 | def transforms(self, crop_scale=(0.25, 1), val=False):
143 | return super().transforms(crop_scale, val)
144 |
145 |
146 | class _CifarDataModule(_BaseDataModule):
147 | def prepare_data(self):
148 | self.dataclass(self.data_dir, download=True)
149 |
150 | def transforms(self, val=False):
151 | if not val:
152 | tform = transforms.Compose([
153 | transforms.Resize(self.size),
154 | transforms.Pad(self.size[0] // 8, padding_mode='reflect'),
155 | transforms.RandomAffine((-10, 10), (0, 1/8), (1, 1.2)),
156 | transforms.CenterCrop(self.size),
157 | transforms.RandomHorizontalFlip(0.5),
158 | to_tensor(self.normalize),
159 | ])
160 | else:
161 | tform = transforms.Compose([
162 | transforms.Resize(self.size),
163 | to_tensor(self.normalize),
164 | ])
165 | return tform
166 |
167 | def setup(self, stage=None):
168 | self.data_train = self.dataclass(
169 | root = self.data_dir,
170 | train = True,
171 | transform = self.transforms(not self.augment)
172 | )
173 | self.data_val = self.dataclass(
174 | root = self.data_dir,
175 | train = False,
176 | transform = self.transforms(True)
177 | )
178 |
179 |
180 | class Cifar10DataModule(_CifarDataModule):
181 | num_class = 10
182 | dataclass = CIFAR10
183 |
184 |
185 | class Cifar100DataModule(_CifarDataModule):
186 | num_class = 100
187 | dataclass = CIFAR100
188 |
189 |
190 | class STL10PretrainDataModule(_BaseDataModule):
191 | num_classes = 10
192 | def prepare_data(self):
193 | pass
194 |
195 | def transforms(self, val=False):
196 | if not val:
197 | tform = transforms.Compose([
198 | transforms.Resize(self.size),
199 | transforms.Pad(self.size[0] // 8, padding_mode='reflect'),
200 | transforms.RandomAffine((-10, 10), (0, 1/8), (1, 1.2)),
201 | transforms.CenterCrop(self.size),
202 | transforms.RandomHorizontalFlip(0.5),
203 | transforms.ToTensor(),
204 | transforms.Normalize((0.4467, 0.4398, 0.4066), (0.2603, 0.2566, 0.2713))
205 | ])
206 | else:
207 | tform = transforms.Compose([
208 | transforms.Resize(self.size),
209 | transforms.ToTensor(),
210 | transforms.Normalize((0.4467, 0.4398, 0.4066), (0.2603, 0.2566, 0.2713))
211 | ])
212 | return tform
213 |
214 | def setup(self, stage=None):
215 | self.data_train = STL10(
216 | root = self.data_dir,
217 | split = 'unlabeled',
218 | transform = self.transforms(not self.augment)
219 | )
220 | self.data_val = STL10(
221 | root = self.data_dir,
222 | split = 'train',
223 | transform = self.transforms(True)
224 | )
225 |
226 | class STL10LinearProbeDataModule(_BaseDataModule):
227 | num_classes = 10
228 | def prepare_data(self):
229 | pass
230 |
231 | def transforms(self):
232 | tform = transforms.Compose([
233 | transforms.Resize(self.size),
234 | transforms.ToTensor(),
235 | transforms.Normalize((0.4467, 0.4398, 0.4066), (0.2603, 0.2566, 0.2713))
236 | ])
237 | return tform
238 |
239 | def setup(self, stage=None):
240 | self.data_train = STL10(
241 | root = self.data_dir,
242 | split = 'train',
243 | transform = self.transforms()
244 | )
245 |
246 | data_val_test = STL10(
247 | root = self.data_dir,
248 | split = 'test',
249 | transform = self.transforms()
250 | )
251 | test_size = int(0.7 * len(data_val_test))
252 | val_size = len(data_val_test) - test_size
253 | self.data_val, self.data_test = random_split(data_val_test, [val_size, test_size])
254 |
255 | def test_dataloader(self):
256 | return DataLoader(
257 | dataset = self.data_test,
258 | batch_size = self.batch_size,
259 | num_workers = self.num_workers,
260 | pin_memory = self.pin_memory,
261 | shuffle = False,
262 | drop_last = True
263 | )
264 |
265 |
--------------------------------------------------------------------------------
/lightning_readout.py:
--------------------------------------------------------------------------------
1 | import math
2 | from datamodule import STL10ReadoutDataModule
3 | import os
4 | from typing import Any, Tuple, Union
5 | from mae import MAE_linear_probing, MAE
6 | import pytorch_lightning as pl
7 | from pytorch_lightning import LightningModule, Trainer
8 | import timm
9 | import torch
10 | from torch import distributed
11 | import torchvision
12 | os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
13 | # os.environ['TORCH_USE_CUDA_DSA'] = "1"
14 | # Load ViT model from ckpt
15 | ckpt_model = torch.load('last.ckpt') # upload checkpoint to aws
16 | model = MAE_linear_probing(ckpt_model)
17 | # model = MAE()
18 |
19 | # Prepare trainer with correct data splits
20 | stl10 = STL10ReadoutDataModule(data_dir='/scratch-shared/matt1/data', batch_size=64, num_workers=17, pin_memory=True, size=224, augment=True, num_samples=None)
21 | trainer = Trainer()
22 | trainer.fit(model, datamodule=stl10)
23 | pass
24 |
--------------------------------------------------------------------------------
/linear_probe.py:
--------------------------------------------------------------------------------
1 | from pytorch_lightning.cli import LightningCLI
2 | from dotenv import load_dotenv
3 | import torch
4 | from dotenv import load_dotenv
5 |
6 | if __name__ == '__main__':
7 | load_dotenv('.env')
8 | torch.set_float32_matmul_precision('medium')
9 | cli = LightningCLI(parser_kwargs={'parser_mode': 'omegaconf'}, run=False)
10 | cli.trainer.fit(cli.model, cli.datamodule)
11 | cli.trainer.test(cli.model, cli.datamodule)
12 |
13 |
14 |
--------------------------------------------------------------------------------
/mae.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | from typing import Any, Tuple, Union
4 |
5 | import pytorch_lightning as pl
6 | from pytorch_lightning.utilities.types import STEP_OUTPUT
7 | import timm
8 | import torch
9 | from torch import distributed
10 | import torchvision
11 | import torch.nn as nn
12 |
13 |
14 | ##############################################################################
15 | # Masked Autoencoder (MAE)
16 | ##############################################################################
17 |
18 | # Encoder:
19 | # Use the timm VisionTransformer, but all we need is the "blocks" and the final "norm" submodules
20 | # Add a fixed positional encoding at the beginning (sin-cos, original transformer style)
21 | # Add a linear projection on the output to match the decoder dimension
22 |
23 | # Decoder:
24 | # Use the timm VisionTransformer, as in the encoder
25 | # Position embeddings are added to the decoder input (sin-cos); note that they are different than
26 | # the encoder's, because the dimension is different
27 | # There is a shared, learnable [MASK] token that is used at every masked position
28 | # A classification token can be included, but it should work similarly without (using average pooling,
29 | # according to the paper); we don't include a classification token here
30 |
31 | # The loss is MSE computed only on the masked patches, as in the paper
32 |
33 |
34 | class ViTBlocks(torch.nn.Module):
35 | '''The main processing blocks of ViT. Excludes things like patch embedding and classificaton
36 | layer.
37 |
38 | Args:
39 | width: size of the feature dimension.
40 | depth: number of blocks in the network.
41 | end_norm: whether to end with LayerNorm or not.
42 | '''
43 | def __init__(
44 | self,
45 | width: int = 768,
46 | depth: int = 12,
47 | end_norm: bool = True,
48 | ):
49 | super().__init__()
50 |
51 | # transformer blocks from ViT
52 | ViT = timm.models.vision_transformer.VisionTransformer
53 | vit = ViT(embed_dim=width, depth=depth)
54 | self.layers = vit.blocks
55 | if end_norm:
56 | # final normalization
57 | self.layers.add_module('norm', vit.norm)
58 |
59 | def forward(self, x: torch.Tensor):
60 | return self.layers(x)
61 |
62 |
63 | class MaskedAutoencoder(torch.nn.Module):
64 | '''Masked Autoencoder for visual representation learning.
65 |
66 | Args:
67 | image_size: (height, width) of the input images.
68 | patch_size: side length of a patch.
69 | keep: percentage of tokens to process in the encoder. (1 - keep) is the percentage of masked tokens.
70 | enc_width: width (feature dimension) of the encoder.
71 | dec_width: width (feature dimension) of the decoder. If a float, it is interpreted as a percentage
72 | of enc_width.
73 | enc_depth: depth (number of blocks) of the encoder
74 | dec_depth: depth (number of blocks) of the decoder
75 | '''
76 | def __init__(
77 | self,
78 | image_size: Tuple[int, int] = (224, 224),
79 | patch_size: int = 16,
80 | keep: float = 0.25,
81 | enc_width: int = 768,
82 | dec_width: Union[int, float] = 0.25,
83 | enc_depth: int = 12,
84 | dec_depth: int = 4,
85 | ):
86 | super().__init__()
87 |
88 | assert image_size[0] % patch_size == 0 and image_size[1] % patch_size == 0
89 |
90 | self.image_size = image_size
91 | self.patch_size = patch_size
92 | self.keep = keep
93 | self.n = (image_size[0] * image_size[1]) // patch_size**2 # number of patches
94 |
95 | if isinstance(dec_width, float) and dec_width > 0 and dec_width < 1:
96 | dec_width = int(dec_width * enc_width)
97 | else:
98 | dec_width = int(dec_width)
99 | self.enc_width = enc_width
100 | self.dec_width = dec_width
101 | self.enc_depth = enc_depth
102 | self.dec_depth = dec_depth
103 |
104 | # linear patch embedding
105 | self.embed_conv = torch.nn.Conv2d(3, enc_width, patch_size, patch_size)
106 |
107 | # mask token and position encoding
108 | self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, dec_width, requires_grad=True))
109 | self.register_buffer('pos_encoder', self.pos_encoding(self.n, enc_width).requires_grad_(False))
110 | self.register_buffer('pos_decoder', self.pos_encoding(self.n, dec_width).requires_grad_(False))
111 |
112 | # encoder
113 | self.encoder = ViTBlocks(width=enc_width, depth=enc_depth)
114 |
115 | # linear projection from enc_width to dec_width
116 | self.project = torch.nn.Linear(enc_width, dec_width)
117 |
118 | # decoder
119 | self.decoder = ViTBlocks(width=dec_width, depth=dec_depth, end_norm=False)
120 |
121 | # linear projection to pixel dimensions
122 | self.pixel_project = torch.nn.Linear(dec_width, 3 * patch_size**2)
123 |
124 | self.freeze_mask = False # set to True to reuse the same mask multiple times
125 |
126 | @property
127 | def freeze_mask(self):
128 | '''When True, the previously computed mask will be used on new inputs, instead of creating a new one.'''
129 | return self._freeze_mask
130 |
131 | @freeze_mask.setter
132 | def freeze_mask(self, val: bool):
133 | self._freeze_mask = val
134 |
135 | @staticmethod
136 | def pos_encoding(n: int, d: int, k: int=10000):
137 | '''Create sine-cosine positional embeddings.
138 |
139 | Args:
140 | n: the number of embedding vectors, corresponding to the number of tokens (patches) in the image.
141 | d: the dimension of the embeddings
142 | k: value that determines the maximum frequency (10,000 by default)
143 |
144 | Returns:
145 | (n, d) tensor of position encoding vectors
146 | '''
147 | x = torch.meshgrid(
148 | torch.arange(n, dtype=torch.float32),
149 | torch.arange(d, dtype=torch.float32),
150 | indexing='ij'
151 | )
152 | pos = torch.zeros_like(x[0])
153 | pos[:, ::2] = x[0][:, ::2].div(torch.pow(k, x[1][:, ::2].div(d // 2))).sin_()
154 | pos[:, 1::2] = x[0][:,1::2].div(torch.pow(k, x[1][:,1::2].div(d // 2))).cos_()
155 | return pos
156 |
157 | @staticmethod
158 | def generate_mask_index(bs: int, n_tok: int, device: str='cpu'):
159 | '''Create a randomly permuted token-index tensor for determining which tokens to mask.
160 |
161 | Args:
162 | bs: batch size
163 | n_tok: number of tokens per image
164 | device: the device where the tensors should be created
165 |
166 | Returns:
167 | (bs, 1) tensor of batch indices [0, 1, ..., bs - 1]^T
168 | (bs, n_tok) tensor of token indices, randomly permuted
169 | '''
170 | idx = torch.rand(bs, n_tok, device=device).argsort(dim=1)
171 | return idx
172 |
173 | @staticmethod
174 | def select_tokens(x: torch.Tensor, idx: torch.Tensor):
175 | '''Return the tokens from `x` corresponding to the indices in `idx`.
176 | '''
177 | idx = idx.unsqueeze(-1).expand(-1, -1, x.shape[-1])
178 | return x.gather(dim=1, index=idx)
179 |
180 | def image_as_tokens(self, x: torch.Tensor):
181 | '''Reshape an image of shape (b, c, h, w) to a set of vectorized patches
182 | of shape (b, h*w/p^2, c*p^2). In other words, the set of non-overlapping
183 | patches of size (3, p, p) in the image are turned into vectors (tokens);
184 | dimension 1 of the output indexes each patch.
185 | '''
186 | b, c, h, w = x.shape
187 | p = self.patch_size
188 | x = x.reshape(b, c, h // p, p, w // p, p).permute(0, 2, 4, 1, 3, 5)
189 | x = x.reshape(b, (h * w) // p**2, c * p * p)
190 | return x
191 |
192 | def tokens_as_image(self, x: torch.Tensor):
193 | '''Reshape a set of token vectors into an image. This is the reverse operation
194 | of `image_as_tokens`.
195 | '''
196 | b = x.shape[0]
197 | im, p = self.image_size, self.patch_size
198 | hh, ww = im[0] // p, im[1] // p
199 | x = x.reshape(b, hh, ww, 3, p, p).permute(0, 3, 1, 4, 2, 5)
200 | x = x.reshape(b, 3, p * hh, p * ww)
201 | return x
202 |
203 | def masked_image(self, x: torch.Tensor):
204 | '''Return a copy of the image batch, with the masked patches set to 0. Used
205 | for visualization.
206 | '''
207 | x = self.image_as_tokens(x).clone()
208 | bidx = torch.arange(x.shape[0], device=x.device)[:, None]
209 | x[bidx, self.idx[:, int(self.keep * self.n):]] = 0
210 | return self.tokens_as_image(x)
211 |
212 | def embed(self, x: torch.Tensor):
213 | return self.embed_conv(x).flatten(2).transpose(1, 2)
214 |
215 | def mask_input(self, x: torch.Tensor):
216 | '''Mask the image patches uniformly at random, as described in the paper: the patch tokens are
217 | randomly permuted (per image), and the first N are returned, where N corresponds to percentage
218 | of patches kept (not masked).
219 |
220 | Returns the masked (truncated) tokens. The mask indices are saved as `self.bidx` and `self.idx`.
221 | '''
222 | # create a new mask if self.freeze_mask is False, or if no mask has been created yet
223 | if not hasattr(self, 'idx') or not self.freeze_mask:
224 | self.idx = self.generate_mask_index(x.shape[0], x.shape[1], x.device)
225 |
226 | k = int(self.keep * self.n)
227 | x = self.select_tokens(x, self.idx[:, :k])
228 | return x
229 |
230 | def forward_features(self, x: torch.Tensor):
231 | x = self.embed(x)
232 | x = x + self.pos_encoder
233 | x = self.mask_input(x)
234 | x = self.encoder(x)
235 |
236 | return x
237 |
238 | def forward(self, x: torch.Tensor):
239 | x = self.forward_features(x)
240 | x = self.project(x)
241 |
242 | k = self.n - x.shape[1] # number of masked tokens
243 | mask_toks = self.mask_token.expand(x.shape[0], k, -1)
244 | x = torch.cat([x, mask_toks], 1)
245 | x = self.select_tokens(x, self.idx.argsort(1))
246 | x = x + self.pos_decoder
247 | x = self.decoder(x)
248 | x = self.pixel_project(x)
249 |
250 | return x
251 |
252 |
253 | class MAE(pl.LightningModule):
254 | '''Masked Autoencoder LightningModule.
255 |
256 | Args:
257 | image_size: (height, width) of the input images.
258 | patch_size: size of the image patches.
259 | keep: percentage of tokens to keep. (1 - keep) is the percentage of masked tokens.
260 | enc_width: width of the encoder features.
261 | dec_width: width of the decoder features.
262 | enc_depth: depth of the encoder.
263 | dec_depth: depth of the decoder.
264 | lr: learning rate
265 | save_imgs_every: save some reconstructions every nth epoch.
266 | num_save_immgs: number of reconstructed images to save.
267 | '''
268 | def __init__(
269 | self,
270 | image_size: Tuple[int, int] = (224, 224),
271 | patch_size: int = 16,
272 | keep: float = 0.25,
273 | enc_width: int = 768,
274 | dec_width: Union[int, float] = 0.5,
275 | enc_depth: int = 12,
276 | dec_depth: int = 6,
277 | lr: float = 1.5e-4,
278 | base_batch_size: int = 256,
279 | normalize_for_loss: bool = False,
280 | save_imgs_every: int = 1,
281 | num_save_imgs: int = 36,
282 | ):
283 | super().__init__()
284 |
285 | self.mae = MaskedAutoencoder(
286 | image_size=image_size,
287 | patch_size=patch_size,
288 | keep=keep,
289 | enc_width=enc_width,
290 | enc_depth=enc_depth,
291 | dec_width=dec_width,
292 | dec_depth=dec_depth,
293 | )
294 |
295 | self.keep = keep
296 | self.n = self.mae.n
297 | self.lr = lr
298 | self.base_batch_size = base_batch_size
299 | self.normalize_for_loss = normalize_for_loss
300 | self.save_imgs_every = save_imgs_every
301 | self.num_save_imgs = num_save_imgs
302 |
303 | self.saved_imgs_list = []
304 |
305 | def on_train_batch_end(self, *args, **kwargs):
306 | if self.trainer.global_step == 2 and self.trainer.is_global_zero:
307 | # print GPU memory usage once at beginning of training
308 | avail, total = torch.cuda.mem_get_info()
309 | mem_used = 100 * (1 - (avail / total))
310 | gb = 1024**3
311 | self.print(f'GPU memory used: {(total-avail)/gb:.2f} of {total/gb:.2f} GB ({mem_used:.2f}%)')
312 | if self.trainer.num_nodes > 1 or self.trainer.num_devices > 1:
313 | distributed.barrier()
314 |
315 | def training_step(self, batch: Any, batch_idx: int, *args, **kwargs):
316 | x, _ = batch
317 | pred = self.mae(x)
318 | loss = self.masked_mse_loss(x, pred)
319 | self.log('train/loss', loss, prog_bar=True)
320 | return {'loss': loss}
321 |
322 | def validation_step(self, batch: Any, batch_idx: int, *args, **kwargs):
323 | x, _ = batch
324 | pred = self.mae(x)
325 | loss = self.masked_mse_loss(x, pred)
326 | self.log('val/loss', loss, prog_bar=True, sync_dist=True)
327 |
328 | if self.save_imgs_every:
329 | p = int(self.save_imgs_every)
330 | if self.trainer.current_epoch % p == 0:
331 | nb = self.trainer.num_val_batches[0]
332 | ns = self.num_save_imgs
333 | per_batch = math.ceil(ns / nb)
334 | self.saved_imgs_list.append(pred[:per_batch])
335 |
336 | return {'loss': loss}
337 |
338 | def on_validation_epoch_end(self):
339 | if self.save_imgs_every:
340 | if self.trainer.is_global_zero:
341 | imgs = torch.cat(self.saved_imgs_list, 0)
342 | self.saved_imgs_list.clear()
343 | self.save_imgs(imgs[:self.num_save_imgs])
344 | if self.trainer.num_nodes > 1 or self.trainer.num_devices > 1:
345 | distributed.barrier()
346 |
347 | # @pl.utilities.rank_zero_only
348 | def save_imgs(self, imgs: torch.Tensor):
349 | with torch.no_grad():
350 | r = int(imgs.shape[0]**0.5)
351 | imgs = self.mae.tokens_as_image(imgs.detach())
352 | imgs = imgs.add_(1).mul_(127.5).clamp_(0, 255).byte()
353 | imgs = torchvision.utils.make_grid(imgs, r).cpu()
354 | epoch = self.trainer.current_epoch
355 | dir = os.path.join(self.trainer.log_dir, 'imgs')
356 | os.makedirs(dir, exist_ok=True)
357 | torchvision.io.write_png(imgs, os.path.join(dir, f'epoch_{epoch}_imgs.png'))
358 |
359 | def configure_optimizers(self):
360 | total_steps = self.trainer.estimated_stepping_batches
361 | devices, nodes = self.trainer.num_devices, self.trainer.num_nodes
362 | batch_size = self.trainer.train_dataloader.batch_size
363 | lr_scale = devices * nodes * batch_size / self.base_batch_size
364 | lr = self.lr * lr_scale
365 |
366 | optim = torch.optim.AdamW(self.parameters(), lr=lr, betas=(.9, .95), weight_decay=0.05)
367 | schedule = torch.optim.lr_scheduler.OneCycleLR(
368 | optim,
369 | max_lr=lr,
370 | total_steps=total_steps,
371 | pct_start=0.1,
372 | cycle_momentum=False,
373 | )
374 | return {
375 | 'optimizer': optim,
376 | 'lr_scheduler': {'scheduler': schedule, 'interval': 'step'}
377 | }
378 |
379 | def masked_mse_loss(self, img: torch.Tensor, recon: torch.Tensor):
380 | # turn the image into patch-vectors for comparison to model output
381 | x = self.mae.image_as_tokens(img)
382 | if self.normalize_for_loss:
383 | std, mean = torch.std_mean(x, dim=-1, keepdim=True)
384 | x = x.sub(mean).div(std + 1e-5)
385 | # only compute on the mask token outputs, which is everything after the first (n * keep)
386 | idx = self.mae.idx[:, int(self.keep * self.n):]
387 | x = self.mae.select_tokens(x, idx)
388 | y = self.mae.select_tokens(recon, idx)
389 | return torch.nn.functional.mse_loss(x, y)
390 |
391 | class MAE_linear_probe(pl.LightningModule):
392 | '''Frozen MAE encoder with trainable linear readout to class labels
393 | https://lightning.ai/docs/pytorch/stable/advanced/transfer_learning.html
394 |
395 | '''
396 | def __init__(
397 | self,
398 | ckpt_path: str,
399 | ):
400 | super().__init__()
401 | mae_module = MAE()
402 | mae_module.load_state_dict(torch.load(ckpt_path)['state_dict'])
403 | self.mae = mae_module.mae
404 |
405 | self.feature_extractor = self.mae.encoder
406 |
407 | self.classifier = torch.nn.Linear(self.mae.enc_width, 10)
408 | self.classifier.weight.data.normal_(mean=0.0, std=0.01)
409 | self.classifier.bias.data.zero_()
410 |
411 | def forward(self, x):
412 | x = self.mae.embed(x)
413 | x = x + self.mae.pos_encoder
414 | self.feature_extractor.eval()
415 | with torch.no_grad():
416 | x = self.feature_extractor(x)
417 | x = x.mean(dim=1) # average pool over the patch dimension
418 | x = self.classifier(x)
419 | return x
420 |
421 | def configure_optimizers(self):
422 | optimizer = torch.optim.AdamW(self.parameters(), lr=5e-4)
423 | return optimizer
424 |
425 | def training_step(self, batch: Any, batch_idx: int, *args, **kwargs):
426 | x, labels = batch
427 | pred = self.forward(x)
428 | loss = self.loss_fn(pred, labels)
429 | self.log('train/loss', loss, prog_bar=True, sync_dist=True, on_step=False, on_epoch=True)
430 | return {'loss': loss}
431 |
432 | def validation_step(self, batch: Any, batch_idx: int, *args, **kwargs):
433 | x, labels = batch
434 | pred = self.forward(x)
435 | loss = self.loss_fn(pred, labels)
436 | _, predicted = torch.max(pred, 1)
437 | correct = (predicted == labels).sum().item()
438 | self.log('val/loss', loss, prog_bar=True, sync_dist=True, on_step=False, on_epoch=True)
439 | self.log('val/acc', correct / len(labels), prog_bar=True, on_step=False, sync_dist=True, on_epoch=True)
440 |
441 | def test_step(self, batch: Any, batch_idx: int, *args, **kwargs):
442 | x, labels = batch
443 | pred = self.forward(x)
444 | loss = self.loss_fn(pred, labels)
445 | _, predicted = torch.max(pred, 1)
446 | correct = (predicted == labels).sum().item()
447 | self.log('test/loss', loss, prog_bar=True, sync_dist=True, on_step=False, on_epoch=True)
448 | self.log('test/acc', correct / len(labels), prog_bar=True, on_step=False, sync_dist=True, on_epoch=True)
449 |
450 | def loss_fn(self, x, y):
451 | fn = torch.nn.CrossEntropyLoss()
452 | return fn(x, y)
453 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=2.0.0
2 | torchvision>=0.15.0
3 | pytorch-lightning>=2.0.2
4 | jsonargparse[signatures]>=4.21.1
5 | omegaconf>=2.1.1
6 | fgvcdata>=0.1.0
7 | timm>=0.9.2
8 | python-dotenv>=0.17.3
9 | pillow>=9.4.0
--------------------------------------------------------------------------------
/samples/bird-samples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/catalys1/mae-pytorch/1c6248f042fd6eccacb281ca5be47799de7c7d7d/samples/bird-samples.png
--------------------------------------------------------------------------------
/samples/bird-training-curves.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/catalys1/mae-pytorch/1c6248f042fd6eccacb281ca5be47799de7c7d7d/samples/bird-training-curves.png
--------------------------------------------------------------------------------
/samples/birds-training.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/catalys1/mae-pytorch/1c6248f042fd6eccacb281ca5be47799de7c7d7d/samples/birds-training.gif
--------------------------------------------------------------------------------
/samples/car-samples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/catalys1/mae-pytorch/1c6248f042fd6eccacb281ca5be47799de7c7d7d/samples/car-samples.png
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | from pytorch_lightning.cli import LightningCLI
2 | import torch
3 | from dotenv import load_dotenv
4 |
5 | if __name__ == '__main__':
6 |
7 | load_dotenv('.env')
8 | torch.set_float32_matmul_precision('medium')
9 | cli = LightningCLI(parser_kwargs={'parser_mode': 'omegaconf'})
10 |
--------------------------------------------------------------------------------
/visualize.py:
--------------------------------------------------------------------------------
1 | import dotenv
2 | from omegaconf import OmegaConf
3 | from pathlib import Path
4 | from pytorch_lightning import seed_everything
5 | import random
6 | import torch
7 | import torchvision
8 |
9 | import datamodule
10 | import mae
11 |
12 | if __name__ == '__main__':
13 | import argparse
14 |
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument('logdir', help='Path to logging directory for a trained model')
17 | parser.add_argument('-n', '--num_imgs', type=int, default=8, help='Number of images to display')
18 | parser.add_argument('-s', '--seed', type=int, default=987654321, help='Random seed')
19 | parser.add_argument('-d', '--device', type=str, default='cpu', help='Device type: "cpu" or "cuda"')
20 |
21 | args = parser.parse_args()
22 |
23 | dotenv.load_dotenv('.env')
24 | seed_everything(args.seed)
25 | device = args.device
26 |
27 | root = Path(args.logdir)
28 | config = OmegaConf.load(root.joinpath('config.yaml'))
29 |
30 | ### data setup
31 | print('Data... ', end='')
32 | dm_class = config.data.class_path.rsplit('.', 1)[-1]
33 | data_dir = config.data.init_args.data_dir
34 | dm = getattr(datamodule, dm_class)(data_dir)
35 | dm.setup()
36 | data = dm.data_val
37 | print(data)
38 |
39 | ### model setup
40 | print('Model... ', end='')
41 | ckpt_path = root.joinpath('checkpoints', 'last.ckpt')
42 | model = mae.MAE.load_from_checkpoint(ckpt_path, map_location='cpu')
43 | model = model.mae.to(device)
44 | print(model.__class__.__name__)
45 |
46 | ### get model predictions
47 | print('Getting predictions...', end='')
48 | img_indices = random.choices(range(len(data)), k=args.num_imgs)
49 | imgs = torch.stack([data[i][0] for i in img_indices], 0).to(device)
50 |
51 | preds = model.tokens_as_image(model(imgs))
52 | masked = model.masked_image(imgs)
53 | print('done')
54 |
55 | ### create visualization
56 | viz = torchvision.utils.make_grid(
57 | torch.cat([imgs, masked, preds], 0).clamp(-1, 1),
58 | nrow=args.num_imgs,
59 | normalize=True,
60 | value_range=(-1, 1),
61 | )
62 | viz = viz.mul_(255).byte()
63 |
64 | torchvision.io.write_png(viz, str(root.joinpath('samples.png')))
--------------------------------------------------------------------------------