├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── mae ├── LICENSE ├── __init__.py ├── clean_data.py ├── config │ ├── eurosat.yaml │ ├── fmow.yaml │ ├── imagenet.yaml │ ├── naip.yaml │ ├── resisc.yaml │ ├── sentinel2.yaml │ └── xview2.yaml ├── dataloaders │ ├── __init__.py │ ├── airbus.py │ ├── airound.py │ ├── cvbrct.py │ ├── eurosat.py │ ├── fmow.py │ ├── geo.py │ ├── imagelist.py │ ├── imagenet100.py │ ├── merge.py │ ├── mlrsnet.py │ ├── naip.py │ ├── optimal.py │ ├── resic45.py │ ├── sentinel2.py │ ├── ucmerced.py │ ├── utils.py │ ├── whurs.py │ └── xview.py ├── engine_finetune.py ├── engine_pretrain.py ├── eval │ └── knn.py ├── exps │ └── encoder-decoder-vanilla.sh ├── helpers │ └── __init__.py ├── lib │ ├── fpn.py │ ├── gpt.py │ ├── scheduler.py │ ├── transformer.py │ └── transforms.py ├── main_eval.py ├── main_finetune.py ├── main_linprobe.py ├── main_pretrain.py ├── models_mae.py ├── models_vit.py ├── samplers │ └── distributed.py ├── scripts │ ├── eval_launcher.py │ ├── evalconf │ │ ├── demo-conf.yaml │ │ └── dgx-conf.yaml │ └── gen_scale_perf_plots.py ├── splits │ ├── train-aerial.txt │ ├── train-eurosat.txt │ ├── train-mlrsnet.txt │ ├── train-resisc.txt │ ├── val-aerial.txt │ ├── val-eurosat.txt │ ├── val-mlrsnet.txt │ └── val-resisc.txt ├── util │ ├── crop.py │ ├── datasets.py │ ├── dist_utils.py │ ├── lars.py │ ├── lr_decay.py │ ├── lr_sched.py │ ├── misc.py │ ├── pos_embed.py │ └── resolution_sched.py └── wandb_log.py ├── pyproject.toml └── setup.cfg /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | # lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | EURO* 113 | eurosat 114 | fmow_data 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | .vscode/ 135 | output_dir/ 136 | *.pth 137 | slurm_launch.sh 138 | wandb 139 | *~ 140 | 141 | demasking-layer*/ 142 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/asottile/pyupgrade 3 | rev: v2.32.1 4 | hooks: 5 | - id: pyupgrade 6 | args: [--py37-plus] 7 | 8 | - repo: https://github.com/pycqa/isort 9 | rev: 5.10.1 10 | hooks: 11 | - id: isort 12 | additional_dependencies: ["colorama>=0.4.3"] 13 | args: ["--profile", "black"] 14 | 15 | - repo: https://github.com/psf/black 16 | rev: 22.3.0 17 | hooks: 18 | - id: black 19 | args: [--skip-magic-trailing-comma] 20 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Scale-MAE 🛰️ 2 | 3 | ![image](https://user-images.githubusercontent.com/1455579/217665789-b46d6830-445f-4151-b7a4-a2152a81a8d1.png) 4 | 5 | 6 | This repository provides a reimplementation of the code for [Scale-MAE: A Scale-Aware Masked Autoencoder for Multiscale Geospatial Representation Learning](https://arxiv.org/abs/2212.14532) (the original code was optimized for our distributed cluster). 7 | 8 | ``` 9 | @article{reed2022scale, 10 | title={Scale-MAE: A Scale-Aware Masked Autoencoder for Multiscale Geospatial Representation Learning}, 11 | author={Reed, Colorado J and Gupta, Ritwik and Li, Shufan and Brockman, Sarah and Funk, Christopher and Clipp, Brian and Candido, Salvatore and Uyttendaele, Matt and Darrell, Trevor}, 12 | journal={arXiv preprint arXiv:2212.14532}, 13 | year={2022} 14 | } 15 | ``` 16 | 17 | * This repo is a modification on the [MAE repo](https://github.com/facebookresearch/mae). Installation and preparation follow that repo ;-). 18 | 19 | * As mentioned in the MAE repo, this repo is based on [`timm==0.3.2`](https://github.com/rwightman/pytorch-image-models), for which a [fix](https://github.com/rwightman/pytorch-image-models/issues/420#issuecomment-776459842) is needed to work with PyTorch 1.8.1+. In addition, install gdal, rasterio, and Shapely. This tends to work pretty well (but gdal is notoriously tricky): 20 | 21 | ## Installation 22 | ```bash 23 | conda create -n scalemae python=3.9 geopandas # geopandas should install gdal correctly 24 | conda activate scalemae 25 | # replace with your desired pytorch target (e.g. cuda version) 26 | conda install pytorch torchvision torchaudio pytorch-cuda=11.6 -c pytorch -c nvidia 27 | pip install -e . 28 | ``` 29 | 30 | ## Data Preparation 31 | Download the FMoW-rgb dataset as described in the [here](https://github.com/fMoW/dataset) and then make a symlink to the data directory in the root of this repo. For example, if you downloaded the data to `~/data/fmow-rgb`, then run: 32 | 33 | ```bash 34 | ln -s ~/data/fmow-rgb data 35 | ``` 36 | 37 | ## Pretraining ## 38 | Datasets are defined by config files in `config`. 39 | ``` 40 | # change to num of gpus you have 41 | python -m torch.distributed.launch --nproc_per_node=4 42 | main_pretrain.py 43 | ``` 44 | 45 | use `-h` to see details of all arguments. 46 | 47 | 48 | ## Pretrained Models 49 | 50 | * [**ViT Large 800 ep**](https://github.com/bair-climate-initiative/scale-mae/releases/download/base-800/scalemae-vitlarge-800.pth) 51 | 52 | 53 | 54 | ## Evaluation 55 | 56 | ### KNN Evaluation 57 | ``` 58 | python -m torch.distributed.launch --nproc_per_node=4 \ 59 | main_pretrain.py \ 60 | --resume \ 61 | --eval_only \ 62 | --eval_dataset \ 63 | --eval_train_fnames \ 64 | --eval_val_fnames 65 | ``` 66 | 67 | We support resisc (default), airound, mlrsnet, and fmow kNN evaluation. We provide all split files in `splits` folder. If `--eval_train_fnames` and `--eval_val_fnames` are specified, the content of these two txt files will be read as the train split and test split. If this is the case, the root folder of the dataset is assumed to be the parent folder of such txt files. Alternatively, one can specify `--eval_path`. If this is the case, 90% of the data is randomly selected as the training set while the 10% is selected as the test set. The dataset is assumed to have the standard structure of `ImageFolder` in `torchvision`. 68 | 69 | ### Finetuning 70 | 71 | ``` 72 | python -m torch.distributed.launch --nproc_per_node=4 \ 73 | main_linprobe.py \ 74 | --checkpoint_path 75 | ``` 76 | 77 | Use the flag `--finetune` to enable full fine-tuning instead of a linear probing. 78 | 79 | --- 80 | 81 | > Note: THIS SOFTWARE AND/OR DATA WAS DEPOSITED IN THE BAIR OPEN RESEARCH COMMONS REPOSITORY ON 2/8/23. 82 | -------------------------------------------------------------------------------- /mae/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bair-climate-initiative/scale-mae/89280d830037ff27c20459cdab03e01e633e29bb/mae/__init__.py -------------------------------------------------------------------------------- /mae/clean_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import threading 3 | from multiprocessing.pool import ThreadPool as Pool 4 | from threading import Thread 5 | 6 | import numpy as np 7 | import scandir 8 | from tqdm.cli import tqdm 9 | 10 | lock = threading.Lock() 11 | from PIL import Image 12 | 13 | Image.MAX_IMAGE_PIXELS = 1000000000 14 | directory = "data/naip" 15 | res = scandir.walk(directory) 16 | 17 | 18 | def f(item): 19 | path, _, files = item 20 | for file in files: 21 | if ".tif" in file: 22 | fpath = os.path.join(path, file) 23 | try: 24 | img = Image.open(fpath) 25 | arr = np.array(img) 26 | arr.shape 27 | except: 28 | ... 29 | # acquire the lock 30 | lock.acquire() 31 | # open file for appending 32 | with open("badfile", "a") as file: 33 | # write text to data 34 | file.write(fpath + "\n") 35 | # release the lock 36 | lock.release() 37 | print(fpath) 38 | 39 | 40 | with open("badfile", "w") as file: 41 | # write text to data 42 | file.write(">>>>>SOF>>>>>>\n") 43 | pool = Pool(100) 44 | pool.map(f, res) 45 | -------------------------------------------------------------------------------- /mae/config/eurosat.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: eurosat 3 | # size: 224 4 | oversample: 1.0 5 | img_dir: 'EURO_TRAIN' 6 | mean: [0.368, 0.381, 0.3436] 7 | std: [0.2035, 0.1854, 0.1849] 8 | # actual mean/std, but imagenet numbers work a bit better 9 | # mean: [0.3916, 0.3936, 0.3658] 10 | # std: [0.2775, 0.2713, 0.2751] 11 | vis_factor: 1.0 12 | base_resolution: 2.5 13 | -------------------------------------------------------------------------------- /mae/config/fmow.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: fmow 3 | # size: 224 4 | oversample: 1.0 5 | img_dir: 'fdata/train' 6 | mean: [0.485, 0.456, 0.406] 7 | std: [0.229, 0.224, 0.225] 8 | # actual mean/std, but imagenet numbers work a bit better 9 | # mean: [0.3916, 0.3936, 0.3658] 10 | # std: [0.2775, 0.2713, 0.2751] 11 | vis_factor: 1.0 12 | base_resolution: 2.5 13 | -------------------------------------------------------------------------------- /mae/config/imagenet.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: ImageNet 3 | size: 224 4 | img_dir: 'ilsvrc' 5 | mean: [0.485, 0.456, 0.406] 6 | std: [0.229, 0.224, 0.225] 7 | vis_factor: 1.0 8 | 9 | sampler: 10 | shuffle: true 11 | seed: 0 12 | -------------------------------------------------------------------------------- /mae/config/naip.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: NAIP 3 | length: 100000 4 | img_dir: 'naip' 5 | mean: [0.46921533, 0.46026663, 0.41329921] 6 | std: [0.1927, 0.1373, 0.1203] 7 | vis_factor: 1.0 -------------------------------------------------------------------------------- /mae/config/resisc.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: resisc 3 | # size: 224 4 | oversample: 1.0 5 | img_dir: 'resisc45/train.txt' 6 | mean: [0.368, 0.381, 0.3436] 7 | std: [0.2035, 0.1854, 0.1849] 8 | # actual mean/std, but imagenet numbers work a bit better 9 | # mean: [0.3916, 0.3936, 0.3658] 10 | # std: [0.2775, 0.2713, 0.2751] 11 | vis_factor: 1.0 12 | base_resolution: 2.5 13 | -------------------------------------------------------------------------------- /mae/config/sentinel2.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: SENTINEL2 3 | size: 1024 4 | length: 100000 5 | img_dir: 'sentinel-2-diverse-plus-expanded-la' 6 | oversample: 1.5 7 | mean: [ 665.65558013, 983.1557218 , 1241.98676153] 8 | std: [ 852.4296, 1258.8298, 1590.1336] 9 | vis_factor: 6000 10 | base_resolution: 1.0 -------------------------------------------------------------------------------- /mae/config/xview2.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | type: XView2 3 | size: 224 4 | length: 100000 5 | img_dir: 'xview2' 6 | mean: [86.02207342, 89.45889871, 67.45251904, 84.22436819, 87.51666638, 66.13790823] 7 | std: [45.36812982, 37.01190998, 34.63558733, 43.13346176, 36.14107798, 33.48689906] 8 | vis_factor: 1.0 9 | 10 | sampler: 11 | shuffle: true 12 | seed: 0 -------------------------------------------------------------------------------- /mae/dataloaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bair-climate-initiative/scale-mae/89280d830037ff27c20459cdab03e01e633e29bb/mae/dataloaders/__init__.py -------------------------------------------------------------------------------- /mae/dataloaders/airbus.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import torch 6 | import torchvision.transforms as transforms 7 | import torchvision.transforms.functional as F 8 | from PIL import Image 9 | from torch.utils.data import Dataset 10 | 11 | 12 | class Airbus(Dataset): 13 | """Airbus Ship Detection dataset.""" 14 | 15 | def __init__(self, root_dir, split_file, transform=None, classification=False): 16 | """ 17 | Args: 18 | root_dir (string): Directory with all the images. 19 | split_file (string): File with COCO JSON for the split. 20 | transform (callable, optional): Optional transform to be applied 21 | on a sample. 22 | """ 23 | if type(root_dir) is str: 24 | root_dir = Path(root_dir) 25 | 26 | self.root_dir = root_dir 27 | self.scenes = self._load_file_info(split_file) 28 | self.length = len(self.scenes) 29 | 30 | self.transform = transform 31 | self.classification = classification 32 | 33 | def _load_file_info(self, split_file): 34 | f = open(split_file) 35 | lines = f.readlines() 36 | # Skip header 37 | lines = lines[1:] 38 | pairs = [x.strip().split(",") for x in lines] 39 | 40 | # Remove corrupt files 41 | pairs = [x for x in pairs if x[0] != "6384c3e78.jpg"] 42 | grouped = [(k, list(v)) for k, v in itertools.groupby(pairs, lambda x: x[0])] 43 | 44 | return grouped 45 | 46 | def _rle2mask(self, mask_rle, shape=(768, 768)): 47 | """ 48 | mask_rle: run-length as string formated (start length) 49 | shape: (width,height) of array to return 50 | Returns numpy array, 1 - mask, 0 - background 51 | """ 52 | s = mask_rle.split() 53 | starts, lengths = (np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])) 54 | starts -= 1 55 | ends = starts + lengths 56 | img = np.zeros(shape[0] * shape[1], dtype=np.uint8) 57 | if len(mask_rle) == 0: 58 | return img.reshape(shape).T 59 | 60 | for lo, hi in zip(starts, ends): 61 | img[lo:hi] = 1 62 | return img.reshape(shape).T 63 | 64 | def _merge_masks(self, grouped_rle, shape=(768, 768)): 65 | mask = np.bitwise_or.reduce( 66 | [self._rle2mask(x[1], shape=shape) for x in grouped_rle] 67 | ) 68 | return mask 69 | 70 | def __len__(self): 71 | return self.length 72 | 73 | def __getitem__(self, idx): 74 | if torch.is_tensor(idx): 75 | idx = idx.tolist() 76 | 77 | ret = {} 78 | scene = self.scenes[idx] 79 | img = np.array(Image.open(self.root_dir / scene[0]), dtype=np.float32) 80 | img = np.transpose(img, (2, 0, 1)) 81 | label = self._merge_masks(scene[1]) 82 | 83 | label = np.expand_dims(label, axis=0) 84 | 85 | if self.transform: 86 | img = self.transform(img) 87 | label = transforms.ToTensor()(label) 88 | 89 | # TODO you must crop the label too! 90 | # i, j, h, w = transforms.RandomCrop.get_params(img, output_size=(128, 128)) 91 | # img = F.crop(img, i, j, h, w) 92 | # label = F.crop(label, i, j, h, w) 93 | 94 | if self.classification: 95 | return img, 0 96 | 97 | return img, label 98 | -------------------------------------------------------------------------------- /mae/dataloaders/airound.py: -------------------------------------------------------------------------------- 1 | # ~1100 files from AiRound were skipped since they were not 224x224x3 2 | 3 | class AIROUND_DATASET_STATS: 4 | PIXEL_MEANS = [0.4015, 0.4149, 0.3817] # [0.485, 0.456, 0.406] # imagenet 5 | PIXEL_STD = [0.2340, 0.2096, 0.2124] # [0.229, 0.224, 0.225] # imagenet 6 | -------------------------------------------------------------------------------- /mae/dataloaders/cvbrct.py: -------------------------------------------------------------------------------- 1 | class CVBRCT_DATASET_STATS: 2 | # Calculated from only 3000 files out of 24,619 3 | # Divided by 255. 4 | # Alpha channel values elided (255/0) 5 | PIXEL_MEANS = [0.45239977, 0.44331489, 0.43006366] 6 | PIXEL_STD = [0.23977479, 0.21991943, 0.20902827] 7 | -------------------------------------------------------------------------------- /mae/dataloaders/eurosat.py: -------------------------------------------------------------------------------- 1 | class EUROSAT_DATASET_STATS: 2 | # Divided by 255. 3 | PIXEL_MEANS = [ 4 | 0.34436897, 5 | 0.38029233, 6 | 0.40777751, 7 | ] # [0.485, 0.456, 0.406] # imagenet 8 | PIXEL_STD = [ 9 | 0.20368513, 10 | 0.13663637, 11 | 0.11484352, 12 | ] # [0.229, 0.224, 0.225] # imagenet 13 | -------------------------------------------------------------------------------- /mae/dataloaders/fmow.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchvision.datasets import ImageFolder 4 | 5 | from .imagelist import ImageList 6 | 7 | 8 | class FMOW_DATASET_STATS: 9 | PIXEL_MEANS = [0.485, 0.456, 0.406] 10 | PIXEL_STD = [0.229, 0.224, 0.225] 11 | 12 | 13 | def is_fmow_rgb(fname: str) -> bool: 14 | return fname.endswith("_rgb.jpg") 15 | 16 | 17 | def build_fmow(data_root, transforms): 18 | if os.path.isdir(data_root): 19 | return ImageFolder( 20 | root=data_root, transform=transforms, is_valid_file=is_fmow_rgb 21 | ) 22 | return ImageList(data_root, transforms) 23 | -------------------------------------------------------------------------------- /mae/dataloaders/geo.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import glob 3 | import os 4 | import re 5 | import sys 6 | from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, cast 7 | 8 | import numpy as np 9 | import rasterio 10 | import rasterio.merge 11 | import torch 12 | from rasterio.crs import CRS 13 | from rasterio.io import DatasetReader 14 | from rasterio.vrt import WarpedVRT 15 | from rasterio.windows import from_bounds 16 | from torch import Tensor 17 | from torchgeo.datasets.geo import GeoDataset 18 | from torchgeo.datasets.utils import BoundingBox, disambiguate_timestamp 19 | 20 | from . import merge 21 | 22 | 23 | class CustomRasterDataset(GeoDataset): 24 | """Abstract base class for :class:`GeoDataset` stored as raster files.""" 25 | 26 | #: Glob expression used to search for files. 27 | #: 28 | #: This expression should be specific enough that it will not pick up files from 29 | #: other datasets. It should not include a file extension, as the dataset may be in 30 | #: a different file format than what it was originally downloaded as. 31 | filename_glob = "*" 32 | 33 | #: Regular expression used to extract date from filename. 34 | #: 35 | #: The expression should use named groups. The expression may contain any number of 36 | #: groups. The following groups are specifically searched for by the base class: 37 | #: 38 | #: * ``date``: used to calculate ``mint`` and ``maxt`` for ``index`` insertion 39 | #: 40 | #: When :attr:`separate_files`` is True, the following additional groups are 41 | #: searched for to find other files: 42 | #: 43 | #: * ``band``: replaced with requested band name 44 | filename_regex = ".*" 45 | 46 | #: Date format string used to parse date from filename. 47 | #: 48 | #: Not used if :attr:`filename_regex` does not contain a ``date`` group. 49 | date_format = "%Y%m%d" 50 | 51 | #: True if dataset contains imagery, False if dataset contains mask 52 | is_image = True 53 | 54 | #: True if data is stored in a separate file for each band, else False. 55 | separate_files = False 56 | 57 | #: Names of all available bands in the dataset 58 | all_bands: List[str] = [] 59 | 60 | #: Names of RGB bands in the dataset, used for plotting 61 | rgb_bands: List[str] = [] 62 | 63 | #: Color map for the dataset, used for plotting 64 | cmap: Dict[int, Tuple[int, int, int, int]] = {} 65 | 66 | def __init__( 67 | self, 68 | root: str = "data", 69 | crs: Optional[CRS] = None, 70 | res: Optional[float] = None, 71 | bands: Optional[Sequence[str]] = None, 72 | transforms: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, 73 | cache: bool = True, 74 | ) -> None: 75 | """Initialize a new Dataset instance. 76 | Args: 77 | root: root directory where dataset can be found 78 | crs: :term:`coordinate reference system (CRS)` to warp to 79 | (defaults to the CRS of the first file found) 80 | res: resolution of the dataset in units of CRS 81 | (defaults to the resolution of the first file found) 82 | bands: bands to return (defaults to all bands) 83 | transforms: a function/transform that takes an input sample 84 | and returns a transformed version 85 | cache: if True, cache file handle to speed up repeated sampling 86 | Raises: 87 | FileNotFoundError: if no files are found in ``root`` 88 | """ 89 | super().__init__(transforms) 90 | 91 | self.root = root 92 | self.cache = cache 93 | self.nodata = None 94 | 95 | # Populate the dataset index 96 | i = 0 97 | pathname = os.path.join(root, "**", self.filename_glob) 98 | filename_regex = re.compile(self.filename_regex, re.VERBOSE) 99 | for filepath in glob.iglob(pathname, recursive=True): 100 | match = re.match(filename_regex, os.path.basename(filepath)) 101 | if match is not None: 102 | try: 103 | with rasterio.open(filepath) as src: 104 | # See if file has a color map 105 | if len(self.cmap) == 0: 106 | try: 107 | self.cmap = src.colormap(1) 108 | except ValueError: 109 | pass 110 | 111 | if crs is None: 112 | crs = src.crs 113 | if res is None: 114 | res = src.res[0] 115 | if self.nodata == None: 116 | self.nodata = src.nodata 117 | 118 | with WarpedVRT(src, crs=crs) as vrt: 119 | minx, miny, maxx, maxy = vrt.bounds 120 | except rasterio.errors.RasterioIOError: 121 | # Skip files that rasterio is unable to read 122 | continue 123 | else: 124 | mint: float = 0 125 | maxt: float = sys.maxsize 126 | if "date" in match.groupdict(): 127 | date = match.group("date") 128 | mint, maxt = disambiguate_timestamp(date, self.date_format) 129 | 130 | coords = (minx, maxx, miny, maxy, mint, maxt) 131 | self.index.insert(i, coords, filepath) 132 | i += 1 133 | 134 | if i == 0: 135 | raise FileNotFoundError( 136 | f"No {self.__class__.__name__} data was found in '{root}'" 137 | ) 138 | 139 | if bands and self.all_bands: 140 | band_indexes = [self.all_bands.index(i) + 1 for i in bands] 141 | self.bands = bands 142 | assert len(band_indexes) == len(self.bands) 143 | elif bands: 144 | msg = ( 145 | f"{self.__class__.__name__} is missing an `all_bands` attribute," 146 | " so `bands` cannot be specified." 147 | ) 148 | raise AssertionError(msg) 149 | else: 150 | band_indexes = None 151 | self.bands = self.all_bands 152 | 153 | self.band_indexes = band_indexes 154 | self._crs = cast(CRS, crs) 155 | self.res = cast(float, res) 156 | 157 | def __getitem__(self, query: BoundingBox) -> Dict[str, Any]: 158 | """Retrieve image/mask and metadata indexed by query. 159 | Args: 160 | query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index 161 | Returns: 162 | sample of image/mask and metadata at that index 163 | Raises: 164 | IndexError: if query is not found in the index 165 | """ 166 | hits = self.index.intersection(tuple(query), objects=True) 167 | filepaths = [hit.object for hit in hits] 168 | 169 | if not filepaths: 170 | raise IndexError( 171 | f"query: {query} not found in index with bounds: {self.bounds}" 172 | ) 173 | 174 | if self.separate_files: 175 | data_list: List[Tensor] = [] 176 | valid_mask_list: List[Tensor] = [] 177 | filename_regex = re.compile(self.filename_regex, re.VERBOSE) 178 | for band in self.bands: 179 | band_filepaths = [] 180 | for filepath in filepaths: 181 | filename = os.path.basename(filepath) 182 | directory = os.path.dirname(filepath) 183 | match = re.match(filename_regex, filename) 184 | if match: 185 | if "date" in match.groupdict(): 186 | start = match.start("band") 187 | end = match.end("band") 188 | filename = filename[:start] + band + filename[end:] 189 | filepath = glob.glob(os.path.join(directory, filename))[0] 190 | band_filepaths.append(filepath) 191 | data_list_internal, valid_mask_list_internal = self._merge_files( 192 | band_filepaths, query 193 | ) 194 | data_list.append(data_list_internal) 195 | valid_mask_list.append(valid_mask_list_internal) 196 | data = torch.cat(data_list) 197 | valid_mask = torch.cat(valid_mask_list) 198 | else: 199 | data, valid_mask = self._merge_files(filepaths, query, self.band_indexes) 200 | 201 | key = "image" if self.is_image else "mask" 202 | sample = {key: data, "crs": self.crs, "bbox": query, "validmask": valid_mask} 203 | 204 | if self.transforms is not None: 205 | sample = self.transforms(sample) 206 | 207 | return sample 208 | 209 | def _merge_files( 210 | self, 211 | filepaths: Sequence[str], 212 | query: BoundingBox, 213 | band_indexes: Optional[Sequence[int]] = None, 214 | ) -> Tensor: 215 | """Load and merge one or more files. 216 | Args: 217 | filepaths: one or more files to load and merge 218 | query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index 219 | band_indexes: indexes of bands to be used 220 | Returns: 221 | image/mask at that index 222 | """ 223 | if self.cache: 224 | vrt_fhs = [self._cached_load_warp_file(fp) for fp in filepaths] 225 | else: 226 | vrt_fhs = [self._load_warp_file(fp) for fp in filepaths] 227 | 228 | bounds = (query.minx, query.miny, query.maxx, query.maxy) 229 | if len(vrt_fhs) == 1: 230 | src = vrt_fhs[0] 231 | out_width = round((query.maxx - query.minx) / self.res) 232 | out_height = round((query.maxy - query.miny) / self.res) 233 | count = len(band_indexes) if band_indexes else src.count 234 | out_shape = (count, out_height, out_width) 235 | 236 | window = from_bounds(*bounds, src.transform) 237 | dest = src.read(indexes=band_indexes, out_shape=out_shape, window=window) 238 | valid_msk = dest != src.nodata 239 | # valid_msk = src.read_masks( 240 | # indexes=band_indexes, 241 | # out_shape=out_shape, 242 | # window=window, 243 | # ) 244 | else: 245 | dest, _, valid_msk = merge.merge( 246 | vrt_fhs, bounds, self.res, indexes=band_indexes 247 | ) 248 | 249 | # fix numpy dtypes which are not supported by pytorch tensors 250 | if dest.dtype == np.uint16: 251 | dest = dest.astype(np.int32) 252 | elif dest.dtype == np.uint32: 253 | dest = dest.astype(np.int64) 254 | 255 | if valid_msk.dtype == np.uint16: 256 | valid_msk = valid_msk.astype(np.int32) 257 | elif valid_msk.dtype == np.uint32: 258 | valid_msk = valid_msk.astype(np.int64) 259 | 260 | tensor = torch.tensor(dest) 261 | msk_tensor = torch.tensor(valid_msk) 262 | 263 | assert tensor.shape == msk_tensor.shape 264 | 265 | return tensor, msk_tensor 266 | 267 | @functools.lru_cache(maxsize=128) 268 | def _cached_load_warp_file(self, filepath: str) -> DatasetReader: 269 | """Cached version of :meth:`_load_warp_file`. 270 | Args: 271 | filepath: file to load and warp 272 | Returns: 273 | file handle of warped VRT 274 | """ 275 | return self._load_warp_file(filepath) 276 | 277 | def _load_warp_file(self, filepath: str) -> DatasetReader: 278 | """Load and warp a file to the correct CRS and resolution. 279 | Args: 280 | filepath: file to load and warp 281 | Returns: 282 | file handle of warped VRT 283 | """ 284 | src = rasterio.open(filepath) 285 | 286 | # Only warp if necessary 287 | if src.crs != self.crs: 288 | vrt = WarpedVRT(src, crs=self.crs) 289 | src.close() 290 | return vrt 291 | else: 292 | return src 293 | -------------------------------------------------------------------------------- /mae/dataloaders/imagelist.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Any, Callable, Dict, List, Optional, Tuple, cast 3 | 4 | from torchvision.datasets import VisionDataset 5 | 6 | 7 | class ImageList(VisionDataset): 8 | """A generic data loader for a list of images in a text file""" 9 | 10 | def __init__( 11 | self, 12 | imglist_path: str, 13 | transform: Optional[Callable] = None, 14 | target_transform: Optional[Callable] = None, 15 | ) -> None: 16 | root = os.path.dirname(imglist_path) 17 | with open(imglist_path) as imglistr: 18 | self.imglist = [line.strip() for line in imglistr] 19 | super().__init__(root, transform=transform, target_transform=target_transform) 20 | 21 | classes, class_to_idx = self.find_classes(self.imglist) 22 | samples = [ 23 | (os.path.join(root, fn), class_to_idx[self.filename_to_class(fn)]) 24 | for fn in self.imglist 25 | ] 26 | print(classes) 27 | self.classes = classes 28 | self.class_to_idx = class_to_idx 29 | self.samples = samples 30 | self.targets = [s[1] for s in samples] 31 | 32 | def filename_to_class(self, fn: str) -> str: 33 | # hardcoded HACK that could break 34 | return os.path.dirname(fn).split("/")[1] 35 | 36 | def find_classes(self, filenames: List[str]) -> Tuple[List[str], Dict[str, int]]: 37 | """ """ 38 | classes = sorted(list({self.filename_to_class(fn) for fn in filenames})) 39 | if len(classes) == 0: 40 | raise FileNotFoundError(f"Couldn't find any classes in filenames.") 41 | 42 | class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} 43 | return classes, class_to_idx 44 | 45 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 46 | """ 47 | Args: 48 | index (int): Index 49 | 50 | Returns: 51 | tuple: (sample, target) where target is class_index of the target class. 52 | """ 53 | path, target = self.samples[index] 54 | sample = pil_loader(path) 55 | if self.transform is not None: 56 | sample = self.transform(sample) 57 | if self.target_transform is not None: 58 | target = self.target_transform(target) 59 | 60 | return sample, target 61 | 62 | def __len__(self) -> int: 63 | return len(self.samples) 64 | 65 | 66 | from PIL import Image 67 | 68 | 69 | def pil_loader(path: str) -> Image.Image: 70 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 71 | with open(path, "rb") as f: 72 | img = Image.open(f) 73 | return img.convert("RGB") 74 | -------------------------------------------------------------------------------- /mae/dataloaders/imagenet100.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.utils.data._utils.collate import default_collate 5 | from torchvision import datasets 6 | 7 | 8 | class ImageNet100Dataset(datasets.ImageFolder): 9 | def __init__(self, path, anno_file, transform) -> None: 10 | super().__init__(path, transform=transform) 11 | self.imgs = self.samples 12 | with open(anno_file) as f: 13 | files_100 = f.readlines() 14 | # breakpoint() 15 | files_100 = [x.replace("\n", "") for x in files_100] 16 | new_samples = [] 17 | for x, y in self.samples: 18 | if any([t in x for t in files_100]): 19 | new_samples.append((x, y)) 20 | self.samples = new_samples 21 | 22 | 23 | def build_imagenet_sampler(config, num_replicas, rank, transforms): 24 | img_dir = config["data"]["img_dir"] 25 | dataset = ImageNet100Dataset( 26 | os.path.join(img_dir, "train"), anno_file="anno_100.txt", transform=transforms 27 | ) 28 | sampler = torch.utils.data.DistributedSampler( 29 | dataset, num_replicas=num_replicas, rank=rank, shuffle=True 30 | ) 31 | collate_fn = default_collate 32 | return dataset, sampler, collate_fn 33 | -------------------------------------------------------------------------------- /mae/dataloaders/merge.py: -------------------------------------------------------------------------------- 1 | """Copy valid pixels from input files to an output file.""" 2 | 3 | import logging 4 | import math 5 | import os 6 | import warnings 7 | from contextlib import contextmanager 8 | from pathlib import Path 9 | 10 | import numpy as np 11 | import rasterio 12 | from rasterio import windows 13 | from rasterio.coords import disjoint_bounds 14 | from rasterio.enums import Resampling 15 | from rasterio.errors import RasterioDeprecationWarning 16 | from rasterio.transform import Affine 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | def copy_first(merged_data, new_data, merged_mask, new_mask, **kwargs): 22 | """Returns the first available pixel.""" 23 | mask = np.empty_like(merged_mask, dtype="bool") 24 | np.logical_not(new_mask, out=mask) 25 | np.logical_and(merged_mask, mask, out=mask) 26 | np.copyto(merged_data, new_data, where=mask, casting="unsafe") 27 | 28 | 29 | def copy_last(merged_data, new_data, merged_mask, new_mask, **kwargs): 30 | """Returns the last available pixel.""" 31 | mask = np.empty_like(merged_mask, dtype="bool") 32 | np.logical_not(new_mask, out=mask) 33 | np.copyto(merged_data, new_data, where=mask, casting="unsafe") 34 | 35 | 36 | def copy_min(merged_data, new_data, merged_mask, new_mask, **kwargs): 37 | """Returns the minimum value pixel.""" 38 | mask = np.empty_like(merged_mask, dtype="bool") 39 | np.logical_or(merged_mask, new_mask, out=mask) 40 | np.logical_not(mask, out=mask) 41 | np.minimum(merged_data, new_data, out=merged_data, where=mask, casting="unsafe") 42 | np.logical_not(new_mask, out=mask) 43 | np.logical_and(merged_mask, mask, out=mask) 44 | np.copyto(merged_data, new_data, where=mask, casting="unsafe") 45 | 46 | 47 | def copy_max(merged_data, new_data, merged_mask, new_mask, **kwargs): 48 | """Returns the maximum value pixel.""" 49 | mask = np.empty_like(merged_mask, dtype="bool") 50 | np.logical_or(merged_mask, new_mask, out=mask) 51 | np.logical_not(mask, out=mask) 52 | np.maximum(merged_data, new_data, out=merged_data, where=mask, casting="unsafe") 53 | np.logical_not(new_mask, out=mask) 54 | np.logical_and(merged_mask, mask, out=mask) 55 | np.copyto(merged_data, new_data, where=mask, casting="unsafe") 56 | 57 | 58 | def copy_sum(merged_data, new_data, merged_mask, new_mask, **kwargs): 59 | """Returns the sum of all pixel values.""" 60 | mask = np.empty_like(merged_mask, dtype="bool") 61 | np.logical_or(merged_mask, new_mask, out=mask) 62 | np.logical_not(mask, out=mask) 63 | np.add(merged_data, new_data, out=merged_data, where=mask, casting="unsafe") 64 | np.logical_not(new_mask, out=mask) 65 | np.logical_and(merged_mask, mask, out=mask) 66 | np.copyto(merged_data, new_data, where=mask, casting="unsafe") 67 | 68 | 69 | def copy_count(merged_data, new_data, merged_mask, new_mask, **kwargs): 70 | """Returns the count of valid pixels.""" 71 | mask = np.empty_like(merged_mask, dtype="bool") 72 | np.logical_or(merged_mask, new_mask, out=mask) 73 | np.logical_not(mask, out=mask) 74 | np.add(merged_data, mask, out=merged_data, where=mask, casting="unsafe") 75 | np.logical_not(new_mask, out=mask) 76 | np.logical_and(merged_mask, mask, out=mask) 77 | np.copyto(merged_data, mask, where=mask, casting="unsafe") 78 | 79 | 80 | MERGE_METHODS = { 81 | "first": copy_first, 82 | "last": copy_last, 83 | "min": copy_min, 84 | "max": copy_max, 85 | "sum": copy_sum, 86 | "count": copy_count, 87 | } 88 | 89 | 90 | def merge( 91 | datasets, 92 | bounds=None, 93 | res=None, 94 | nodata=None, 95 | dtype=None, 96 | precision=None, 97 | indexes=None, 98 | output_count=None, 99 | resampling=Resampling.nearest, 100 | method="first", 101 | target_aligned_pixels=False, 102 | dst_path=None, 103 | dst_kwds=None, 104 | ): 105 | """Copy valid pixels from input files to an output file. 106 | 107 | All files must have the same number of bands, data type, and 108 | coordinate reference system. 109 | 110 | Input files are merged in their listed order using the reverse 111 | painter's algorithm (default) or another method. If the output file exists, 112 | its values will be overwritten by input values. 113 | 114 | Geospatial bounds and resolution of a new output file in the 115 | units of the input file coordinate reference system may be provided 116 | and are otherwise taken from the first input file. 117 | 118 | Parameters 119 | ---------- 120 | datasets : list of dataset objects opened in 'r' mode, filenames or PathLike objects 121 | source datasets to be merged. 122 | bounds: tuple, optional 123 | Bounds of the output image (left, bottom, right, top). 124 | If not set, bounds are determined from bounds of input rasters. 125 | res: tuple, optional 126 | Output resolution in units of coordinate reference system. If not set, 127 | the resolution of the first raster is used. If a single value is passed, 128 | output pixels will be square. 129 | nodata: float, optional 130 | nodata value to use in output file. If not set, uses the nodata value 131 | in the first input raster. 132 | dtype: numpy.dtype or string 133 | dtype to use in outputfile. If not set, uses the dtype value in the 134 | first input raster. 135 | precision: int, optional 136 | This parameters is unused, deprecated in rasterio 1.3.0, and 137 | will be removed in version 2.0.0. 138 | indexes : list of ints or a single int, optional 139 | bands to read and merge 140 | output_count: int, optional 141 | If using callable it may be useful to have additional bands in the output 142 | in addition to the indexes specified for read 143 | resampling : Resampling, optional 144 | Resampling algorithm used when reading input files. 145 | Default: `Resampling.nearest`. 146 | method : str or callable 147 | pre-defined method: 148 | first: reverse painting 149 | last: paint valid new on top of existing 150 | min: pixel-wise min of existing and new 151 | max: pixel-wise max of existing and new 152 | or custom callable with signature: 153 | merged_data : array_like 154 | array to update with new_data 155 | new_data : array_like 156 | data to merge 157 | same shape as merged_data 158 | merged_mask, new_mask : array_like 159 | boolean masks where merged/new data pixels are invalid 160 | same shape as merged_data 161 | index: int 162 | index of the current dataset within the merged dataset collection 163 | roff: int 164 | row offset in base array 165 | coff: int 166 | column offset in base array 167 | 168 | target_aligned_pixels : bool, optional 169 | Whether to adjust output image bounds so that pixel coordinates 170 | are integer multiples of pixel size, matching the ``-tap`` 171 | options of GDAL utilities. Default: False. 172 | dst_path : str or PathLike, optional 173 | Path of output dataset 174 | dst_kwds : dict, optional 175 | Dictionary of creation options and other paramters that will be 176 | overlaid on the profile of the output dataset. 177 | 178 | Returns 179 | ------- 180 | tuple 181 | 182 | Two elements: 183 | 184 | dest: numpy.ndarray 185 | Contents of all input rasters in single array 186 | 187 | out_transform: affine.Affine() 188 | Information for mapping pixel coordinates in `dest` to another 189 | coordinate system 190 | 191 | """ 192 | if precision is not None: 193 | warnings.warn( 194 | "The precision parameter is unused, deprecated, and will be removed in 2.0.0.", 195 | RasterioDeprecationWarning, 196 | ) 197 | 198 | if method in MERGE_METHODS: 199 | copyto = MERGE_METHODS[method] 200 | elif callable(method): 201 | copyto = method 202 | else: 203 | raise ValueError( 204 | f"Unknown method {method}, must be one of {list(MERGE_METHODS.keys())} or callable" 205 | ) 206 | 207 | # Create a dataset_opener object to use in several places in this function. 208 | if isinstance(datasets[0], (str, os.PathLike)): 209 | dataset_opener = rasterio.open 210 | else: 211 | 212 | @contextmanager 213 | def nullcontext(obj): 214 | try: 215 | yield obj 216 | finally: 217 | pass 218 | 219 | dataset_opener = nullcontext 220 | 221 | with dataset_opener(datasets[0]) as first: 222 | first_profile = first.profile 223 | first_res = first.res 224 | nodataval = first.nodatavals[0] 225 | dt = first.dtypes[0] 226 | 227 | if indexes is None: 228 | src_count = first.count 229 | elif isinstance(indexes, int): 230 | src_count = indexes 231 | else: 232 | src_count = len(indexes) 233 | 234 | try: 235 | first_colormap = first.colormap(1) 236 | except ValueError: 237 | first_colormap = None 238 | 239 | if not output_count: 240 | output_count = src_count 241 | 242 | # Extent from option or extent of all inputs 243 | if bounds: 244 | dst_w, dst_s, dst_e, dst_n = bounds 245 | else: 246 | # scan input files 247 | xs = [] 248 | ys = [] 249 | for dataset in datasets: 250 | with dataset_opener(dataset) as src: 251 | left, bottom, right, top = src.bounds 252 | xs.extend([left, right]) 253 | ys.extend([bottom, top]) 254 | dst_w, dst_s, dst_e, dst_n = min(xs), min(ys), max(xs), max(ys) 255 | 256 | # Resolution/pixel size 257 | if not res: 258 | res = first_res 259 | elif not np.iterable(res): 260 | res = (res, res) 261 | elif len(res) == 1: 262 | res = (res[0], res[0]) 263 | 264 | if target_aligned_pixels: 265 | dst_w = math.floor(dst_w / res[0]) * res[0] 266 | dst_e = math.ceil(dst_e / res[0]) * res[0] 267 | dst_s = math.floor(dst_s / res[1]) * res[1] 268 | dst_n = math.ceil(dst_n / res[1]) * res[1] 269 | 270 | # Compute output array shape. We guarantee it will cover the output 271 | # bounds completely 272 | output_width = int(round((dst_e - dst_w) / res[0])) 273 | output_height = int(round((dst_n - dst_s) / res[1])) 274 | 275 | output_transform = Affine.translation(dst_w, dst_n) * Affine.scale(res[0], -res[1]) 276 | 277 | if dtype is not None: 278 | dt = dtype 279 | logger.debug("Set dtype: %s", dt) 280 | 281 | out_profile = first_profile 282 | out_profile.update(**(dst_kwds or {})) 283 | 284 | out_profile["transform"] = output_transform 285 | out_profile["height"] = output_height 286 | out_profile["width"] = output_width 287 | out_profile["count"] = output_count 288 | out_profile["dtype"] = dt 289 | if nodata is not None: 290 | out_profile["nodata"] = nodata 291 | 292 | # create destination array 293 | dest = np.zeros((output_count, output_height, output_width), dtype=dt) 294 | 295 | if nodata is not None: 296 | nodataval = nodata 297 | logger.debug("Set nodataval: %r", nodataval) 298 | 299 | if nodataval is not None: 300 | # Only fill if the nodataval is within dtype's range 301 | inrange = False 302 | if np.issubdtype(dt, np.integer): 303 | info = np.iinfo(dt) 304 | inrange = info.min <= nodataval <= info.max 305 | elif np.issubdtype(dt, np.floating): 306 | if math.isnan(nodataval): 307 | inrange = True 308 | else: 309 | info = np.finfo(dt) 310 | inrange = info.min <= nodataval <= info.max 311 | if inrange: 312 | dest.fill(nodataval) 313 | else: 314 | warnings.warn( 315 | "The nodata value, %s, is beyond the valid " 316 | "range of the chosen data type, %s. Consider overriding it " 317 | "using the --nodata option for better results." % (nodataval, dt) 318 | ) 319 | else: 320 | nodataval = 0 321 | 322 | for idx, dataset in enumerate(datasets): 323 | with dataset_opener(dataset) as src: 324 | # Real World (tm) use of boundless reads. 325 | # This approach uses the maximum amount of memory to solve the 326 | # problem. Making it more efficient is a TODO. 327 | 328 | if disjoint_bounds((dst_w, dst_s, dst_e, dst_n), src.bounds): 329 | logger.debug("Skipping source: src=%r, window=%r", src) 330 | continue 331 | 332 | # 1. Compute spatial intersection of destination and source 333 | src_w, src_s, src_e, src_n = src.bounds 334 | 335 | int_w = src_w if src_w > dst_w else dst_w 336 | int_s = src_s if src_s > dst_s else dst_s 337 | int_e = src_e if src_e < dst_e else dst_e 338 | int_n = src_n if src_n < dst_n else dst_n 339 | 340 | # 2. Compute the source window 341 | src_window = windows.from_bounds(int_w, int_s, int_e, int_n, src.transform) 342 | 343 | # 3. Compute the destination window 344 | dst_window = windows.from_bounds( 345 | int_w, int_s, int_e, int_n, output_transform 346 | ) 347 | 348 | # 4. Read data in source window into temp 349 | src_window_rnd_shp = src_window.round_lengths() 350 | dst_window_rnd_shp = dst_window.round_lengths() 351 | dst_window_rnd_off = dst_window_rnd_shp.round_offsets() 352 | 353 | temp_height, temp_width = ( 354 | dst_window_rnd_off.height, 355 | dst_window_rnd_off.width, 356 | ) 357 | temp_shape = (src_count, temp_height, temp_width) 358 | 359 | temp_src = src.read( 360 | out_shape=temp_shape, 361 | window=src_window_rnd_shp, 362 | boundless=False, 363 | masked=True, 364 | indexes=indexes, 365 | resampling=resampling, 366 | ) 367 | 368 | # 5. Copy elements of temp into dest 369 | roff, coff = ( 370 | max(0, dst_window_rnd_off.row_off), 371 | max(0, dst_window_rnd_off.col_off), 372 | ) 373 | region = dest[:, roff : roff + temp_height, coff : coff + temp_width] 374 | 375 | # region_mask = true where there is NODATA 376 | if math.isnan(nodataval): 377 | region_mask = np.isnan(region) 378 | elif np.issubdtype(region.dtype, np.floating): 379 | region_mask = np.isclose(region, nodataval) 380 | else: 381 | region_mask = region == nodataval 382 | 383 | # Ensure common shape, resolving issue #2202. 384 | temp = temp_src[:, : region.shape[1], : region.shape[2]] 385 | temp_mask = np.ma.getmask(temp) 386 | copyto(region, temp, region_mask, temp_mask, index=idx, roff=roff, coff=coff) 387 | 388 | valid_mask = dest != nodataval 389 | 390 | if dst_path is None: 391 | return dest, output_transform, valid_mask 392 | 393 | else: 394 | with rasterio.open(dst_path, "w", **out_profile) as dst: 395 | dst.write(dest) 396 | if first_colormap: 397 | dst.write_colormap(1, first_colormap) 398 | -------------------------------------------------------------------------------- /mae/dataloaders/mlrsnet.py: -------------------------------------------------------------------------------- 1 | # 224 "beach" files were 600x600x3 instead of 224x224x3. 2 | 3 | 4 | class MLRSNET_DATASET_STATS: 5 | PIXEL_MEANS = [0.4002, 0.4116, 0.3874] 6 | PIXEL_STD = [0.2107, 0.1915, 0.1940] 7 | -------------------------------------------------------------------------------- /mae/dataloaders/naip.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | 3 | from lib.transforms import get_inputs_outputs 4 | from matplotlib import pyplot as plt 5 | from samplers.distributed import DistributedRandomGeoSampler 6 | from torchgeo.datasets import stack_samples 7 | from torchgeo.samplers import Units 8 | 9 | from .geo import CustomRasterDataset 10 | 11 | 12 | class NAIP(CustomRasterDataset): 13 | """National Agriculture Imagery Program (NAIP) dataset. 14 | 15 | The `National Agriculture Imagery Program (NAIP) 16 | `_ 17 | acquires aerial imagery during the agricultural growing seasons in the continental 18 | U.S. A primary goal of the NAIP program is to make digital ortho photography 19 | available to governmental agencies and the public within a year of acquisition. 20 | 21 | NAIP is administered by the USDA's Farm Service Agency (FSA) through the Aerial 22 | Photography Field Office in Salt Lake City. This "leaf-on" imagery is used as a base 23 | layer for GIS programs in FSA's County Service Centers, and is used to maintain the 24 | Common Land Unit (CLU) boundaries. 25 | 26 | If you use this dataset in your research, please cite it using the following format: 27 | 28 | * https://www.fisheries.noaa.gov/inport/item/49508/citation 29 | """ 30 | 31 | # https://www.nrcs.usda.gov/Internet/FSE_DOCUMENTS/nrcs141p2_015644.pdf 32 | # https://planetarycomputer.microsoft.com/dataset/naip#Storage-Documentation 33 | filename_glob = "m_*.*" 34 | filename_regex = r""" 35 | ^m 36 | _(?P\d+) 37 | _(?P[a-z]+) 38 | _(?P\d+) 39 | _(?P\d+) 40 | _(?P\d+) 41 | (?:_(?P\d+))? 42 | \..*$ 43 | """ 44 | 45 | # Plotting 46 | all_bands = ["R", "G", "B", "NIR"] 47 | rgb_bands = ["R", "G", "B"] 48 | 49 | def plot( 50 | self, 51 | sample: Dict[str, Any], 52 | show_titles: bool = True, 53 | suptitle: Optional[str] = None, 54 | ) -> plt.Figure: 55 | """Plot a sample from the dataset. 56 | 57 | Args: 58 | sample: a sample returned by :meth:`RasterDataset.__getitem__` 59 | show_titles: flag indicating whether to show titles above each panel 60 | suptitle: optional string to use as a suptitle 61 | 62 | Returns: 63 | a matplotlib Figure with the rendered sample 64 | 65 | .. versionchanged:: 0.3 66 | Method now takes a sample dict, not a Tensor. Additionally, possible to 67 | show subplot titles and/or use a custom suptitle. 68 | """ 69 | image = sample["image"][0:3, :, :].permute(1, 2, 0) 70 | 71 | fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(4, 4)) 72 | 73 | ax.imshow(image) 74 | ax.axis("off") 75 | if show_titles: 76 | ax.set_title("Image") 77 | 78 | if suptitle is not None: 79 | plt.suptitle(suptitle) 80 | 81 | return fig 82 | 83 | 84 | from .geo import CustomRasterDataset 85 | 86 | 87 | class NAIP(CustomRasterDataset): 88 | """National Agriculture Imagery Program (NAIP) dataset. 89 | 90 | The `National Agriculture Imagery Program (NAIP) 91 | `_ 92 | acquires aerial imagery during the agricultural growing seasons in the continental 93 | U.S. A primary goal of the NAIP program is to make digital ortho photography 94 | available to governmental agencies and the public within a year of acquisition. 95 | 96 | NAIP is administered by the USDA's Farm Service Agency (FSA) through the Aerial 97 | Photography Field Office in Salt Lake City. This "leaf-on" imagery is used as a base 98 | layer for GIS programs in FSA's County Service Centers, and is used to maintain the 99 | Common Land Unit (CLU) boundaries. 100 | 101 | If you use this dataset in your research, please cite it using the following format: 102 | 103 | * https://www.fisheries.noaa.gov/inport/item/49508/citation 104 | """ 105 | 106 | # https://www.nrcs.usda.gov/Internet/FSE_DOCUMENTS/nrcs141p2_015644.pdf 107 | # https://planetarycomputer.microsoft.com/dataset/naip#Storage-Documentation 108 | filename_glob = "m_*.*" 109 | filename_regex = r""" 110 | ^m 111 | _(?P\d+) 112 | _(?P[a-z]+) 113 | _(?P\d+) 114 | _(?P\d+) 115 | _(?P\d+) 116 | (?:_(?P\d+))? 117 | \..*$ 118 | """ 119 | 120 | # Plotting 121 | all_bands = ["R", "G", "B", "NIR"] 122 | rgb_bands = ["R", "G", "B"] 123 | 124 | def plot( 125 | self, 126 | sample: Dict[str, Any], 127 | show_titles: bool = True, 128 | suptitle: Optional[str] = None, 129 | ) -> plt.Figure: 130 | """Plot a sample from the dataset. 131 | 132 | Args: 133 | sample: a sample returned by :meth:`RasterDataset.__getitem__` 134 | show_titles: flag indicating whether to show titles above each panel 135 | suptitle: optional string to use as a suptitle 136 | 137 | Returns: 138 | a matplotlib Figure with the rendered sample 139 | 140 | .. versionchanged:: 0.3 141 | Method now takes a sample dict, not a Tensor. Additionally, possible to 142 | show subplot titles and/or use a custom suptitle. 143 | """ 144 | image = sample["image"][0:3, :, :].permute(1, 2, 0) 145 | 146 | fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(4, 4)) 147 | 148 | ax.imshow(image) 149 | ax.axis("off") 150 | if show_titles: 151 | ax.set_title("Image") 152 | 153 | if suptitle is not None: 154 | plt.suptitle(suptitle) 155 | 156 | return fig 157 | 158 | 159 | class NAIPStackSampleCollateFn: 160 | def __init__(self, transforms, base_resolution=1.0): 161 | self.transforms = transforms 162 | self.base_resolution = base_resolution 163 | 164 | def __call__(self, samples): 165 | targets = stack_samples(samples)["image"][:, :3, :, :] 166 | valid_masks = stack_samples(samples)["validmask"][:, :3, :, :] 167 | targets = targets.float() / 255 168 | if self.transforms is not None: 169 | targets, imgs_src, ratios, zero_ratio, valid_masks = self.transforms( 170 | targets, valid_masks 171 | ) 172 | targets_res = ratios * self.base_resolution 173 | imgs_src_res = targets_res * (targets.shape[-1] / imgs_src.shape[-1]) 174 | return get_inputs_outputs(imgs_src, imgs_src_res, targets, targets_res), dict( 175 | zero_ratio=zero_ratio, valid_masks=valid_masks 176 | ) 177 | 178 | 179 | def build_naip_sampler(config, args, num_replicas, rank, transforms): 180 | config = config["data"] 181 | naip = NAIP(config["img_dir"]) 182 | 183 | # To support multiple output sizes per batch, TorchGeo should crop to the largest possible target size first 184 | sampler = DistributedRandomGeoSampler( 185 | naip, 186 | size=1024, 187 | length=config["length"], 188 | units=Units.PIXELS, 189 | num_replicas=num_replicas, 190 | rank=rank, 191 | ) 192 | collate_fn = NAIPStackSampleCollateFn(transforms, args.base_resolution) 193 | return naip, sampler, collate_fn 194 | -------------------------------------------------------------------------------- /mae/dataloaders/optimal.py: -------------------------------------------------------------------------------- 1 | class OPTIMAL_DATASET_STATS: 2 | # Divided by 255 3 | PIXEL_MEANS = [ 4 | 0.3688422134901961, 5 | 0.3807842425882353, 6 | 0.3377770719607843, 7 | ] # [0.485, 0.456, 0.406] # imagenet 8 | PIXEL_STD = [ 9 | 0.20127227078431373, 10 | 0.1805812771764706, 11 | 0.1775864696862745, 12 | ] # [0.229, 0.224, 0.225] # imagenet 13 | -------------------------------------------------------------------------------- /mae/dataloaders/resic45.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from torchvision import transforms 4 | from torchvision.datasets import ImageFolder 5 | 6 | from .imagelist import ImageList 7 | 8 | 9 | class RESIC_DATASET_STATS: 10 | PIXEL_MEANS = [0.368, 0.381, 0.3436] 11 | PIXEL_STD = [0.2035, 0.1854, 0.1849] 12 | 13 | 14 | def build_resic(data_root, transforms): 15 | # backwards compatable -- pass in a folder or a list of images 16 | # this hardcoding isn't great 17 | if os.path.isdir(data_root): 18 | return ImageFolder(data_root, transform=transforms) 19 | return ImageList(data_root, transforms) 20 | 21 | 22 | def build_resic_gsd_resample(input_size=224, gsd_scale=1.0): 23 | if gsd_scale == 1.0: 24 | return [] 25 | 26 | resample = transforms.Resize(size=(input_size * gsd_scale, input_size * gsd_scale)) 27 | original = transforms.Resize(size=(input_size, input_size)) 28 | 29 | return (resample, original) 30 | -------------------------------------------------------------------------------- /mae/dataloaders/sentinel2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from lib.transforms import get_inputs_outputs 3 | from samplers.distributed import DistributedRandomGeoSampler 4 | from torchgeo.datasets import RasterDataset, stack_samples 5 | from torchgeo.samplers import Units 6 | 7 | from .geo import CustomRasterDataset 8 | 9 | 10 | class Sentinel2StackSampleCollateFn: 11 | def __init__(self, transforms, over_sample_factor=1.0, base_resolution=1.0): 12 | self.transforms = transforms 13 | self.over_sample_factor = over_sample_factor 14 | self.base_resolution = base_resolution 15 | 16 | def __call__(self, samples): 17 | imgs = stack_samples(samples)["image"][:, :3, :, :] 18 | valid_masks = stack_samples(samples)["validmask"] 19 | b, c, h, w = imgs.shape 20 | tgt_b = int(b / self.over_sample_factor) 21 | zero_ratio = (valid_masks == 0).sum((1, 2, 3)) / (h * w * c) 22 | zero_ratio_order = torch.argsort(zero_ratio, descending=False) 23 | imgs = imgs[zero_ratio_order][:tgt_b].contiguous() 24 | valid_masks = valid_masks[zero_ratio_order][:tgt_b].contiguous() 25 | assert imgs.shape == (tgt_b, c, h, w) 26 | imgs = imgs.float() # / 255 27 | if self.transforms is not None: 28 | imgs, imgs_src, ratios, zero_ratio, valid_masks = self.transforms( 29 | imgs, valid_masks 30 | ) # ratio is crop_dim / original_dim, so resolution should be 1/ ratios 31 | res = ratios * self.base_resolution 32 | imgs_src_res = res * (imgs.shape[-1] / imgs_src.shape[-1]) 33 | return get_inputs_outputs(imgs_src, imgs_src_res, imgs, res), dict( 34 | zero_ratio=zero_ratio, valid_masks=valid_masks 35 | ) 36 | 37 | 38 | class Sentinel2(CustomRasterDataset): 39 | filename_glob = "T*_B02_10m.tif" 40 | filename_regex = r"^.{6}_(?P\d{8}T\d{6})_(?PB0[\d])" 41 | date_format = "%Y%m%dT%H%M%S" 42 | is_image = True 43 | separate_files = True 44 | all_bands = ["B02", "B03", "B04"] 45 | rgb_bands = ["B04", "B03", "B02"] 46 | 47 | 48 | def build_sentinel_sampler(config, args, num_replicas, rank, transforms): 49 | config = config["data"] 50 | dataset = Sentinel2(config["img_dir"]) 51 | over_sample_factor = config["oversample"] 52 | sampler = DistributedRandomGeoSampler( 53 | dataset, 54 | size=config["size"], 55 | length=int(config["length"] * over_sample_factor), 56 | units=Units.PIXELS, 57 | num_replicas=num_replicas, 58 | rank=rank, 59 | ) 60 | collate_fn = Sentinel2StackSampleCollateFn( 61 | transforms, over_sample_factor, args.base_resolution 62 | ) 63 | return dataset, sampler, collate_fn 64 | -------------------------------------------------------------------------------- /mae/dataloaders/ucmerced.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class UCMERCED_DATASET_STATS: 4 | # Divided by 255. 5 | # Alpha channel values elided (255/0) 6 | PIXEL_MEANS = [0.4842, 0.4901, 0.4505] 7 | PIXEL_STD = [0.2180, 0.2021, 0.1958] 8 | -------------------------------------------------------------------------------- /mae/dataloaders/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import util.misc as misc 5 | from torchvision import datasets, transforms 6 | from torchvision.datasets import ImageFolder 7 | 8 | from .airound import AIROUND_DATASET_STATS 9 | from .cvbrct import CVBRCT_DATASET_STATS 10 | from .eurosat import EUROSAT_DATASET_STATS 11 | from .fmow import FMOW_DATASET_STATS, build_fmow 12 | from .imagelist import ImageList 13 | from .imagenet100 import build_imagenet_sampler 14 | from .mlrsnet import MLRSNET_DATASET_STATS 15 | from .naip import build_naip_sampler 16 | from .optimal import OPTIMAL_DATASET_STATS 17 | from .resic45 import RESIC_DATASET_STATS, build_resic 18 | from .sentinel2 import build_sentinel_sampler 19 | from .ucmerced import UCMERCED_DATASET_STATS 20 | from .whurs import WHURS_DATASET_STATS 21 | from .xview import build_xview2_sampler 22 | 23 | dataset_stats_lookup = { 24 | "airound": AIROUND_DATASET_STATS, 25 | "cvbrct": CVBRCT_DATASET_STATS, 26 | "mlrsnet": MLRSNET_DATASET_STATS, 27 | "resisc": RESIC_DATASET_STATS, 28 | "eurosat": EUROSAT_DATASET_STATS, 29 | "optimal-31": OPTIMAL_DATASET_STATS, 30 | "whu-rs19": WHURS_DATASET_STATS, 31 | "ucmerced": UCMERCED_DATASET_STATS, 32 | "fmow": FMOW_DATASET_STATS, 33 | } 34 | 35 | 36 | def get_dataset_and_sampler( 37 | args, 38 | config, 39 | split="train", 40 | num_replicas=None, 41 | rank=None, 42 | transforms=None, 43 | transforms_init=None, 44 | linprobe_finetune=False, 45 | ): 46 | dataset_type = config["data"]["type"] 47 | if dataset_type == "NAIP": 48 | return build_naip_sampler(config, args, num_replicas, rank, transforms) 49 | elif dataset_type == "SENTINEL2": 50 | return build_sentinel_sampler(config, args, num_replicas, rank, transforms) 51 | elif dataset_type == "XView2": 52 | return build_xview2_sampler( 53 | config=config, 54 | num_replicas=num_replicas, 55 | rank=rank, 56 | transforms=transforms, 57 | split=split, 58 | ) 59 | elif dataset_type == "ImageNet": 60 | return build_imagenet_sampler( 61 | config=config, num_replicas=num_replicas, rank=rank, transforms=transforms 62 | ) 63 | elif dataset_type in ["fmow"]: 64 | dataset = datasets.ImageFolder( 65 | root=config["data"]["img_dir"], 66 | transform=transforms_init, 67 | is_valid_file=is_fmow_rgb, 68 | ) 69 | sampler_train = torch.utils.data.DistributedSampler( 70 | dataset, num_replicas=num_replicas, rank=rank, shuffle=True 71 | ) 72 | 73 | if not linprobe_finetune: 74 | return ( 75 | dataset, 76 | sampler_train, 77 | TransformCollateFn(transforms, args.base_resolution), 78 | ) 79 | else: 80 | return ( 81 | dataset, 82 | sampler_train, 83 | TransformCollateFnLabel(transforms, args.base_resolution), 84 | ) 85 | elif dataset_type == "resisc": 86 | dataset = build_resic(config["data"]["img_dir"], transforms=transforms_init) 87 | sampler_train = torch.utils.data.DistributedSampler( 88 | dataset, num_replicas=num_replicas, rank=rank, shuffle=True 89 | ) 90 | if not linprobe_finetune: 91 | return ( 92 | dataset, 93 | sampler_train, 94 | TransformCollateFn(transforms, args.base_resolution), 95 | ) 96 | else: 97 | return ( 98 | dataset, 99 | sampler_train, 100 | TransformCollateFnLabel(transforms, args.base_resolution), 101 | ) 102 | elif dataset_type == "eurosat": 103 | dataset = datasets.ImageFolder( 104 | root=config["data"]["img_dir"], transform=transforms_init 105 | ) 106 | sampler_train = torch.utils.data.DistributedSampler( 107 | dataset, num_replicas=num_replicas, rank=rank, shuffle=True 108 | ) 109 | 110 | if not linprobe_finetune: 111 | return ( 112 | dataset, 113 | sampler_train, 114 | TransformCollateFn(transforms, args.base_resolution), 115 | ) 116 | else: 117 | return ( 118 | dataset, 119 | sampler_train, 120 | TransformCollateFnLabel(transforms, args.base_resolution), 121 | ) 122 | else: 123 | raise NotImplementedError 124 | 125 | 126 | def is_fmow_rgb(fname: str) -> bool: 127 | return fname.endswith("_rgb.jpg") 128 | 129 | 130 | class TransformCollateFn: 131 | def __init__(self, transforms, base_resolution=1.0): 132 | self.transforms = transforms 133 | self.base_resolution = base_resolution 134 | 135 | def __call__(self, samples): 136 | imgs = torch.stack(list(zip(*samples))[0]) 137 | imgs, imgs_src, ratios, _, _ = self.transforms(imgs) 138 | res = ratios * self.base_resolution 139 | imgs_src_res = res * (imgs.shape[-1] / imgs_src.shape[-1]) 140 | return (imgs_src, imgs_src_res, imgs, res), None 141 | 142 | 143 | class TransformCollateFnLabel: 144 | def __init__(self, transforms, base_resolution=1.0): 145 | self.transforms = transforms 146 | self.base_resolution = base_resolution 147 | 148 | def __call__(self, samples): 149 | imgs = torch.stack(list(zip(*samples))[0]) 150 | labels = torch.tensor([x[1] for x in samples]) 151 | imgs, imgs_src, ratios, _, _ = self.transforms(imgs) 152 | res = ratios * self.base_resolution 153 | imgs_src_res = res * (imgs.shape[-1] / imgs_src.shape[-1]) 154 | return (imgs_src, imgs_src_res, imgs, res, labels), None 155 | 156 | 157 | def get_eval_dataset_and_transform( 158 | eval_dataset_id="resisc", 159 | eval_dataset_path="~/data/resisc", 160 | transforms_init=None, 161 | args=None, 162 | ): 163 | # All of these datasets are ImageFolders 164 | if eval_dataset_id in [ 165 | "resisc", 166 | "mlrsnet", 167 | "airound", 168 | "cvbrct", 169 | "eurosat", 170 | "optimal-31", 171 | "whu-rs19", 172 | "ucmerced", 173 | ]: 174 | ds_stats = dataset_stats_lookup[eval_dataset_id] 175 | transform_normalize = transforms.Normalize( 176 | mean=ds_stats.PIXEL_MEANS, std=ds_stats.PIXEL_STD 177 | ) 178 | use_transforms = [transforms.ToTensor(), transform_normalize] 179 | if transforms_init: 180 | use_transforms.insert(0, transforms_init) 181 | if eval_dataset_id == 'ucmerced': 182 | use_transforms.insert(0, transforms.Resize((256,256))) 183 | transform_eval = transforms.Compose(use_transforms) 184 | 185 | if os.path.isdir(eval_dataset_path): 186 | dataset_eval = ImageFolder(eval_dataset_path, transform=transform_eval) 187 | else: 188 | dataset_eval = ImageList(eval_dataset_path, transform=transform_eval) 189 | 190 | elif eval_dataset_id == "fmow": 191 | ds_stats = dataset_stats_lookup[eval_dataset_id] 192 | if transforms_init and args: 193 | transform_eval = transforms.Compose( 194 | [ 195 | # Resize only the short side 196 | transforms.Resize(args.eval_scale), 197 | # TODO this may not be the right thing to do here. 198 | transforms.CenterCrop(args.eval_scale), 199 | transforms.ToTensor(), 200 | transforms.Normalize( 201 | mean=ds_stats.PIXEL_MEANS, std=ds_stats.PIXEL_STD 202 | ), 203 | ] 204 | ) 205 | else: 206 | transform_eval = transforms.Compose( 207 | [ 208 | # TODO remove hardcoding px size? 209 | transforms.Resize(512), # downsample short side to 512 210 | transforms.CenterCrop(512), 211 | transforms.ToTensor(), 212 | transforms.Normalize( 213 | mean=ds_stats.PIXEL_MEANS, std=ds_stats.PIXEL_STD 214 | ), 215 | ] 216 | ) 217 | dataset_eval = build_fmow(eval_dataset_path, transforms=transform_eval) 218 | 219 | else: 220 | raise NotImplementedError 221 | 222 | return dataset_eval, transform_eval 223 | -------------------------------------------------------------------------------- /mae/dataloaders/whurs.py: -------------------------------------------------------------------------------- 1 | class WHURS_DATASET_STATS: 2 | # 4 files that were not 600x600 deleted from dataset 3 | # Divided by 255. 4 | PIXEL_MEANS = [0.42647115, 0.44830369, 0.40256118] 5 | PIXEL_STD = [0.24203978, 0.21878441, 0.22890343] 6 | -------------------------------------------------------------------------------- /mae/dataloaders/xview.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. All rights reserved. 2 | # Licensed under the MIT License. 3 | 4 | """xView2 dataset.""" 5 | 6 | import glob 7 | import os 8 | from typing import Callable, Dict, List, Optional 9 | 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import torch 13 | from PIL import Image 14 | from torch import Tensor 15 | from torch.utils.data.distributed import DistributedSampler 16 | from torchgeo.datasets import stack_samples 17 | from torchgeo.datasets.geo import NonGeoDataset 18 | from torchgeo.datasets.utils import ( 19 | check_integrity, 20 | draw_semantic_segmentation_masks, 21 | extract_archive, 22 | ) 23 | 24 | 25 | class XView2StackSampleCollateFn: 26 | def __init__(self): 27 | pass 28 | 29 | def __call__(self, samples): 30 | imgs = stack_samples(samples)["image"] 31 | return imgs, None 32 | 33 | 34 | class XView2(NonGeoDataset): 35 | """xView2 dataset. 36 | 37 | The `xView2 `__ 38 | dataset is a dataset for building disaster change detection. This dataset object 39 | uses the "Challenge training set (~7.8 GB)" and "Challenge test set (~2.6 GB)" data 40 | from the xView2 website as the train and test splits. Note, the xView2 website 41 | contains other data under the xView2 umbrella that are _not_ included here. E.g. 42 | the "Tier3 training data", the "Challenge holdout set", and the "full data". 43 | 44 | Dataset format: 45 | 46 | * images are three-channel pngs 47 | * masks are single-channel pngs where the pixel values represent the class 48 | 49 | Dataset classes: 50 | 51 | 0. background 52 | 1. no damage 53 | 2. minor damage 54 | 3. major damage 55 | 4. destroyed 56 | 57 | If you use this dataset in your research, please cite the following paper: 58 | 59 | * https://arxiv.org/abs/1911.09296 60 | 61 | .. versionadded:: 0.2 62 | """ 63 | 64 | metadata = { 65 | "train": { 66 | "filename": "train_images_labels_targets.tar.gz", 67 | "md5": "a20ebbfb7eb3452785b63ad02ffd1e16", 68 | "directory": "train", 69 | }, 70 | "test": { 71 | "filename": "test_images_labels_targets.tar.gz", 72 | "md5": "1b39c47e05d1319c17cc8763cee6fe0c", 73 | "directory": "test", 74 | }, 75 | } 76 | classes = ["background", "no-damage", "minor-damage", "major-damage", "destroyed"] 77 | colormap = ["green", "blue", "orange", "red"] 78 | 79 | def __init__( 80 | self, 81 | root: str = "data", 82 | split: str = "train", 83 | transforms: Optional[Callable[[Dict[str, Tensor]], Dict[str, Tensor]]] = None, 84 | checksum: bool = False, 85 | ) -> None: 86 | """Initialize a new xView2 dataset instance. 87 | 88 | Args: 89 | root: root directory where dataset can be found 90 | split: one of "train" or "test" 91 | transforms: a function/transform that takes input sample and its target as 92 | entry and returns a transformed version 93 | checksum: if True, check the MD5 of the downloaded files (may be slow) 94 | """ 95 | assert split in self.metadata 96 | self.root = root 97 | self.split = split 98 | self.transforms = transforms 99 | self.checksum = checksum 100 | 101 | self._verify() 102 | 103 | self.class2idx = {c: i for i, c in enumerate(self.classes)} 104 | self.files = self._load_files(root, split) 105 | 106 | def __getitem__(self, index: int) -> Dict[str, Tensor]: 107 | """Return an index within the dataset. 108 | 109 | Args: 110 | index: index to return 111 | 112 | Returns: 113 | data and label at that index 114 | """ 115 | files = self.files[index] 116 | pre = self._load_image(files["pre"]) 117 | post = self._load_image(files["post"]) 118 | mask_pre = self._load_target(files["mask_pre"]) 119 | mask_post = self._load_target(files["mask_post"]) 120 | 121 | image = torch.stack(tensors=[pre, post], dim=0) 122 | mask = torch.stack(tensors=[mask_pre, mask_post], dim=0) 123 | sample = {"image": image, "mask": mask} 124 | 125 | if self.transforms is not None: 126 | sample = self.transforms(sample) 127 | 128 | return sample 129 | 130 | def __len__(self) -> int: 131 | """Return the number of data points in the dataset. 132 | 133 | Returns: 134 | length of the dataset 135 | """ 136 | return len(self.files) 137 | 138 | def _load_files(self, root: str, split: str) -> List[Dict[str, str]]: 139 | """Return the paths of the files in the dataset. 140 | 141 | Args: 142 | root: root dir of dataset 143 | split: subset of dataset, one of [train, test] 144 | 145 | Returns: 146 | list of dicts containing paths for each pair of images and masks 147 | """ 148 | files = [] 149 | directory = self.metadata[split]["directory"] 150 | image_root = os.path.join(root, directory, "images") 151 | mask_root = os.path.join(root, directory, "targets") 152 | images = glob.glob(os.path.join(image_root, "*.png")) 153 | basenames = [os.path.basename(f) for f in images] 154 | basenames = ["_".join(f.split("_")[:-2]) for f in basenames] 155 | for name in set(basenames): 156 | pre = os.path.join(image_root, f"{name}_pre_disaster.png") 157 | post = os.path.join(image_root, f"{name}_post_disaster.png") 158 | mask_pre = os.path.join(mask_root, f"{name}_pre_disaster_target.png") 159 | mask_post = os.path.join(mask_root, f"{name}_post_disaster_target.png") 160 | files.append( 161 | dict(pre=pre, post=post, mask_pre=mask_pre, mask_post=mask_post) 162 | ) 163 | return files 164 | 165 | def _load_image(self, path: str) -> Tensor: 166 | """Load a single image. 167 | 168 | Args: 169 | path: path to the image 170 | 171 | Returns: 172 | the image of shape C X H X W 173 | """ 174 | filename = os.path.join(path) 175 | with Image.open(filename) as img: 176 | array: "np.typing.NDArray[np.int_]" = np.array(img.convert("RGB")) 177 | tensor = torch.from_numpy(array) 178 | # Convert from HxWxC to CxHxW 179 | tensor = tensor.permute((2, 0, 1)) 180 | return tensor 181 | 182 | def _load_target(self, path: str) -> Tensor: 183 | """Load the target mask for a single image. 184 | 185 | Args: 186 | path: path to the image 187 | 188 | Returns: 189 | the target mask 190 | """ 191 | filename = os.path.join(path) 192 | with Image.open(filename) as img: 193 | array: "np.typing.NDArray[np.int_]" = np.array(img.convert("L")) 194 | tensor = torch.from_numpy(array) 195 | tensor = tensor.to(torch.long) 196 | return tensor 197 | 198 | def _verify(self) -> None: 199 | """Verify the integrity of the dataset. 200 | 201 | Raises: 202 | RuntimeError: if checksum fails or the dataset is not downloaded 203 | """ 204 | # Check if the files already exist 205 | exists = [] 206 | for split_info in self.metadata.values(): 207 | for directory in ["images", "targets"]: 208 | exists.append( 209 | os.path.exists( 210 | os.path.join(self.root, split_info["directory"], directory) 211 | ) 212 | ) 213 | 214 | if all(exists): 215 | return 216 | 217 | # Check if .tar.gz files already exists (if so then extract) 218 | exists = [] 219 | for split_info in self.metadata.values(): 220 | filepath = os.path.join(self.root, split_info["filename"]) 221 | if os.path.isfile(filepath): 222 | if self.checksum and not check_integrity(filepath, split_info["md5"]): 223 | raise RuntimeError("Dataset found, but corrupted.") 224 | exists.append(True) 225 | extract_archive(filepath) 226 | else: 227 | exists.append(False) 228 | 229 | if all(exists): 230 | return 231 | 232 | # Check if the user requested to download the dataset 233 | raise RuntimeError( 234 | "Dataset not found in `root` directory, either specify a different" 235 | + " `root` directory or manually download the dataset to this directory." 236 | ) 237 | 238 | def plot( 239 | self, 240 | sample: Dict[str, Tensor], 241 | show_titles: bool = True, 242 | suptitle: Optional[str] = None, 243 | alpha: float = 0.5, 244 | ) -> plt.Figure: 245 | """Plot a sample from the dataset. 246 | 247 | Args: 248 | sample: a sample returned by :meth:`__getitem__` 249 | show_titles: flag indicating whether to show titles above each panel 250 | suptitle: optional string to use as a suptitle 251 | alpha: opacity with which to render predictions on top of the imagery 252 | 253 | Returns: 254 | a matplotlib Figure with the rendered sample 255 | """ 256 | ncols = 2 257 | image_pre = draw_semantic_segmentation_masks( 258 | sample["image"][0], sample["mask"][0], alpha=alpha, colors=self.colormap 259 | ) 260 | image_post = draw_semantic_segmentation_masks( 261 | sample["image"][1], sample["mask"][1], alpha=alpha, colors=self.colormap 262 | ) 263 | if "prediction" in sample: # NOTE: this assumes predictions are made for post 264 | ncols += 1 265 | image3 = draw_semantic_segmentation_masks( 266 | sample["images"]["post"], 267 | sample["prediction"], 268 | alpha=alpha, 269 | colors=self.colormap, 270 | ) 271 | 272 | fig, axs = plt.subplots(ncols=ncols, figsize=(ncols * 10, 10)) 273 | axs[0].imshow(image_pre) 274 | axs[0].axis("off") 275 | axs[1].imshow(image_post) 276 | axs[1].axis("off") 277 | if ncols > 2: 278 | axs[2].imshow(image3) 279 | axs[2].axis("off") 280 | 281 | if show_titles: 282 | axs[0].set_title("Pre disaster") 283 | axs[1].set_title("Post disaster") 284 | if ncols > 2: 285 | axs[2].set_title("Predictions") 286 | 287 | if suptitle is not None: 288 | plt.suptitle(suptitle) 289 | 290 | return fig 291 | 292 | 293 | def build_xview2_sampler(config, num_replicas, rank, transforms, split="train"): 294 | xv_dataset = XView2( 295 | root=config["data"]["img_dir"], split=split, transforms=transforms 296 | ) 297 | sampler = DistributedSampler( 298 | xv_dataset, 299 | num_replicas=num_replicas, 300 | rank=rank, 301 | shuffle=config["sampler"]["shuffle"], 302 | seed=config["sampler"]["seed"], 303 | ) 304 | collate_function = XView2StackSampleCollateFn() 305 | return xv_dataset, sampler, collate_function 306 | -------------------------------------------------------------------------------- /mae/engine_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import math 13 | import sys 14 | from typing import Iterable, Optional 15 | 16 | import torch 17 | import util.lr_sched as lr_sched 18 | import util.misc as misc 19 | from timm.data import Mixup 20 | from timm.utils import accuracy 21 | 22 | 23 | def train_one_epoch( 24 | model: torch.nn.Module, 25 | criterion: torch.nn.Module, 26 | data_loader: Iterable, 27 | optimizer: torch.optim.Optimizer, 28 | device: torch.device, 29 | epoch: int, 30 | loss_scaler, 31 | max_norm: float = 0, 32 | mixup_fn: Optional[Mixup] = None, 33 | log_writer=None, 34 | args=None, 35 | ): 36 | model.train(True) 37 | metric_logger = misc.MetricLogger(delimiter=" ") 38 | metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}")) 39 | header = f"Epoch: [{epoch}]" 40 | print_freq = 20 41 | 42 | accum_iter = args.accum_iter 43 | 44 | optimizer.zero_grad() 45 | 46 | if log_writer is not None: 47 | print(f"log_dir: {log_writer.log_dir}") 48 | 49 | for data_iter_step, ((samples, res, _, target_res, labels), metadata) in enumerate( 50 | metric_logger.log_every(data_loader, print_freq, header) 51 | ): 52 | 53 | # we use a per iteration (instead of per epoch) lr scheduler 54 | if data_iter_step % accum_iter == 0: 55 | lr_sched.adjust_learning_rate( 56 | optimizer, data_iter_step / len(data_loader) + epoch, args 57 | ) 58 | 59 | samples = samples.to(device, non_blocking=True) 60 | targets = labels.to(device, non_blocking=True) 61 | 62 | if mixup_fn is not None: 63 | samples, targets = mixup_fn(samples, targets) 64 | 65 | with torch.cuda.amp.autocast(): 66 | outputs = model(samples, input_res=res) 67 | loss = criterion(outputs, targets) 68 | 69 | loss_value = loss.item() 70 | 71 | if not math.isfinite(loss_value): 72 | print(f"Loss is {loss_value}, stopping training") 73 | sys.exit(1) 74 | 75 | loss /= accum_iter 76 | loss_scaler( 77 | loss, 78 | optimizer, 79 | clip_grad=max_norm, 80 | parameters=model.parameters(), 81 | create_graph=False, 82 | update_grad=(data_iter_step + 1) % accum_iter == 0, 83 | ) 84 | if (data_iter_step + 1) % accum_iter == 0: 85 | optimizer.zero_grad() 86 | 87 | torch.cuda.synchronize() 88 | 89 | metric_logger.update(loss=loss_value) 90 | min_lr = 10.0 91 | max_lr = 0.0 92 | for group in optimizer.param_groups: 93 | min_lr = min(min_lr, group["lr"]) 94 | max_lr = max(max_lr, group["lr"]) 95 | 96 | metric_logger.update(lr=max_lr) 97 | 98 | loss_value_reduce = misc.all_reduce_mean(loss_value) 99 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 100 | """We use epoch_1000x as the x-axis in tensorboard. 101 | This calibrates different curves when batch size changes. 102 | """ 103 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 104 | log_writer.add_scalar("loss", loss_value_reduce, epoch_1000x) 105 | log_writer.add_scalar("lr", max_lr, epoch_1000x) 106 | 107 | # gather the stats from all processes 108 | metric_logger.synchronize_between_processes() 109 | print("Averaged stats:", metric_logger) 110 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 111 | 112 | 113 | @torch.no_grad() 114 | def evaluate( 115 | data_loader, 116 | model, 117 | device, 118 | eval_base_resolution=1.0, 119 | gsd_embed=False, 120 | eval_scale=512, 121 | reference_size=512, 122 | ): 123 | gsd_ratio = eval_base_resolution 124 | if gsd_embed: 125 | gsd_ratio = gsd_ratio * (reference_size / eval_scale) 126 | 127 | criterion = torch.nn.CrossEntropyLoss() 128 | 129 | metric_logger = misc.MetricLogger(delimiter=" ") 130 | header = "Test:" 131 | 132 | # switch to evaluation mode 133 | model.eval() 134 | 135 | for (samples, labels) in metric_logger.log_every(data_loader, 10, header): 136 | images = samples 137 | target = labels 138 | images = images.to(device, non_blocking=True) 139 | target = target.to(device, non_blocking=True) 140 | 141 | # compute output 142 | with torch.cuda.amp.autocast(): 143 | output = model( 144 | images, 145 | input_res=torch.ones(len(images)).float().to(images.device) * gsd_ratio, 146 | ) 147 | loss = criterion(output, target) 148 | 149 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 150 | 151 | batch_size = images.shape[0] 152 | metric_logger.update(loss=loss.item()) 153 | metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) 154 | metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) 155 | # gather the stats from all processes 156 | metric_logger.synchronize_between_processes() 157 | print( 158 | "* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}".format( 159 | top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss 160 | ) 161 | ) 162 | 163 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 164 | -------------------------------------------------------------------------------- /mae/engine_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | import math 12 | import sys 13 | from typing import Iterable 14 | 15 | import torch 16 | import util.lr_sched as lr_sched 17 | import util.misc as misc 18 | from wandb_log import wandb_dump_input_output, wandb_log_metadata 19 | 20 | 21 | def train_one_epoch( 22 | model: torch.nn.Module, 23 | data_loader: Iterable, 24 | optimizer: torch.optim.Optimizer, 25 | device: torch.device, 26 | epoch: int, 27 | loss_scaler, 28 | log_writer=None, 29 | args=None, 30 | scheduler=None, 31 | source_size_scheduler=None, 32 | fix_resolution_scheduler=None, 33 | ): 34 | model.train(True) 35 | 36 | metric_logger = misc.MetricLogger(delimiter=" ") 37 | metric_logger.add_meter("lr", misc.SmoothedValue(window_size=1, fmt="{value:.6f}")) 38 | header = f"Epoch: [{epoch}]" 39 | print_freq = args.print_freq 40 | 41 | accum_iter = args.accum_iter 42 | 43 | optimizer.zero_grad() 44 | 45 | if log_writer is not None: 46 | print(f"log_dir: {log_writer.log_dir}") 47 | 48 | for data_iter_step, ((samples, res, targets, target_res), metadata) in enumerate( 49 | metric_logger.log_every(data_loader, print_freq, header) 50 | ): 51 | # we use a per iteration (instead of per epoch) lr scheduler 52 | if data_iter_step % accum_iter == 0: 53 | lr_sched.adjust_learning_rate( 54 | optimizer, data_iter_step / len(data_loader) + epoch, args 55 | ) 56 | samples = samples.to(device, non_blocking=True) 57 | targets = targets.to(device, non_blocking=True) 58 | 59 | with torch.cuda.amp.autocast(): 60 | target_size = scheduler.get_target_size(epoch) 61 | source_size = source_size_scheduler.get_target_size(epoch)[0] 62 | fix_decoding_size = fix_resolution_scheduler.get_target_size(epoch) 63 | model.module.set_target_size(target_size) 64 | model.module.set_fix_decoding_size(fix_decoding_size) 65 | loss, y, mask, mean, var, pos_emb, pos_emb_decoder, samples = model( 66 | samples, 67 | input_res=res, 68 | targets=targets, 69 | target_res=target_res, 70 | mask_ratio=args.mask_ratio, 71 | source_size=source_size, 72 | ) 73 | 74 | if data_iter_step % print_freq == 0: 75 | y = [ 76 | model.module.unpatchify(y_i)[0].permute(1, 2, 0).detach().cpu() 77 | for y_i in y 78 | ] 79 | x = torch.einsum("nchw->nhwc", samples[:1]).detach().cpu() 80 | wandb_dump_input_output( 81 | x[0], 82 | y, 83 | epoch, 84 | f"target-size:{target_size}-output_size:{fix_decoding_size}", 85 | ) 86 | if metadata: 87 | wandb_log_metadata(metadata) 88 | 89 | loss_value = loss.item() 90 | 91 | if not math.isfinite(loss_value): 92 | print(f"Loss is {loss_value}, stopping training") 93 | sys.exit(1) 94 | 95 | loss = loss / accum_iter 96 | loss_scaler( 97 | loss, 98 | optimizer, 99 | parameters=model.parameters(), 100 | update_grad=(data_iter_step + 1) % accum_iter == 0, 101 | ) 102 | if (data_iter_step + 1) % accum_iter == 0: 103 | optimizer.zero_grad() 104 | 105 | torch.cuda.synchronize() 106 | 107 | metric_logger.update(loss=loss_value) 108 | 109 | lr = optimizer.param_groups[0]["lr"] 110 | metric_logger.update(lr=lr) 111 | 112 | loss_value_reduce = misc.all_reduce_mean(loss_value) 113 | if log_writer is not None and (data_iter_step + 1) % accum_iter == 0: 114 | """We use epoch_1000x as the x-axis in tensorboard. 115 | This calibrates different curves when batch size changes. 116 | """ 117 | epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000) 118 | log_writer.add_scalar("train_loss", loss_value_reduce, epoch_1000x) 119 | log_writer.add_scalar("lr", lr, epoch_1000x) 120 | 121 | # gather the stats from all processes 122 | metric_logger.synchronize_between_processes() 123 | print("Averaged stats:", metric_logger) 124 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 125 | -------------------------------------------------------------------------------- /mae/eval/knn.py: -------------------------------------------------------------------------------- 1 | import imp 2 | import os 3 | import time 4 | 5 | import torch 6 | import torch.distributed 7 | import util.misc as misc 8 | import wandb 9 | from torch.distributed import all_reduce 10 | from torch.nn.functional import adaptive_avg_pool2d 11 | from tqdm.cli import tqdm 12 | from util.dist_utils import gather_from_all 13 | 14 | 15 | # utils 16 | @torch.no_grad() 17 | def concat_all_gather(tensor): 18 | """ 19 | Performs all_gather operation on the provided tensors. 20 | *** Warning ***: torch.distributed.all_gather has no gradient. 21 | """ 22 | tensors_gather = [ 23 | torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size()) 24 | ] 25 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 26 | 27 | output = torch.cat(tensors_gather, dim=1) 28 | return output 29 | 30 | 31 | def get_knn_iter(x, gpu): 32 | if gpu == 0: 33 | return tqdm(x) 34 | else: 35 | return x 36 | 37 | 38 | @torch.no_grad() 39 | def kNN( 40 | cmd_args=None, 41 | net=None, 42 | trainloader=None, 43 | testloader=None, 44 | sigma=0.07, 45 | feat_dim=768, 46 | eval_scale=256, 47 | eval_base_resolution=1.0, 48 | gsd_embed=False, 49 | ): 50 | is_dist = misc.is_dist_avail_and_initialized() 51 | net.eval() 52 | print(f"Starting KNN evaluation with K={cmd_args.knn}") 53 | gsd_ratio = eval_base_resolution 54 | if gsd_embed: 55 | gsd_ratio = gsd_ratio * (224 / eval_scale) 56 | 57 | st_time = time.time() 58 | trainFeatures = torch.zeros( 59 | [feat_dim + 1, len(trainloader) * trainloader.batch_size] 60 | ) 61 | if not hasattr(cmd_args, "gpu"): 62 | cmd_args.gpu = None 63 | 64 | if cmd_args.gpu is not None: 65 | trainFeatures = trainFeatures.cuda(cmd_args.gpu) 66 | else: 67 | trainFeatures = trainFeatures.cuda() 68 | 69 | for batch_idx, (inputs, targets) in get_knn_iter( 70 | enumerate(trainloader), cmd_args.gpu 71 | ): 72 | # print mean and std as a debugging sanity check 73 | if batch_idx == 0: 74 | print("Eval data mean (should be near 0):", inputs.mean()) 75 | print("Eval data std (should be near 1):", inputs.std()) 76 | 77 | # targets = targets.cuda(async=True) 78 | batchSize = inputs.size(0) 79 | if cmd_args.gpu is not None: 80 | inputs = inputs.cuda(cmd_args.gpu) 81 | else: 82 | inputs = inputs.cuda() 83 | inputs = torch.nn.functional.interpolate( 84 | inputs, (eval_scale, eval_scale), mode="area" 85 | ) 86 | features = net( 87 | inputs, 88 | input_res=torch.ones(len(inputs)).float().to(inputs.device) * gsd_ratio, 89 | knn_feats=True, 90 | ) 91 | # breakpoint() 92 | trainFeatures[ 93 | :-1, batch_idx * batchSize : batch_idx * batchSize + batchSize 94 | ] = features.T 95 | trainFeatures[ 96 | -1, batch_idx * batchSize : batch_idx * batchSize + batchSize 97 | ] = targets 98 | 99 | if is_dist: 100 | print(f"distributed world size: {torch.distributed.get_world_size()}") 101 | trainFeatures = gather_from_all( 102 | trainFeatures.permute(1, 0).contiguous() 103 | ).permute(1, 0) 104 | 105 | if not hasattr(cmd_args, "gpu") or cmd_args.gpu is None: 106 | trainLabels = torch.flatten(trainFeatures[-1, :]).cuda() 107 | trainFeatures = trainFeatures[:-1, :].cuda() 108 | else: 109 | trainLabels = torch.flatten(trainFeatures[-1, :]).cuda(cmd_args.gpu) 110 | trainFeatures = trainFeatures[:-1, :].cuda(cmd_args.gpu) 111 | 112 | trainFeatures = torch.nn.functional.normalize(trainFeatures, dim=0) 113 | 114 | print( 115 | f"Grabbing all kNN training features took {(time.time() - st_time): .1f} seconds" 116 | ) 117 | print(f"Shape of final train features {trainFeatures.shape}") 118 | top1 = torch.FloatTensor([0.0]) 119 | total = torch.FloatTensor([0.0]) 120 | if cmd_args.gpu is not None: 121 | top1 = top1.cuda(cmd_args.gpu) 122 | total = total.cuda(cmd_args.gpu) 123 | else: 124 | top1 = top1.cuda() 125 | total = total.cuda() 126 | C = int(trainLabels.max() + 1) 127 | st_time = time.time() 128 | with torch.no_grad(): 129 | retrieval_one_hot = torch.zeros(cmd_args.knn, C).cuda() 130 | for batch_idx, (inputs, targets) in get_knn_iter( 131 | enumerate(testloader), cmd_args.gpu 132 | ): 133 | 134 | # targets = targets.cuda(async=True) 135 | batchSize = inputs.size(0) 136 | if cmd_args.gpu is not None: 137 | inputs = inputs.cuda(cmd_args.gpu) 138 | targets = targets.cuda(cmd_args.gpu) 139 | else: 140 | inputs = inputs.cuda() 141 | targets = targets.cuda() 142 | inputs = torch.nn.functional.interpolate( 143 | inputs, (eval_scale, eval_scale), mode="area" 144 | ) 145 | 146 | features = net( 147 | inputs, 148 | input_res=torch.ones(len(inputs)).float().to(inputs.device) * gsd_ratio, 149 | knn_feats=True, 150 | ) 151 | features = torch.nn.functional.normalize(features, dim=1) 152 | dist = torch.mm(features, trainFeatures) 153 | # if misc.is_main_process(): 154 | # breakpoint() 155 | yd, yi = dist.topk(cmd_args.knn, dim=1, largest=True, sorted=True) 156 | candidates = trainLabels.view(1, -1).expand(batchSize, -1) 157 | retrieval = torch.gather(candidates, 1, yi).long() 158 | 159 | retrieval_one_hot.resize_(batchSize * cmd_args.knn, C).zero_() 160 | retrieval_one_hot.scatter_(1, retrieval.view(-1, 1), 1) 161 | yd_transform = yd.clone().div_(sigma).exp_() 162 | probs = torch.sum( 163 | torch.mul( 164 | retrieval_one_hot.view(batchSize, -1, C), 165 | yd_transform.view(batchSize, -1, 1), 166 | ), 167 | 1, 168 | ) 169 | _, predictions = probs.sort(1, True) 170 | 171 | # Find which predictions match the target 172 | correct = predictions.eq(targets.data.view(-1, 1)) 173 | 174 | top1 = top1 + correct.narrow(1, 0, 1).sum().item() 175 | 176 | total += targets.size(0) 177 | 178 | if is_dist: 179 | all_reduce(top1) 180 | all_reduce(total) 181 | top1 = top1.detach().cpu().numpy().item() # sum 182 | total = total.detach().cpu().numpy().item() # sum 183 | 184 | return top1 / total 185 | -------------------------------------------------------------------------------- /mae/exps/encoder-decoder-vanilla.sh: -------------------------------------------------------------------------------- 1 | # rm -rf ./jobs/pretrain/* 2 | export JOB_DIR=./jobs/pretrain 3 | export IMAGENET_DIR=/shared/group/ilsvrc 4 | export CUDA_VISIBLE_DEVICES=5 5 | set -x 6 | CUDA_VISIBLE_DEVICES=3 python -m torch.distributed.launch --nproc_per_node=1 --master_port=11085 main_pretrain.py\ 7 | --batch_size 4 \ 8 | --model mae_vit_base_patch16 \ 9 | --mask_ratio 0.75 \ 10 | --num_workers 0 \ 11 | --epochs 300 \ 12 | --target_size 224\ 13 | --input_size 224\ 14 | --self_attention\ 15 | --scale_min 0.2 \ 16 | --scale_max 1.0 \ 17 | --output_dir /home/jacklishufan/exps/output_encoder_decoder_4\ 18 | --log_dir /home/jacklishufan/exps/output_encoder_decoder_4\ 19 | --warmup_epochs 40 \ 20 | --blr 1.5e-4 --weight_decay 0.05 \ 21 | --config config/naip.yaml \ 22 | --decoder_aux_loss_layers 1\ 23 | --decoder_mode encoder\ 24 | --target_size_scheduler constant\ 25 | --decoder_depth 8 \ 26 | --no_autoresume \ 27 | --use_mask_token \ 28 | --loss_masking\ 29 | --skip_knn_eval \ 30 | --fixed_output_size_min 224\ 31 | --fixed_output_size_max 336\ 32 | --eval_train_fnames /shared/jacklishufan/resisc45/train.txt\ 33 | --eval_val_fnames /shared/jacklishufan/resisc45/val.txt \ 34 | --independent_fcn_head \ 35 | --absolute_scale \ 36 | $@ \ 37 | 38 | # --resume /shared/jacklishufan/mae/mae_visualize_vit_base.pth \ 39 | # --restart \ 40 | # --use_mask_token \ 41 | # CUDA_VISIBLE_DEVICES=5 python -m torch.distributed.launch --nproc_per_node=1 main_pretrain.py\ 42 | # --job_dir ${JOB_DIR} \ 43 | # --nodes 1 \ 44 | # --ngpus 1 \ 45 | # --batch_size 4 \ 46 | # --model mae_vit_base_patch16 \ 47 | # --norm_pix_loss \ 48 | # --mask_ratio 0.75 \ 49 | # --epochs 100 \ 50 | # --warmup_epochs 40 \ 51 | # --blr 1.5e-4 --weight_decay 0.05 \ 52 | # --data_path ${IMAGENET_DIR} 53 | 54 | -------------------------------------------------------------------------------- /mae/helpers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bair-climate-initiative/scale-mae/89280d830037ff27c20459cdab03e01e633e29bb/mae/helpers/__init__.py -------------------------------------------------------------------------------- /mae/lib/fpn.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class Norm2d(nn.Module): 5 | def __init__(self, embed_dim): 6 | super().__init__() 7 | self.ln = nn.LayerNorm(embed_dim, eps=1e-6) 8 | 9 | def forward(self, x): 10 | x = x.permute(0, 2, 3, 1).contiguous() 11 | x = self.ln(x) 12 | x = x.permute(0, 3, 1, 2).contiguous() 13 | return x 14 | 15 | 16 | class FPNHead(nn.Module): 17 | def __init__(self, embed_dim, share_weights=False) -> None: 18 | super().__init__() 19 | self.share_weights = share_weights 20 | if self.share_weights: 21 | self.fpn1 = nn.Sequential( 22 | Norm2d(embed_dim), 23 | nn.GELU(), 24 | nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), 25 | ) 26 | self.do_fpn1 = lambda x: self.fpn1(self.fpn2(x)) 27 | else: 28 | self.fpn1 = nn.Sequential( 29 | nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), 30 | Norm2d(embed_dim), 31 | nn.GELU(), 32 | nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2), 33 | ) 34 | 35 | self.fpn2 = nn.Sequential( 36 | nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=2, stride=2) 37 | ) 38 | 39 | # self.fpn3 = nn.Identity() 40 | 41 | # self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2) 42 | 43 | def forward(self, x): 44 | """ 45 | InputL B X C X H X W 46 | """ 47 | features = [] 48 | if self.share_weights: 49 | ops = [ 50 | self.do_fpn1, 51 | self.fpn2, 52 | # self.fpn3, self.fpn4 53 | ] 54 | else: 55 | ops = [ 56 | self.fpn1, 57 | self.fpn2, 58 | # self.fpn3, self.fpn4 59 | ] 60 | for i in range(len(ops)): 61 | features.append(ops[i](x)) 62 | 63 | return tuple(features) 64 | 65 | 66 | class HFFB(nn.Module): 67 | def __init__(self, hidden_dim) -> None: 68 | super().__init__() 69 | self.convs = nn.Sequential( 70 | nn.GELU(), 71 | nn.Conv2d( 72 | hidden_dim, hidden_dim // 2, 3, padding=1, groups=hidden_dim // 2 73 | ), 74 | nn.GELU(), 75 | nn.Conv2d(hidden_dim // 2, hidden_dim, 1, padding=0), 76 | ) 77 | self.residual = nn.Conv2d(hidden_dim, hidden_dim, 1) 78 | 79 | def forward(self, x): 80 | return self.convs(x) + self.residual(x) 81 | 82 | 83 | class FCNHead(nn.Module): 84 | def __init__(self, embed_dim, hidden_dim, num_layers, target_dim) -> None: 85 | super().__init__() 86 | self.proj = nn.Conv2d(embed_dim, hidden_dim, 1) 87 | convs = [] 88 | for _ in range(num_layers): 89 | convs.append(HFFB(hidden_dim)) 90 | self.conv_blocks = nn.Sequential(*convs) 91 | self.pred = nn.Sequential( 92 | Norm2d(hidden_dim), 93 | nn.ConvTranspose2d(hidden_dim, hidden_dim // 2, kernel_size=4, stride=4), 94 | nn.GELU(), 95 | nn.Conv2d( 96 | hidden_dim // 2, hidden_dim // 4, 3, padding=1, groups=hidden_dim // 4 97 | ), 98 | nn.GELU(), 99 | nn.Conv2d(hidden_dim // 4, hidden_dim // 2, 1, padding=0), 100 | nn.GELU(), 101 | nn.ConvTranspose2d(hidden_dim // 2, 3, kernel_size=2, stride=2), 102 | ) 103 | 104 | def forward(self, xp): 105 | """ 106 | InputL List[B X C X H X W], FPN features 107 | """ 108 | out = [] 109 | for x in xp: 110 | x = self.proj(x) 111 | out.append(self.pred(self.conv_blocks(x))) 112 | 113 | return out 114 | -------------------------------------------------------------------------------- /mae/lib/gpt.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from timm.models.vision_transformer import DropPath, Mlp 3 | 4 | 5 | class Attention(nn.Module): 6 | def __init__( 7 | self, 8 | dim, 9 | num_heads=8, 10 | qkv_bias=False, 11 | qk_scale=None, 12 | attn_drop=0.0, 13 | proj_drop=0.0, 14 | ): 15 | super().__init__() 16 | self.num_heads = num_heads 17 | head_dim = dim // num_heads 18 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 19 | self.scale = qk_scale or head_dim**-0.5 20 | 21 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 22 | self.attn_drop = nn.Dropout(attn_drop) 23 | self.proj = nn.Linear(dim, dim) 24 | self.proj_drop = nn.Dropout(proj_drop) 25 | 26 | def forward(self, x, mask=None): 27 | B, N, C = x.shape 28 | qkv = ( 29 | self.qkv(x) 30 | .reshape(B, N, 3, self.num_heads, C // self.num_heads) 31 | .permute(2, 0, 3, 1, 4) 32 | ) 33 | q, k, v = ( 34 | qkv[0], 35 | qkv[1], 36 | qkv[2], 37 | ) # make torchscript happy (cannot use tensor as tuple) 38 | 39 | attn = (q @ k.transpose(-2, -1)) * self.scale 40 | if mask is not None: 41 | attn += mask 42 | attn = attn.softmax(dim=-1) 43 | attn = self.attn_drop(attn) 44 | 45 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 46 | x = self.proj(x) 47 | x = self.proj_drop(x) 48 | return x 49 | 50 | 51 | class Block(nn.Module): 52 | def __init__( 53 | self, 54 | dim, 55 | num_heads, 56 | mlp_ratio=4.0, 57 | qkv_bias=False, 58 | qk_scale=None, 59 | drop=0.0, 60 | attn_drop=0.0, 61 | drop_path=0.0, 62 | act_layer=nn.GELU, 63 | norm_layer=nn.LayerNorm, 64 | ): 65 | super().__init__() 66 | self.norm1 = norm_layer(dim) 67 | self.attn = Attention( 68 | dim, 69 | num_heads=num_heads, 70 | qkv_bias=qkv_bias, 71 | qk_scale=qk_scale, 72 | attn_drop=attn_drop, 73 | proj_drop=drop, 74 | ) 75 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 76 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 77 | self.norm2 = norm_layer(dim) 78 | mlp_hidden_dim = int(dim * mlp_ratio) 79 | self.mlp = Mlp( 80 | in_features=dim, 81 | hidden_features=mlp_hidden_dim, 82 | act_layer=act_layer, 83 | drop=drop, 84 | ) 85 | 86 | def forward(self, x, mask=None): 87 | x = x + self.drop_path(self.attn(self.norm1(x), mask)) 88 | x = x + self.drop_path(self.mlp(self.norm2(x))) 89 | return x 90 | -------------------------------------------------------------------------------- /mae/lib/scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class ResolutionScheduler: 5 | def __init__(self, *args, **kwargs): 6 | pass 7 | 8 | def get_target_size(self, epoch): 9 | raise NotImplemented 10 | 11 | 12 | class ConstantResolutionScheduler(ResolutionScheduler): 13 | def __init__(self, target_size): 14 | self.target_size = target_size 15 | 16 | def get_target_size(self, epoch): 17 | return self.target_size 18 | 19 | 20 | class RandomResolutionScheduler(ResolutionScheduler): 21 | def __init__(self, target_size, n=1): 22 | self.target_size = target_size 23 | self.n = n 24 | 25 | def get_target_size(self, epoch): 26 | return sorted(np.random.choice(self.target_size, self.n).tolist(), reverse=True) 27 | -------------------------------------------------------------------------------- /mae/lib/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Copy-paste from DETR repo 4 | ------------- 5 | DETR Transformer class. 6 | Copy-paste from torch.nn.Transformer with modifications: 7 | * positional encodings are passed in MHattention 8 | * extra LN at the end of encoder is removed 9 | * decoder returns a stack of activations from all decoding layers 10 | """ 11 | import copy 12 | from typing import List, Optional 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | from torch import Tensor, nn 17 | 18 | 19 | class Transformer(nn.Module): 20 | def __init__( 21 | self, 22 | d_model=512, 23 | nhead=8, 24 | num_encoder_layers=6, 25 | num_decoder_layers=6, 26 | dim_feedforward=2048, 27 | dropout=0.1, 28 | activation="relu", 29 | normalize_before=False, 30 | return_intermediate_dec=False, 31 | ): 32 | super().__init__() 33 | 34 | encoder_layer = TransformerEncoderLayer( 35 | d_model, nhead, dim_feedforward, dropout, activation, normalize_before 36 | ) 37 | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None 38 | self.encoder = TransformerEncoder( 39 | encoder_layer, num_encoder_layers, encoder_norm 40 | ) 41 | 42 | decoder_layer = TransformerDecoderLayer( 43 | d_model, nhead, dim_feedforward, dropout, activation, normalize_before 44 | ) 45 | decoder_norm = nn.LayerNorm(d_model) 46 | self.decoder = TransformerDecoder( 47 | decoder_layer, 48 | num_decoder_layers, 49 | decoder_norm, 50 | return_intermediate=return_intermediate_dec, 51 | ) 52 | 53 | self._reset_parameters() 54 | 55 | self.d_model = d_model 56 | self.nhead = nhead 57 | 58 | def _reset_parameters(self): 59 | for p in self.parameters(): 60 | if p.dim() > 1: 61 | nn.init.xavier_uniform_(p) 62 | 63 | def forward(self, src, mask, query_embed, pos_embed): 64 | # flatten NxCxHxW to HWxNxC 65 | bs, c, h, w = src.shape 66 | src = src.flatten(2).permute(2, 0, 1) 67 | pos_embed = pos_embed.flatten(2).permute(2, 0, 1) 68 | query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) 69 | mask = mask.flatten(1) 70 | 71 | tgt = torch.zeros_like(query_embed) 72 | memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) 73 | hs = self.decoder( 74 | tgt, 75 | memory, 76 | memory_key_padding_mask=mask, 77 | pos=pos_embed, 78 | query_pos=query_embed, 79 | ) 80 | return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) 81 | 82 | 83 | class TransformerEncoder(nn.Module): 84 | def __init__(self, encoder_layer, num_layers, norm=None): 85 | super().__init__() 86 | self.layers = _get_clones(encoder_layer, num_layers) 87 | self.num_layers = num_layers 88 | self.norm = norm 89 | 90 | def forward( 91 | self, 92 | src, 93 | mask: Optional[Tensor] = None, 94 | src_key_padding_mask: Optional[Tensor] = None, 95 | pos: Optional[Tensor] = None, 96 | ): 97 | output = src 98 | 99 | for layer in self.layers: 100 | output = layer( 101 | output, 102 | src_mask=mask, 103 | src_key_padding_mask=src_key_padding_mask, 104 | pos=pos, 105 | ) 106 | 107 | if self.norm is not None: 108 | output = self.norm(output) 109 | 110 | return output 111 | 112 | 113 | class TransformerDecoder(nn.Module): 114 | def __init__( 115 | self, 116 | decoder_layer, 117 | num_layers, 118 | norm=None, 119 | return_intermediate=False, 120 | return_layers=0, 121 | ): 122 | super().__init__() 123 | self.layers = _get_clones(decoder_layer, num_layers) 124 | self.num_layers = num_layers 125 | self.norm = norm 126 | self.return_intermediate = return_intermediate 127 | self.return_layers = return_layers # 0 to return all layers 128 | 129 | def forward( 130 | self, 131 | tgt, 132 | memory, 133 | tgt_mask: Optional[Tensor] = None, 134 | memory_mask: Optional[Tensor] = None, 135 | tgt_key_padding_mask: Optional[Tensor] = None, 136 | memory_key_padding_mask: Optional[Tensor] = None, 137 | pos: Optional[Tensor] = None, 138 | query_pos: Optional[Tensor] = None, 139 | ): 140 | output = tgt 141 | 142 | intermediate = [] 143 | 144 | for layer in self.layers: 145 | output = layer( 146 | output, 147 | memory, 148 | tgt_mask=tgt_mask, 149 | memory_mask=memory_mask, 150 | tgt_key_padding_mask=tgt_key_padding_mask, 151 | memory_key_padding_mask=memory_key_padding_mask, 152 | pos=pos, 153 | query_pos=query_pos, 154 | ) 155 | if self.return_intermediate: 156 | intermediate.append(self.norm(output)) 157 | 158 | if self.norm is not None: 159 | output = self.norm(output) 160 | if self.return_intermediate: 161 | intermediate.pop() 162 | intermediate.append(output) 163 | 164 | if self.return_intermediate: 165 | stacked_output = torch.stack(intermediate) 166 | if self.return_layers > 0: 167 | stacked_output = stacked_output[-self.return_layers :] 168 | return stacked_output 169 | 170 | return output.unsqueeze(0) 171 | 172 | 173 | class MAEDecoder(nn.Module): 174 | def __init__( 175 | self, 176 | d_model=512, 177 | nhead=8, 178 | num_decoder_layers=6, 179 | dim_feedforward=2048, 180 | dropout=0.1, 181 | activation="relu", 182 | normalize_before=False, 183 | return_intermediate_dec=False, 184 | return_layers=0, 185 | ) -> None: 186 | super().__init__() 187 | decoder_layer = TransformerDecoderLayer( 188 | d_model, nhead, dim_feedforward, dropout, activation, normalize_before 189 | ) 190 | decoder_norm = nn.LayerNorm(d_model) 191 | self.decoder = TransformerDecoder( 192 | decoder_layer, 193 | num_decoder_layers, 194 | decoder_norm, 195 | return_intermediate=return_intermediate_dec, 196 | return_layers=return_layers, 197 | ) 198 | 199 | def forward(self, x, tgt): 200 | """ 201 | x: N X L X d_emb 202 | tgt: N X T X d_emb 203 | out: T X N X d_emb or N_layers X T X N X d_emb 204 | """ 205 | n, k, d_emb = x.shape 206 | x = x.permute(1, 0, 2) # L X N X d_emb 207 | tgt = tgt.permute(1, 0, 2) 208 | x = self.decoder(tgt, x) # N_Layer X T X N X d_emb 209 | return x 210 | 211 | 212 | class TransformerEncoderLayer(nn.Module): 213 | def __init__( 214 | self, 215 | d_model, 216 | nhead, 217 | dim_feedforward=2048, 218 | dropout=0.1, 219 | activation="relu", 220 | normalize_before=False, 221 | ): 222 | super().__init__() 223 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 224 | # Implementation of Feedforward model 225 | self.linear1 = nn.Linear(d_model, dim_feedforward) 226 | self.dropout = nn.Dropout(dropout) 227 | self.linear2 = nn.Linear(dim_feedforward, d_model) 228 | 229 | self.norm1 = nn.LayerNorm(d_model) 230 | self.norm2 = nn.LayerNorm(d_model) 231 | self.dropout1 = nn.Dropout(dropout) 232 | self.dropout2 = nn.Dropout(dropout) 233 | 234 | self.activation = _get_activation_fn(activation) 235 | self.normalize_before = normalize_before 236 | 237 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 238 | return tensor if pos is None else tensor + pos 239 | 240 | def forward_post( 241 | self, 242 | src, 243 | src_mask: Optional[Tensor] = None, 244 | src_key_padding_mask: Optional[Tensor] = None, 245 | pos: Optional[Tensor] = None, 246 | ): 247 | q = k = self.with_pos_embed(src, pos) 248 | src2 = self.self_attn( 249 | q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask 250 | )[0] 251 | src = src + self.dropout1(src2) 252 | src = self.norm1(src) 253 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 254 | src = src + self.dropout2(src2) 255 | src = self.norm2(src) 256 | return src 257 | 258 | def forward_pre( 259 | self, 260 | src, 261 | src_mask: Optional[Tensor] = None, 262 | src_key_padding_mask: Optional[Tensor] = None, 263 | pos: Optional[Tensor] = None, 264 | ): 265 | src2 = self.norm1(src) 266 | q = k = self.with_pos_embed(src2, pos) 267 | src2 = self.self_attn( 268 | q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask 269 | )[0] 270 | src = src + self.dropout1(src2) 271 | src2 = self.norm2(src) 272 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) 273 | src = src + self.dropout2(src2) 274 | return src 275 | 276 | def forward( 277 | self, 278 | src, 279 | src_mask: Optional[Tensor] = None, 280 | src_key_padding_mask: Optional[Tensor] = None, 281 | pos: Optional[Tensor] = None, 282 | ): 283 | if self.normalize_before: 284 | return self.forward_pre(src, src_mask, src_key_padding_mask, pos) 285 | return self.forward_post(src, src_mask, src_key_padding_mask, pos) 286 | 287 | 288 | class TransformerDecoderLayer(nn.Module): 289 | def __init__( 290 | self, 291 | d_model, 292 | nhead, 293 | dim_feedforward=2048, 294 | dropout=0.1, 295 | activation="relu", 296 | normalize_before=False, 297 | ): 298 | super().__init__() 299 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 300 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 301 | # Implementation of Feedforward model 302 | self.linear1 = nn.Linear(d_model, dim_feedforward) 303 | self.dropout = nn.Dropout(dropout) 304 | self.linear2 = nn.Linear(dim_feedforward, d_model) 305 | 306 | self.norm1 = nn.LayerNorm(d_model) 307 | self.norm2 = nn.LayerNorm(d_model) 308 | self.norm3 = nn.LayerNorm(d_model) 309 | self.dropout1 = nn.Dropout(dropout) 310 | self.dropout2 = nn.Dropout(dropout) 311 | self.dropout3 = nn.Dropout(dropout) 312 | 313 | self.activation = _get_activation_fn(activation) 314 | self.normalize_before = normalize_before 315 | 316 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 317 | return tensor if pos is None else tensor + pos 318 | 319 | def forward_post( 320 | self, 321 | tgt, 322 | memory, 323 | tgt_mask: Optional[Tensor] = None, 324 | memory_mask: Optional[Tensor] = None, 325 | tgt_key_padding_mask: Optional[Tensor] = None, 326 | memory_key_padding_mask: Optional[Tensor] = None, 327 | pos: Optional[Tensor] = None, 328 | query_pos: Optional[Tensor] = None, 329 | ): 330 | q = k = self.with_pos_embed(tgt, query_pos) 331 | tgt2 = self.self_attn( 332 | q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask 333 | )[0] 334 | tgt = tgt + self.dropout1(tgt2) 335 | tgt = self.norm1(tgt) 336 | tgt2 = self.multihead_attn( 337 | query=self.with_pos_embed(tgt, query_pos), 338 | key=self.with_pos_embed(memory, pos), 339 | value=memory, 340 | attn_mask=memory_mask, 341 | key_padding_mask=memory_key_padding_mask, 342 | )[0] 343 | tgt = tgt + self.dropout2(tgt2) 344 | tgt = self.norm2(tgt) 345 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 346 | tgt = tgt + self.dropout3(tgt2) 347 | tgt = self.norm3(tgt) 348 | return tgt 349 | 350 | def forward_pre( 351 | self, 352 | tgt, 353 | memory, 354 | tgt_mask: Optional[Tensor] = None, 355 | memory_mask: Optional[Tensor] = None, 356 | tgt_key_padding_mask: Optional[Tensor] = None, 357 | memory_key_padding_mask: Optional[Tensor] = None, 358 | pos: Optional[Tensor] = None, 359 | query_pos: Optional[Tensor] = None, 360 | ): 361 | tgt2 = self.norm1(tgt) 362 | q = k = self.with_pos_embed(tgt2, query_pos) 363 | tgt2 = self.self_attn( 364 | q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask 365 | )[0] 366 | tgt = tgt + self.dropout1(tgt2) 367 | tgt2 = self.norm2(tgt) 368 | tgt2 = self.multihead_attn( 369 | query=self.with_pos_embed(tgt2, query_pos), 370 | key=self.with_pos_embed(memory, pos), 371 | value=memory, 372 | attn_mask=memory_mask, 373 | key_padding_mask=memory_key_padding_mask, 374 | )[0] 375 | tgt = tgt + self.dropout2(tgt2) 376 | tgt2 = self.norm3(tgt) 377 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 378 | tgt = tgt + self.dropout3(tgt2) 379 | return tgt 380 | 381 | def forward( 382 | self, 383 | tgt, 384 | memory, 385 | tgt_mask: Optional[Tensor] = None, 386 | memory_mask: Optional[Tensor] = None, 387 | tgt_key_padding_mask: Optional[Tensor] = None, 388 | memory_key_padding_mask: Optional[Tensor] = None, 389 | pos: Optional[Tensor] = None, 390 | query_pos: Optional[Tensor] = None, 391 | ): 392 | if self.normalize_before: 393 | return self.forward_pre( 394 | tgt, 395 | memory, 396 | tgt_mask, 397 | memory_mask, 398 | tgt_key_padding_mask, 399 | memory_key_padding_mask, 400 | pos, 401 | query_pos, 402 | ) 403 | return self.forward_post( 404 | tgt, 405 | memory, 406 | tgt_mask, 407 | memory_mask, 408 | tgt_key_padding_mask, 409 | memory_key_padding_mask, 410 | pos, 411 | query_pos, 412 | ) 413 | 414 | 415 | def _get_clones(module, N): 416 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 417 | 418 | 419 | def build_transformer(args): 420 | return Transformer( 421 | d_model=args.hidden_dim, 422 | dropout=args.dropout, 423 | nhead=args.nheads, 424 | dim_feedforward=args.dim_feedforward, 425 | num_encoder_layers=args.enc_layers, 426 | num_decoder_layers=args.dec_layers, 427 | normalize_before=args.pre_norm, 428 | return_intermediate_dec=True, 429 | ) 430 | 431 | 432 | def _get_activation_fn(activation): 433 | """Return an activation function given a string""" 434 | if activation == "relu": 435 | return F.relu 436 | if activation == "gelu": 437 | return F.gelu 438 | if activation == "glu": 439 | return F.glu 440 | raise RuntimeError(f"activation should be relu/gelu, not {activation}.") 441 | -------------------------------------------------------------------------------- /mae/lib/transforms.py: -------------------------------------------------------------------------------- 1 | from kornia.augmentation import RandomGaussianBlur 2 | 3 | 4 | class CustomCompose: 5 | def __init__(self, rescale_transform, other_transforms, src_transform): 6 | self.rescale_transform = rescale_transform 7 | self.other_transforms = other_transforms 8 | self.src_transform = src_transform 9 | 10 | def __call__(self, x, valid_masks=None): 11 | if valid_masks is not None: 12 | nodata = (x * (1 - valid_masks.float())).max() 13 | x_aug = self.rescale_transform(x) 14 | parms = self.rescale_transform._params 15 | # sanity check, comment if this is working 16 | # valid_masks = self.rescale_transform(valid_masks.float(), params=parms) 17 | # assert (x_aug==self.rescale_transform(x, params=parms)).all() # 18 | 19 | if valid_masks is not None: 20 | valid_masks = x_aug != nodata 21 | _, c, h, w = x_aug.shape 22 | zero_ratio = ((valid_masks == 0).sum((1, 2, 3)) / (h * w * c)).cpu().numpy() 23 | else: 24 | zero_ratio = -1 25 | 26 | if self.other_transforms: 27 | x_aug = self.other_transforms(x_aug) 28 | x_src = self.src_transform(x_aug) 29 | dx = parms["src"][:, 1, 0] - parms["src"][:, 0, 0] 30 | # dy = (parms['src'][:,2,1] - parms['src'][:,1,1]) 31 | # assert (dx == dy).all() 32 | h, w = x_aug.shape[-2:] 33 | # assert h == w 34 | return x_aug, x_src, dx / h, zero_ratio, valid_masks 35 | 36 | 37 | blur = RandomGaussianBlur((3, 3), (2.0, 2.0), p=0.5) 38 | 39 | 40 | def get_inputs_outputs(img, res, target=None, target_res=None, strategy="naive"): 41 | # TODO: More strategies 42 | if target is not None: 43 | return img, res, target, target_res 44 | else: 45 | target = img 46 | target_res = res 47 | img = blur(img) 48 | return img, res, target, target_res 49 | -------------------------------------------------------------------------------- /mae/main_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | import argparse 12 | import datetime 13 | import json 14 | import os 15 | import sys 16 | import time 17 | from pathlib import Path 18 | from re import L 19 | 20 | import numpy as np 21 | import timm 22 | import torch 23 | import torch.backends.cudnn as cudnn 24 | import torchvision.transforms as tv_transforms 25 | import wandb 26 | from torch.utils.tensorboard import SummaryWriter 27 | 28 | import os 29 | import tempfile 30 | 31 | import kornia.augmentation as K 32 | import matplotlib.pyplot as plt 33 | import models_mae 34 | import numpy as np 35 | import timm.optim.optim_factory as optim_factory 36 | import util.misc as misc 37 | import yaml 38 | from dataloaders.resic45 import build_resic 39 | from dataloaders.utils import get_dataset_and_sampler 40 | from engine_pretrain import train_one_epoch 41 | from eval.knn import kNN 42 | from kornia.augmentation import AugmentationSequential 43 | from kornia.constants import Resample 44 | from lib.scheduler import ConstantResolutionScheduler, RandomResolutionScheduler 45 | from lib.transforms import CustomCompose 46 | from PIL import Image 47 | from torch.distributed.elastic.multiprocessing.errors import record 48 | from torch.utils.data import DataLoader, Subset 49 | from torchgeo.datasets import NAIP, stack_samples 50 | from torchgeo.datasets.utils import download_url 51 | from torchgeo.samplers import RandomGeoSampler, Units 52 | from torchvision import transforms 53 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 54 | from util.misc import is_main_process 55 | from wandb_log import WANDB_LOG_IMG_CONFIG 56 | 57 | Image.MAX_IMAGE_PIXELS = 1000000000 58 | -------------------------------------------------------------------------------- /mae/main_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import argparse 13 | import datetime 14 | import json 15 | import os 16 | import time 17 | from pathlib import Path 18 | 19 | import numpy as np 20 | import timm 21 | import torch 22 | import torch.backends.cudnn as cudnn 23 | from torch.utils.tensorboard import SummaryWriter 24 | 25 | import models_vit 26 | import util.lr_decay as lrd 27 | import util.misc as misc 28 | from engine_finetune import evaluate, train_one_epoch 29 | from timm.data.mixup import Mixup 30 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 31 | from timm.models.layers import trunc_normal_ 32 | from util.datasets import build_dataset 33 | from util.misc import NativeScalerWithGradNormCount as NativeScaler 34 | from util.pos_embed import interpolate_pos_embed 35 | 36 | 37 | def get_args_parser(): 38 | parser = argparse.ArgumentParser( 39 | "MAE fine-tuning for image classification", add_help=False 40 | ) 41 | parser.add_argument( 42 | "--batch_size", 43 | default=64, 44 | type=int, 45 | help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus", 46 | ) 47 | parser.add_argument("--epochs", default=50, type=int) 48 | parser.add_argument( 49 | "--accum_iter", 50 | default=1, 51 | type=int, 52 | help="Accumulate gradient iterations (for increasing the effective batch size under memory constraints)", 53 | ) 54 | 55 | # Model parameters 56 | parser.add_argument( 57 | "--model", 58 | default="vit_large_patch16", 59 | type=str, 60 | metavar="MODEL", 61 | help="Name of model to train", 62 | ) 63 | 64 | parser.add_argument("--input_size", default=224, type=int, help="images input size") 65 | 66 | parser.add_argument( 67 | "--drop_path", 68 | type=float, 69 | default=0.1, 70 | metavar="PCT", 71 | help="Drop path rate (default: 0.1)", 72 | ) 73 | 74 | # Optimizer parameters 75 | parser.add_argument( 76 | "--clip_grad", 77 | type=float, 78 | default=None, 79 | metavar="NORM", 80 | help="Clip gradient norm (default: None, no clipping)", 81 | ) 82 | parser.add_argument( 83 | "--weight_decay", type=float, default=0.05, help="weight decay (default: 0.05)" 84 | ) 85 | 86 | parser.add_argument( 87 | "--lr", 88 | type=float, 89 | default=None, 90 | metavar="LR", 91 | help="learning rate (absolute lr)", 92 | ) 93 | parser.add_argument( 94 | "--blr", 95 | type=float, 96 | default=1e-3, 97 | metavar="LR", 98 | help="base learning rate: absolute_lr = base_lr * total_batch_size / 256", 99 | ) 100 | parser.add_argument( 101 | "--layer_decay", 102 | type=float, 103 | default=0.75, 104 | help="layer-wise lr decay from ELECTRA/BEiT", 105 | ) 106 | 107 | parser.add_argument( 108 | "--min_lr", 109 | type=float, 110 | default=1e-6, 111 | metavar="LR", 112 | help="lower lr bound for cyclic schedulers that hit 0", 113 | ) 114 | 115 | parser.add_argument( 116 | "--warmup_epochs", type=int, default=5, metavar="N", help="epochs to warmup LR" 117 | ) 118 | 119 | # Augmentation parameters 120 | parser.add_argument( 121 | "--color_jitter", 122 | type=float, 123 | default=None, 124 | metavar="PCT", 125 | help="Color jitter factor (enabled only when not using Auto/RandAug)", 126 | ) 127 | parser.add_argument( 128 | "--aa", 129 | type=str, 130 | default="rand-m9-mstd0.5-inc1", 131 | metavar="NAME", 132 | help='Use AutoAugment policy. "v0" or "original". " + "(default: rand-m9-mstd0.5-inc1)', 133 | ), 134 | parser.add_argument( 135 | "--smoothing", type=float, default=0.1, help="Label smoothing (default: 0.1)" 136 | ) 137 | 138 | # * Random Erase params 139 | parser.add_argument( 140 | "--reprob", 141 | type=float, 142 | default=0.25, 143 | metavar="PCT", 144 | help="Random erase prob (default: 0.25)", 145 | ) 146 | parser.add_argument( 147 | "--remode", 148 | type=str, 149 | default="pixel", 150 | help='Random erase mode (default: "pixel")', 151 | ) 152 | parser.add_argument( 153 | "--recount", type=int, default=1, help="Random erase count (default: 1)" 154 | ) 155 | parser.add_argument( 156 | "--resplit", 157 | action="store_true", 158 | default=False, 159 | help="Do not random erase first (clean) augmentation split", 160 | ) 161 | 162 | # * Mixup params 163 | parser.add_argument( 164 | "--mixup", type=float, default=0, help="mixup alpha, mixup enabled if > 0." 165 | ) 166 | parser.add_argument( 167 | "--cutmix", type=float, default=0, help="cutmix alpha, cutmix enabled if > 0." 168 | ) 169 | parser.add_argument( 170 | "--cutmix_minmax", 171 | type=float, 172 | nargs="+", 173 | default=None, 174 | help="cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)", 175 | ) 176 | parser.add_argument( 177 | "--mixup_prob", 178 | type=float, 179 | default=1.0, 180 | help="Probability of performing mixup or cutmix when either/both is enabled", 181 | ) 182 | parser.add_argument( 183 | "--mixup_switch_prob", 184 | type=float, 185 | default=0.5, 186 | help="Probability of switching to cutmix when both mixup and cutmix enabled", 187 | ) 188 | parser.add_argument( 189 | "--mixup_mode", 190 | type=str, 191 | default="batch", 192 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"', 193 | ) 194 | 195 | # * Finetuning params 196 | parser.add_argument("--finetune", default="", help="finetune from checkpoint") 197 | parser.add_argument("--global_pool", action="store_true") 198 | parser.set_defaults(global_pool=True) 199 | parser.add_argument( 200 | "--cls_token", 201 | action="store_false", 202 | dest="global_pool", 203 | help="Use class token instead of global pool for classification", 204 | ) 205 | 206 | # Dataset parameters 207 | parser.add_argument( 208 | "--data_path", 209 | default="/datasets01/imagenet_full_size/061417/", 210 | type=str, 211 | help="dataset path", 212 | ) 213 | parser.add_argument( 214 | "--nb_classes", 215 | default=1000, 216 | type=int, 217 | help="number of the classification types", 218 | ) 219 | 220 | parser.add_argument( 221 | "--output_dir", 222 | default="./output_dir", 223 | help="path where to save, empty for no saving", 224 | ) 225 | parser.add_argument( 226 | "--log_dir", default="./output_dir", help="path where to tensorboard log" 227 | ) 228 | parser.add_argument( 229 | "--device", default="cuda", help="device to use for training / testing" 230 | ) 231 | parser.add_argument("--seed", default=0, type=int) 232 | parser.add_argument("--resume", default="", help="resume from checkpoint") 233 | 234 | parser.add_argument( 235 | "--start_epoch", default=0, type=int, metavar="N", help="start epoch" 236 | ) 237 | parser.add_argument("--eval", action="store_true", help="Perform evaluation only") 238 | parser.add_argument( 239 | "--dist_eval", 240 | action="store_true", 241 | default=False, 242 | help="Enabling distributed evaluation (recommended during training for faster monitor", 243 | ) 244 | parser.add_argument("--num_workers", default=10, type=int) 245 | parser.add_argument( 246 | "--pin_mem", 247 | action="store_true", 248 | help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.", 249 | ) 250 | parser.add_argument("--no_pin_mem", action="store_false", dest="pin_mem") 251 | parser.set_defaults(pin_mem=True) 252 | 253 | # distributed training parameters 254 | parser.add_argument( 255 | "--world_size", default=1, type=int, help="number of distributed processes" 256 | ) 257 | parser.add_argument("--local_rank", default=-1, type=int) 258 | parser.add_argument("--dist_on_itp", action="store_true") 259 | parser.add_argument( 260 | "--dist_url", default="env://", help="url used to set up distributed training" 261 | ) 262 | 263 | return parser 264 | 265 | 266 | def main(args): 267 | misc.init_distributed_mode(args) 268 | 269 | print(f"job dir: {os.path.dirname(os.path.realpath(__file__))}") 270 | print(f"{args}".replace(", ", ",\n")) 271 | 272 | device = torch.device(args.device) 273 | 274 | # fix the seed for reproducibility 275 | seed = args.seed + misc.get_rank() 276 | torch.manual_seed(seed) 277 | np.random.seed(seed) 278 | 279 | cudnn.benchmark = True 280 | 281 | dataset_train = build_dataset(is_train=True, args=args) 282 | dataset_val = build_dataset(is_train=False, args=args) 283 | 284 | if True: # args.distributed: 285 | num_tasks = misc.get_world_size() 286 | global_rank = misc.get_rank() 287 | sampler_train = torch.utils.data.DistributedSampler( 288 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 289 | ) 290 | print("Sampler_train = %s" % str(sampler_train)) 291 | if args.dist_eval: 292 | if len(dataset_val) % num_tasks != 0: 293 | print( 294 | "Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. " 295 | "This will slightly alter validation results as extra duplicate entries are added to achieve " 296 | "equal num of samples per-process." 297 | ) 298 | sampler_val = torch.utils.data.DistributedSampler( 299 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True 300 | ) # shuffle=True to reduce monitor bias 301 | else: 302 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 303 | else: 304 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 305 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 306 | 307 | if global_rank == 0 and args.log_dir is not None and not args.eval: 308 | os.makedirs(args.log_dir, exist_ok=True) 309 | log_writer = SummaryWriter(log_dir=args.log_dir) 310 | else: 311 | log_writer = None 312 | 313 | data_loader_train = torch.utils.data.DataLoader( 314 | dataset_train, 315 | sampler=sampler_train, 316 | batch_size=args.batch_size, 317 | num_workers=args.num_workers, 318 | pin_memory=args.pin_mem, 319 | drop_last=True, 320 | ) 321 | 322 | data_loader_val = torch.utils.data.DataLoader( 323 | dataset_val, 324 | sampler=sampler_val, 325 | batch_size=args.batch_size, 326 | num_workers=args.num_workers, 327 | pin_memory=args.pin_mem, 328 | drop_last=False, 329 | ) 330 | 331 | mixup_fn = None 332 | mixup_active = args.mixup > 0 or args.cutmix > 0.0 or args.cutmix_minmax is not None 333 | if mixup_active: 334 | print("Mixup is activated!") 335 | mixup_fn = Mixup( 336 | mixup_alpha=args.mixup, 337 | cutmix_alpha=args.cutmix, 338 | cutmix_minmax=args.cutmix_minmax, 339 | prob=args.mixup_prob, 340 | switch_prob=args.mixup_switch_prob, 341 | mode=args.mixup_mode, 342 | label_smoothing=args.smoothing, 343 | num_classes=args.nb_classes, 344 | ) 345 | 346 | model = models_vit.__dict__[args.model]( 347 | num_classes=args.nb_classes, 348 | drop_path_rate=args.drop_path, 349 | global_pool=args.global_pool, 350 | ) 351 | 352 | if args.finetune and not args.eval: 353 | checkpoint = torch.load(args.finetune, map_location="cpu") 354 | 355 | print("Load pre-trained checkpoint from: %s" % args.finetune) 356 | checkpoint_model = checkpoint["model"] 357 | state_dict = model.state_dict() 358 | for k in ["head.weight", "head.bias"]: 359 | if ( 360 | k in checkpoint_model 361 | and checkpoint_model[k].shape != state_dict[k].shape 362 | ): 363 | print(f"Removing key {k} from pretrained checkpoint") 364 | del checkpoint_model[k] 365 | 366 | # interpolate position embedding 367 | interpolate_pos_embed(model, checkpoint_model) 368 | 369 | # load pre-trained model 370 | msg = model.load_state_dict(checkpoint_model, strict=False) 371 | print(msg) 372 | 373 | if args.global_pool: 374 | assert set(msg.missing_keys) == { 375 | "head.weight", 376 | "head.bias", 377 | "fc_norm.weight", 378 | "fc_norm.bias", 379 | } 380 | else: 381 | assert set(msg.missing_keys) == {"head.weight", "head.bias"} 382 | 383 | # manually initialize fc layer 384 | trunc_normal_(model.head.weight, std=2e-5) 385 | 386 | model.to(device) 387 | 388 | model_without_ddp = model 389 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 390 | 391 | print("Model = %s" % str(model_without_ddp)) 392 | print("number of params (M): %.2f" % (n_parameters / 1.0e6)) 393 | 394 | eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size() 395 | 396 | if args.lr is None: # only base_lr is specified 397 | args.lr = args.blr * eff_batch_size / 256 398 | 399 | print("base lr: %.2e" % (args.lr * 256 / eff_batch_size)) 400 | print("actual lr: %.2e" % args.lr) 401 | 402 | print("accumulate grad iterations: %d" % args.accum_iter) 403 | print("effective batch size: %d" % eff_batch_size) 404 | 405 | if args.distributed: 406 | model = torch.nn.parallel.DistributedDataParallel( 407 | model, device_ids=[args.gpu], find_unused_parameters=True 408 | ) 409 | model_without_ddp = model.module 410 | 411 | # build optimizer with layer-wise lr decay (lrd) 412 | param_groups = lrd.param_groups_lrd( 413 | model_without_ddp, 414 | args.weight_decay, 415 | no_weight_decay_list=model_without_ddp.no_weight_decay(), 416 | layer_decay=args.layer_decay, 417 | ) 418 | optimizer = torch.optim.AdamW(param_groups, lr=args.lr) 419 | loss_scaler = NativeScaler() 420 | 421 | if mixup_fn is not None: 422 | # smoothing is handled with mixup label transform 423 | criterion = SoftTargetCrossEntropy() 424 | elif args.smoothing > 0.0: 425 | criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 426 | else: 427 | criterion = torch.nn.CrossEntropyLoss() 428 | 429 | print("criterion = %s" % str(criterion)) 430 | 431 | misc.load_model( 432 | args=args, 433 | model_without_ddp=model_without_ddp, 434 | optimizer=optimizer, 435 | loss_scaler=loss_scaler, 436 | ) 437 | 438 | if args.eval: 439 | test_stats = evaluate(data_loader_val, model, device) 440 | print( 441 | f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%" 442 | ) 443 | exit(0) 444 | 445 | print(f"Start training for {args.epochs} epochs") 446 | start_time = time.time() 447 | max_accuracy = 0.0 448 | for epoch in range(args.start_epoch, args.epochs): 449 | if args.distributed: 450 | data_loader_train.sampler.set_epoch(epoch) 451 | train_stats = train_one_epoch( 452 | model, 453 | criterion, 454 | data_loader_train, 455 | optimizer, 456 | device, 457 | epoch, 458 | loss_scaler, 459 | args.clip_grad, 460 | mixup_fn, 461 | log_writer=log_writer, 462 | args=args, 463 | ) 464 | if args.output_dir: 465 | misc.save_model( 466 | args=args, 467 | model=model, 468 | model_without_ddp=model_without_ddp, 469 | optimizer=optimizer, 470 | loss_scaler=loss_scaler, 471 | epoch=epoch, 472 | ) 473 | 474 | test_stats = evaluate(data_loader_val, model, device) 475 | print( 476 | f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%" 477 | ) 478 | max_accuracy = max(max_accuracy, test_stats["acc1"]) 479 | print(f"Max accuracy: {max_accuracy:.2f}%") 480 | 481 | if log_writer is not None: 482 | log_writer.add_scalar("perf/test_acc1", test_stats["acc1"], epoch) 483 | log_writer.add_scalar("perf/test_acc5", test_stats["acc5"], epoch) 484 | log_writer.add_scalar("perf/test_loss", test_stats["loss"], epoch) 485 | 486 | log_stats = { 487 | **{f"train_{k}": v for k, v in train_stats.items()}, 488 | **{f"test_{k}": v for k, v in test_stats.items()}, 489 | "epoch": epoch, 490 | "n_parameters": n_parameters, 491 | } 492 | 493 | if args.output_dir and misc.is_main_process(): 494 | if log_writer is not None: 495 | log_writer.flush() 496 | with open( 497 | os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8" 498 | ) as f: 499 | f.write(json.dumps(log_stats) + "\n") 500 | 501 | total_time = time.time() - start_time 502 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 503 | print(f"Training time {total_time_str}") 504 | 505 | 506 | if __name__ == "__main__": 507 | args = get_args_parser() 508 | args = args.parse_args() 509 | if args.output_dir: 510 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 511 | main(args) 512 | -------------------------------------------------------------------------------- /mae/models_vit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 9 | # DeiT: https://github.com/facebookresearch/deit 10 | # -------------------------------------------------------- 11 | 12 | from functools import partial 13 | 14 | import timm.models.vision_transformer 15 | import torch 16 | import torch.nn as nn 17 | from timm.models.vision_transformer import Block, PatchEmbed 18 | from util.pos_embed import get_2d_sincos_pos_embed_with_resolution 19 | 20 | 21 | class PatchEmbedUnSafe(PatchEmbed): 22 | """Image to Patch Embedding""" 23 | 24 | def forward(self, x): 25 | B, C, H, W = x.shape 26 | # Dropped size check in timm 27 | # assert H == self.img_size[0] and W == self.img_size[1], \ 28 | # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 29 | x = self.proj(x).flatten(2).transpose(1, 2) 30 | return x 31 | 32 | 33 | class VisionTransformer(timm.models.vision_transformer.VisionTransformer): 34 | """Vision Transformer with support for global average pooling""" 35 | 36 | def __init__( 37 | self, global_pool=False, patch_size=16, in_chans=3, embed_dim=1024, **kwargs 38 | ): 39 | super().__init__(embed_dim=embed_dim, **kwargs) 40 | 41 | self.patch_embed = PatchEmbedUnSafe( 42 | img_size=kwargs["img_size"], 43 | patch_size=patch_size, 44 | in_chans=in_chans, 45 | embed_dim=embed_dim, 46 | ) 47 | 48 | self.global_pool = global_pool 49 | if self.global_pool: 50 | norm_layer = kwargs["norm_layer"] 51 | embed_dim = embed_dim 52 | self.fc_norm = norm_layer(embed_dim) 53 | 54 | del self.norm # remove the original norm 55 | 56 | def forward_features(self, x, input_res=None): 57 | B, _, h, w = x.shape 58 | x = self.patch_embed(x) 59 | input_res = input_res.cpu() 60 | 61 | num_patches = int( 62 | (h * w) / (self.patch_embed.patch_size[0] * self.patch_embed.patch_size[1]) 63 | ) 64 | pos_embed = get_2d_sincos_pos_embed_with_resolution( 65 | x.shape[-1], 66 | int(num_patches**0.5), 67 | input_res, 68 | cls_token=True, 69 | device=x.device, 70 | ) 71 | 72 | cls_tokens = self.cls_token.expand( 73 | B, -1, -1 74 | ) # stole cls_tokens impl from Phil Wang, thanks 75 | x = torch.cat((cls_tokens, x), dim=1) 76 | x = x + pos_embed 77 | x = self.pos_drop(x) 78 | 79 | for blk in self.blocks: 80 | x = blk(x) 81 | 82 | if self.global_pool: 83 | x = x[:, 1:, :].mean(dim=1) # global pool without cls token 84 | outcome = self.fc_norm(x) 85 | else: 86 | x = self.norm(x) 87 | outcome = x[:, 0] 88 | 89 | return outcome 90 | 91 | def forward(self, x, input_res=None): 92 | x = self.forward_features(x, input_res=input_res) 93 | x = self.head(x) 94 | return x 95 | 96 | 97 | def vit_base_patch16(**kwargs): 98 | model = VisionTransformer( 99 | patch_size=16, 100 | embed_dim=768, 101 | depth=12, 102 | num_heads=12, 103 | mlp_ratio=4, 104 | qkv_bias=True, 105 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 106 | **kwargs 107 | ) 108 | return model 109 | 110 | 111 | def vit_large_patch16(**kwargs): 112 | model = VisionTransformer( 113 | patch_size=16, 114 | embed_dim=1024, 115 | depth=24, 116 | num_heads=16, 117 | mlp_ratio=4, 118 | qkv_bias=True, 119 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 120 | **kwargs 121 | ) 122 | return model 123 | 124 | 125 | def vit_huge_patch14(**kwargs): 126 | model = VisionTransformer( 127 | patch_size=14, 128 | embed_dim=1280, 129 | depth=32, 130 | num_heads=16, 131 | mlp_ratio=4, 132 | qkv_bias=True, 133 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 134 | **kwargs 135 | ) 136 | return model 137 | -------------------------------------------------------------------------------- /mae/samplers/distributed.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Iterator, Optional, Tuple, Union 3 | 4 | import torch 5 | import torch.distributed as dist 6 | from torchgeo.datasets import BoundingBox, GeoDataset 7 | from torchgeo.samplers import RandomGeoSampler, Units 8 | from torchgeo.samplers.utils import get_random_bounding_box 9 | 10 | 11 | class DistributedRandomGeoSampler(RandomGeoSampler): 12 | """Samples elements from a region of interest randomly. 13 | 14 | This is particularly useful during training when you want to maximize the size of 15 | the dataset and return as many random :term:`chips ` as possible. 16 | 17 | This sampler is not recommended for use with tile-based datasets. Use 18 | :class:`RandomBatchGeoSampler` instead. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | dataset: GeoDataset, 24 | size: Union[Tuple[float, float], float], 25 | length: int, 26 | roi: Optional[BoundingBox] = None, 27 | units: Units = Units.PIXELS, 28 | num_replicas: Optional[int] = None, 29 | rank: Optional[int] = None, 30 | seed: int = 0, 31 | ) -> None: 32 | 33 | if num_replicas is None: 34 | if not dist.is_available(): 35 | raise RuntimeError("Requires distributed package to be available") 36 | num_replicas = dist.get_world_size() 37 | if rank is None: 38 | if not dist.is_available(): 39 | raise RuntimeError("Requires distributed package to be available") 40 | rank = dist.get_rank() 41 | if rank >= num_replicas or rank < 0: 42 | raise ValueError( 43 | "Invalid rank {}, rank should be in the interval" 44 | " [0, {}]".format(rank, num_replicas - 1) 45 | ) 46 | self.dataset = dataset 47 | self.num_replicas = num_replicas 48 | self.rank = rank 49 | self.epoch = 0 50 | # Pad the last batch 51 | self.num_samples = math.ceil(length / self.num_replicas) 52 | self.total_size = self.num_samples * self.num_replicas 53 | self.seed = seed 54 | 55 | ##### TODO: Check if this is actually working#### 56 | super().__init__(dataset, size, length, roi, units) 57 | 58 | def __iter__(self) -> Iterator[BoundingBox]: 59 | """Return the index of a dataset. 60 | 61 | Returns: 62 | (minx, maxx, miny, maxy, mint, maxt) coordinates to index a dataset 63 | """ 64 | g = torch.Generator() 65 | g.manual_seed(self.seed + self.epoch) 66 | indices = torch.multinomial( 67 | self.areas, self.total_size, generator=g, replacement=True 68 | ).tolist() 69 | assert len(indices) == self.total_size 70 | indices = indices[self.rank : self.total_size : self.num_replicas] 71 | assert len(indices) == self.num_samples 72 | 73 | for idx in indices: 74 | # Choose a random tile, weighted by area 75 | hit = self.hits[idx] 76 | bounds = BoundingBox(*hit.bounds) 77 | 78 | # Choose a random index within that tile 79 | bounding_box = get_random_bounding_box(bounds, self.size, self.res) 80 | 81 | yield bounding_box 82 | 83 | def __len__(self) -> int: 84 | """Return the number of samples in a single epoch. 85 | 86 | Returns: 87 | length of the epoch 88 | """ 89 | return self.num_samples 90 | 91 | def set_epoch(self, epoch: int) -> None: 92 | r""" 93 | Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas 94 | use a different random ordering for each epoch. Otherwise, the next iteration of this 95 | sampler will yield the same ordering. 96 | 97 | Args: 98 | epoch (int): Epoch number. 99 | """ 100 | self.epoch = epoch 101 | -------------------------------------------------------------------------------- /mae/scripts/eval_launcher.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from argparse import Namespace 4 | 5 | import torch 6 | import wandb 7 | import yaml 8 | 9 | api = wandb.Api() 10 | import sys 11 | sys.path.append('/home/jacklishufan/scale-mae/mae') 12 | import util.misc as misc 13 | from main_pretrain import get_args_parser as pretrain_get_args_parser 14 | from main_pretrain import main as main_pretrain 15 | 16 | 17 | def get_args_parser(): 18 | parser = argparse.ArgumentParser("Eval controller", add_help=False) 19 | parser.add_argument( 20 | "--eval_config", 21 | default=os.path.join(os.path.dirname(__file__), "evalconf/demo-conf.yaml"), 22 | type=str, 23 | help="Eval config file", 24 | ) 25 | 26 | ################################### 27 | # DISTRIBUTED TRAINING PARAMETERS # 28 | ################################### 29 | parser.add_argument( 30 | "--world_size", default=1, type=int, help="number of distributed processes" 31 | ) 32 | parser.add_argument("--local_rank", default=-1, type=int) 33 | parser.add_argument("--dist_on_itp", action="store_true") 34 | parser.add_argument( 35 | "--dist_url", default="env://", help="url used to set up distributed training" 36 | ) 37 | parser.add_argument( 38 | "--knn", default=20, type=int, help="Number of neighbors to use for KNN" 39 | ) 40 | 41 | parser.add_argument( 42 | "--eval_gsd", 43 | action="store_true", 44 | help="USE GSD Relative Embedding with base=224x224", 45 | ) 46 | parser.add_argument( 47 | "--no-eval_gsd", 48 | action="store_false", 49 | help="USE GSD Relative Embedding with base=224x224", 50 | dest='eval_gsd' 51 | ) 52 | parser.set_defaults(eval_gsd=True) 53 | parser.add_argument( 54 | "--eval_base_resolution", 55 | default=1.0, 56 | type=float, 57 | help="Global Multiplication factor of Positional Embedding Resolution in KNN", 58 | ) 59 | return parser 60 | 61 | 62 | def main(args): 63 | print("Starting eval") 64 | with open(args.eval_config) as f: 65 | config = yaml.safe_load(f.read()) 66 | 67 | if "exp_ids" in config: 68 | exp_ids = config["exp_ids"] 69 | else: 70 | # TODO collect all run ids or specify other conditions 71 | pass 72 | misc.init_distributed_mode(args) 73 | is_main = False 74 | if args.rank == 0: 75 | wandb_args = dict( 76 | project="scale-mae-knn-reproduce", 77 | entity="bair-climate-initiative", 78 | resume="allow", 79 | ) 80 | run = wandb.init(**wandb_args) 81 | run_id = run.id 82 | is_main= True 83 | default_args = pretrain_get_args_parser().parse_args([]) 84 | for expid in exp_ids: 85 | try: 86 | # load the latest checkpoint 87 | # TODO allow different epochs, scales, datasets 88 | mdl_path = os.path.join(config["root"], str(expid), "checkpoint-latest.pth") 89 | mdl = torch.load(mdl_path, map_location="cpu") 90 | margs = mdl["args"] if "args" in mdl else Namespace() 91 | nepochs = mdl["epoch"] if "epoch" in mdl else 100 92 | if nepochs < 90: 93 | print(f"Skipping {expid} because it only has {nepochs} epochs") 94 | continue 95 | 96 | # add all of the eval params 97 | for k, v in config.items(): 98 | setattr(margs, k, v) 99 | # set all of the distributed bits 100 | margs.eval_gsd = args.eval_gsd 101 | margs.eval_base_resolution = args.eval_base_resolution 102 | margs.knn = args.knn 103 | margs.local_rank = args.local_rank 104 | margs.dist_on_itp = args.dist_on_itp 105 | margs.dist_url = args.dist_url 106 | margs.world_size = args.world_size 107 | # only do evaluation 108 | margs.resume = mdl_path 109 | 110 | for eval_data in config["evals"]: 111 | eval_id = eval_data["id"] 112 | margs.eval_scale = eval_data["scales"] 113 | margs.eval_dataset = eval_id 114 | print(f"Starting {margs.eval_dataset} {margs.eval_scale}: {eval_id}") 115 | margs.eval_only = True 116 | margs.eval_train_fnames = os.path.join(eval_data["path"], "train.txt") 117 | margs.eval_val_fnames = os.path.join(eval_data["path"], "val.txt") 118 | 119 | arg_vals = {**vars(default_args), **vars(margs)} 120 | use_args = Namespace(**arg_vals) 121 | use_args.base_resolution = 2.0 122 | res = main_pretrain(use_args) 123 | 124 | if is_main: 125 | wandb_run = api.run( 126 | f"bair-climate-initiative/scale-mae-knn-reproduce/{run_id}" 127 | ) 128 | for scale, acc in res.items(): 129 | wandb_run.summary[f"{eval_id}-knn-acc-{scale}"]= acc * 100.0 130 | wandb_run.summary.update() 131 | print("Sent results", res) 132 | print("HERE") 133 | except Exception as err: 134 | print(f"Unable to process (will skip) {expid}: {err}") 135 | 136 | 137 | if __name__ == "__main__": 138 | args = get_args_parser().parse_args() 139 | main(args) 140 | -------------------------------------------------------------------------------- /mae/scripts/evalconf/demo-conf.yaml: -------------------------------------------------------------------------------- 1 | # Specify the root directory that contains your checkpoints 2 | root: experiments 3 | 4 | evals: 5 | - id: fmow 6 | path: data/fdata/val 7 | scales: 8 | - 56 9 | - 112 10 | - 224 11 | - id: airound 12 | path: data/aerial 13 | scales: 14 | - 224 15 | - 112 16 | - 56 17 | - id: mlrsnet 18 | path: data/mlrsnet/Images 19 | scales: 20 | - 224 21 | - 112 22 | - 56 23 | - id: resisc 24 | path: data/resisc45 25 | scales: 26 | - 224 27 | - 112 28 | - 56 29 | 30 | exp_ids: 31 | - 66042715 32 | - 66042713 33 | - 66042712 34 | - 66033181 35 | - 66033180 36 | - 66033178 37 | - 66033177 38 | - 66032972 39 | - 66032971 40 | - 66032970 41 | - 66032969 42 | - 66032967 43 | - 66032965 44 | - 66032962 45 | - 66032961 46 | - 66032687 47 | - 66032686 48 | - 66032685 49 | - 66032684 50 | - 66032677 51 | - 66032676 52 | - 66031607 53 | - 66031599 54 | - 66001340 55 | - 66001200 56 | - 65896740 57 | - 65891704 58 | - 65889622 59 | - 65888830 60 | - 65829475 61 | - 65828768 62 | - 65801974 63 | - 65801948 64 | - 65797652 65 | - 65794283 66 | - 65788556 67 | - 65774431 68 | - 65772750 69 | - 65756430 70 | - 65741967 71 | - 65741789 72 | - 65741786 73 | - 65741784 74 | - 65722636 75 | - 65717963 76 | - 65717792 77 | - 65715972 78 | - 65710678 79 | - 65693964 80 | - 65678678 81 | - 65678677 82 | - 65678673 83 | - 65678630 84 | - 65677698 85 | - 65677004 86 | - 65675349 87 | - 65626892 88 | - 65613328 89 | - 65613327 90 | - 65600410 91 | - 65599927 92 | - 65578864 93 | - 65578863 94 | - 65578733 95 | - 65578732 96 | - 65578731 97 | - 65557997 98 | - 65541427 99 | - 65541425 100 | - 65541383 -------------------------------------------------------------------------------- /mae/scripts/evalconf/dgx-conf.yaml: -------------------------------------------------------------------------------- 1 | # Specify the root directory that contains your checkpoints 2 | root: /exps 3 | 4 | # will pull checkpoint from ${root}/${exp_ids}/checkpoint-latest.pth 5 | exp_ids: 6 | - eval1 7 | 8 | evals: 9 | - id: resisc 10 | path: resisc45 11 | scales: 12 | - 56 13 | - 112 14 | - 224 15 | 16 | -------------------------------------------------------------------------------- /mae/scripts/gen_scale_perf_plots.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | 5 | import matplotlib.pyplot as plt 6 | import pandas as pd 7 | import wandb 8 | import yaml 9 | 10 | api = wandb.Api() 11 | 12 | 13 | max_res = { 14 | "resisc": 256, 15 | "ucmerced": 256, 16 | "whu-rs19": 256, 17 | "airound": 496, 18 | "mlrsnet": 256, 19 | "cvbrct": 496, 20 | "eurosat": 64, 21 | "optimal-31": 256, 22 | } 23 | 24 | res_lims = { 25 | "resisc": [0.3, 0.9], 26 | "ucmerced": [0.4, 0.9], 27 | "whu-rs19": [0.4, 1.0], 28 | "airound": [0.35, 0.8], 29 | "mlrsnet": [0.5, 1.0], 30 | "cvbrct": [0.4, 0.8], 31 | "eurosat": [0.5, 1.0], 32 | "optimal-31": [0.35, 0.8], 33 | } 34 | 35 | 36 | def get_args_parser(): 37 | parser = argparse.ArgumentParser("Eval controller", add_help=False) 38 | parser.add_argument("--runs", nargs="*", type=str, default=["1dcghih0", "2y9klhll"]) 39 | 40 | parser.add_argument("--evals", nargs="*", type=str, default=sorted(max_res.keys())) 41 | 42 | parser.add_argument( 43 | "--px", nargs="*", type=int, default=[16, 32, 64, 128, 256, 496] 44 | ) 45 | 46 | parser.add_argument("--names", nargs="*", type=str, default=["Scale-MAE", "SatMAE"]) 47 | 48 | parser.add_argument("--output", type=str, default="results.png") 49 | 50 | return parser 51 | 52 | 53 | def main(args): 54 | pd.DataFrame() 55 | 56 | # for each eval dataset 57 | ## for each method 58 | ## for each resolution 59 | if not os.path.exists("cached.pkl"): 60 | all_data = [] 61 | for i, rid in enumerate(args.runs): 62 | wandb_run = api.run(f"bair-climate-initiative/multiscale_mae/{rid}") 63 | name = args.names[i] 64 | print(name) 65 | for eval in args.evals: 66 | print(eval) 67 | wtable = api.artifact( 68 | f"bair-climate-initiative/multiscale_mae/run-{rid}-{eval.replace('-','')}_eval:latest" 69 | ).get(f"{eval}_eval") 70 | if wtable is None: 71 | import ipdb 72 | 73 | ipdb.set_trace() 74 | for datum in wtable.data: 75 | data = {name: val for val, name in zip(datum, wtable.columns)} 76 | if data["val_resolution"] in args.px: 77 | all_data.append( 78 | dict( 79 | result=data["acc"], 80 | px=data["val_resolution"], 81 | name=name, 82 | valset=eval, 83 | ) 84 | ) 85 | data = pd.DataFrame(all_data) 86 | pd.to_pickle(data, "cached.pkl") 87 | else: 88 | print("using cache") 89 | data = pickle.load(open("cached.pkl", "rb")) 90 | 91 | # generate the performance plots 92 | nvals = data.valset.nunique() 93 | ncols = 4 94 | nrows = (nvals + ncols - 1) // ncols 95 | fig, axs = plt.subplots(nrows=nrows, ncols=ncols, squeeze=0, figsize=(21, 6)) 96 | ct = 0 97 | for ax in axs.reshape(-1): 98 | if ct >= nvals: 99 | break 100 | subdata = data[(data.valset == args.evals[ct]) & (data.result > -1)] 101 | subdata.px /= float(max_res[args.evals[ct]]) 102 | subdata = subdata.groupby(["name"]) 103 | 104 | legend = ct == 0 # nvals - 1 105 | subdata.plot( 106 | kind="line", 107 | x="px", 108 | y="result", 109 | ax=ax, 110 | legend=False, 111 | title=args.evals[ct].upper(), 112 | style=".-", 113 | ) 114 | if legend: 115 | handles, _ = ax.get_legend_handles_labels() 116 | ax.legend(handles[::-1], [x for x in subdata.groups.keys()][::-1]) 117 | ax.set_xticks( 118 | ticks=[0, 0.25, 0.5, 0.75, 1.0], labels=["0", "25%", "50%", "75%", "100%"] 119 | ) 120 | ax.yaxis.get_major_locator().set_params(integer=True) 121 | ax.set_ylim(res_lims[args.evals[ct]]) 122 | if ct % ncols == 0: 123 | ax.set_ylabel("KNN acc.") 124 | else: 125 | ax.set_ylabel("") 126 | if ct // ncols == nrows - 1: 127 | ax.set_xlabel("Relative GSD") 128 | else: 129 | ax.set_xlabel("") 130 | ct += 1 131 | plt.subplots_adjust(hspace=0.3) 132 | fig.savefig(args.output, bbox_inches="tight") 133 | 134 | 135 | if __name__ == "__main__": 136 | print("starting eval") 137 | main(get_args_parser().parse_args()) 138 | -------------------------------------------------------------------------------- /mae/util/crop.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | import torch 10 | from torchvision import transforms 11 | from torchvision.transforms import functional as F 12 | 13 | 14 | class RandomResizedCrop(transforms.RandomResizedCrop): 15 | """ 16 | RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. 17 | This may lead to results different with torchvision's version. 18 | Following BYOL's TF code: 19 | https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 20 | """ 21 | 22 | @staticmethod 23 | def get_params(img, scale, ratio): 24 | width, height = F.get_image_size(img) 25 | area = height * width 26 | 27 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() 28 | log_ratio = torch.log(torch.tensor(ratio)) 29 | aspect_ratio = torch.exp( 30 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) 31 | ).item() 32 | 33 | w = int(round(math.sqrt(target_area * aspect_ratio))) 34 | h = int(round(math.sqrt(target_area / aspect_ratio))) 35 | 36 | w = min(w, width) 37 | h = min(h, height) 38 | 39 | i = torch.randint(0, height - h + 1, size=(1,)).item() 40 | j = torch.randint(0, width - w + 1, size=(1,)).item() 41 | 42 | return i, j, h, w 43 | -------------------------------------------------------------------------------- /mae/util/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # -------------------------------------------------------- 10 | 11 | import os 12 | 13 | import PIL 14 | from timm.data import create_transform 15 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 16 | from torchvision import datasets, transforms 17 | 18 | 19 | def build_dataset(is_train, args): 20 | transform = build_transform(is_train, args) 21 | 22 | root = os.path.join(args.data_path, "train" if is_train else "val") 23 | dataset = datasets.ImageFolder(root, transform=transform) 24 | 25 | print(dataset) 26 | 27 | return dataset 28 | 29 | 30 | def build_transform(is_train, args, config): 31 | mean = IMAGENET_DEFAULT_MEAN 32 | std = IMAGENET_DEFAULT_STD 33 | # train transform 34 | if is_train: 35 | # this should always dispatch to transforms_imagenet_train 36 | transform = create_transform( 37 | input_size=config["data"]["input_size"], 38 | is_training=True, 39 | color_jitter=args.color_jitter, 40 | auto_augment=args.aa, 41 | interpolation="bicubic", 42 | re_prob=args.reprob, 43 | re_mode=args.remode, 44 | re_count=args.recount, 45 | mean=mean, 46 | std=std, 47 | ) 48 | return transform 49 | 50 | # eval transform 51 | t = [] 52 | if config["data"]["input_size"] <= 224: 53 | crop_pct = 224 / 256 54 | else: 55 | crop_pct = 1.0 56 | size = int(config["data"]["input_size"] / crop_pct) 57 | t.append( 58 | transforms.Resize( 59 | size, interpolation=PIL.Image.BICUBIC 60 | ) # to maintain same ratio w.r.t. 224 images 61 | ) 62 | t.append(transforms.CenterCrop(config["data"]["input_size"])) 63 | 64 | t.append(transforms.ToTensor()) 65 | t.append(transforms.Normalize(mean, std)) 66 | return transforms.Compose(t) 67 | -------------------------------------------------------------------------------- /mae/util/dist_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import random 7 | from typing import List 8 | 9 | import numpy as np 10 | import torch 11 | import torch.distributed as dist 12 | from classy_vision.generic.distributed_util import ( 13 | convert_to_distributed_tensor, 14 | convert_to_normal_tensor, 15 | is_distributed_training_run, 16 | ) 17 | 18 | 19 | def set_seed(seed: int = 0) -> None: 20 | """ 21 | Set random seed for PyTorch, Python random, NumPy 22 | Sets CUDA convolution algorithm to be deterministic (see https://pytorch.org/docs/stable/notes/randomness.html#avoiding-nondeterministic-algorithms) 23 | """ 24 | torch.manual_seed(seed) 25 | random.seed(seed) 26 | np.random.seed(seed) 27 | 28 | torch.backends.cudnn.deterministic = True 29 | 30 | 31 | class GatherLayer(torch.autograd.Function): 32 | """ 33 | Gather tensors from all workers with support for backward propagation: 34 | This implementation does not cut the gradients as torch.distributed.all_gather does. 35 | """ 36 | 37 | @staticmethod 38 | def forward(ctx, x): 39 | output = [torch.zeros_like(x) for _ in range(dist.get_world_size())] 40 | dist.all_gather(output, x) 41 | return tuple(output) 42 | 43 | @staticmethod 44 | def backward(ctx, *grads): 45 | all_gradients = torch.stack(grads) 46 | dist.all_reduce(all_gradients) 47 | return all_gradients[dist.get_rank()] 48 | 49 | 50 | def gather_from_all(tensor: torch.Tensor) -> torch.Tensor: 51 | """ 52 | Similar to classy_vision.generic.distributed_util.gather_from_all 53 | except that it does not cut the gradients 54 | """ 55 | if tensor.ndim == 0: 56 | # 0 dim tensors cannot be gathered. so unsqueeze 57 | tensor = tensor.unsqueeze(0) 58 | 59 | if is_distributed_training_run(): 60 | tensor, orig_device = convert_to_distributed_tensor(tensor) 61 | gathered_tensors = GatherLayer.apply(tensor) 62 | gathered_tensors = [ 63 | convert_to_normal_tensor(_tensor, orig_device) 64 | for _tensor in gathered_tensors 65 | ] 66 | else: 67 | gathered_tensors = [tensor] 68 | gathered_tensor = torch.cat(gathered_tensors, 0) 69 | return gathered_tensor 70 | 71 | 72 | def all_gather_sizes(x: torch.Tensor) -> List[int]: 73 | """ 74 | Get the first dimension sizes of the the tensor to gather on each 75 | of the distributed workers 76 | """ 77 | dist_rank = torch.distributed.get_rank() 78 | world_size = torch.distributed.get_world_size() 79 | current_device = torch.device("cuda", torch.cuda.current_device()) 80 | sizes = torch.zeros(size=(world_size,), device=current_device, dtype=torch.int64) 81 | sizes[dist_rank] = x.shape[0] 82 | torch.distributed.all_reduce(sizes) 83 | return list(sizes.cpu().numpy()) 84 | 85 | 86 | def all_gather_heterogeneous(sizes: List[int], x: torch.Tensor) -> List[torch.Tensor]: 87 | """ 88 | Gather a list of heterogeenous tensors shape in the first 89 | dimension (different batch sizes) 90 | """ 91 | current_device = torch.device("cuda", torch.cuda.current_device()) 92 | shape = x.shape[1:] 93 | all_x = [ 94 | torch.zeros(size=(sizes[i], *shape), device=current_device, dtype=x.dtype) 95 | for i in range(torch.distributed.get_world_size()) 96 | ] 97 | torch.distributed.all_gather(all_x, x) 98 | return all_x 99 | -------------------------------------------------------------------------------- /mae/util/lars.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # LARS optimizer, implementation from MoCo v3: 8 | # https://github.com/facebookresearch/moco-v3 9 | # -------------------------------------------------------- 10 | 11 | import torch 12 | 13 | 14 | class LARS(torch.optim.Optimizer): 15 | """ 16 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 17 | """ 18 | 19 | def __init__( 20 | self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001 21 | ): 22 | defaults = dict( 23 | lr=lr, 24 | weight_decay=weight_decay, 25 | momentum=momentum, 26 | trust_coefficient=trust_coefficient, 27 | ) 28 | super().__init__(params, defaults) 29 | 30 | @torch.no_grad() 31 | def step(self): 32 | for g in self.param_groups: 33 | for p in g["params"]: 34 | dp = p.grad 35 | 36 | if dp is None: 37 | continue 38 | 39 | if p.ndim > 1: # if not normalization gamma/beta or bias 40 | dp = dp.add(p, alpha=g["weight_decay"]) 41 | param_norm = torch.norm(p) 42 | update_norm = torch.norm(dp) 43 | one = torch.ones_like(param_norm) 44 | q = torch.where( 45 | param_norm > 0.0, 46 | torch.where( 47 | update_norm > 0, 48 | (g["trust_coefficient"] * param_norm / update_norm), 49 | one, 50 | ), 51 | one, 52 | ) 53 | dp = dp.mul(q) 54 | 55 | param_state = self.state[p] 56 | if "mu" not in param_state: 57 | param_state["mu"] = torch.zeros_like(p) 58 | mu = param_state["mu"] 59 | mu.mul_(g["momentum"]).add_(dp) 60 | p.add_(mu, alpha=-g["lr"]) 61 | -------------------------------------------------------------------------------- /mae/util/lr_decay.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # ELECTRA https://github.com/google-research/electra 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import json 13 | 14 | 15 | def param_groups_lrd( 16 | model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=0.75 17 | ): 18 | """ 19 | Parameter groups for layer-wise lr decay 20 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58 21 | """ 22 | param_group_names = {} 23 | param_groups = {} 24 | 25 | num_layers = len(model.blocks) + 1 26 | 27 | layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1)) 28 | 29 | for n, p in model.named_parameters(): 30 | if not p.requires_grad: 31 | continue 32 | 33 | # no decay: all 1D parameters and model specific ones 34 | if p.ndim == 1 or n in no_weight_decay_list: 35 | g_decay = "no_decay" 36 | this_decay = 0.0 37 | else: 38 | g_decay = "decay" 39 | this_decay = weight_decay 40 | 41 | layer_id = get_layer_id_for_vit(n, num_layers) 42 | group_name = "layer_%d_%s" % (layer_id, g_decay) 43 | 44 | if group_name not in param_group_names: 45 | this_scale = layer_scales[layer_id] 46 | 47 | param_group_names[group_name] = { 48 | "lr_scale": this_scale, 49 | "weight_decay": this_decay, 50 | "params": [], 51 | } 52 | param_groups[group_name] = { 53 | "lr_scale": this_scale, 54 | "weight_decay": this_decay, 55 | "params": [], 56 | } 57 | 58 | param_group_names[group_name]["params"].append(n) 59 | param_groups[group_name]["params"].append(p) 60 | 61 | # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2)) 62 | 63 | return list(param_groups.values()) 64 | 65 | 66 | def get_layer_id_for_vit(name, num_layers): 67 | """ 68 | Assign a parameter with its layer id 69 | Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 70 | """ 71 | if name in ["cls_token", "pos_embed"]: 72 | return 0 73 | elif name.startswith("patch_embed"): 74 | return 0 75 | elif name.startswith("blocks"): 76 | return int(name.split(".")[1]) + 1 77 | else: 78 | return num_layers 79 | -------------------------------------------------------------------------------- /mae/util/lr_sched.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | 10 | def adjust_learning_rate(optimizer, epoch, args): 11 | """Decay the learning rate with half-cycle cosine after warmup""" 12 | if epoch < args.warmup_epochs: 13 | lr = args.lr * epoch / args.warmup_epochs 14 | else: 15 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * ( 16 | 1.0 17 | + math.cos( 18 | math.pi 19 | * (epoch - args.warmup_epochs) 20 | / (args.epochs - args.warmup_epochs) 21 | ) 22 | ) 23 | for param_group in optimizer.param_groups: 24 | if "lr_scale" in param_group: 25 | param_group["lr"] = lr * param_group["lr_scale"] 26 | else: 27 | param_group["lr"] = lr 28 | return lr 29 | -------------------------------------------------------------------------------- /mae/util/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # References: 8 | # DeiT: https://github.com/facebookresearch/deit 9 | # BEiT: https://github.com/microsoft/unilm/tree/master/beit 10 | # -------------------------------------------------------- 11 | 12 | import builtins 13 | import datetime 14 | import os 15 | import time 16 | from collections import defaultdict, deque 17 | from pathlib import Path 18 | 19 | import torch 20 | import torch.distributed as dist 21 | from torch import inf 22 | 23 | 24 | class SmoothedValue: 25 | """Track a series of values and provide access to smoothed values over a 26 | window or the global series average. 27 | """ 28 | 29 | def __init__(self, window_size=20, fmt=None): 30 | if fmt is None: 31 | fmt = "{median:.4f} ({global_avg:.4f})" 32 | self.deque = deque(maxlen=window_size) 33 | self.total = 0.0 34 | self.count = 0 35 | self.fmt = fmt 36 | 37 | def update(self, value, n=1): 38 | self.deque.append(value) 39 | self.count += n 40 | self.total += value * n 41 | 42 | def synchronize_between_processes(self): 43 | """ 44 | Warning: does not synchronize the deque! 45 | """ 46 | if not is_dist_avail_and_initialized(): 47 | return 48 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") 49 | dist.barrier() 50 | dist.all_reduce(t) 51 | t = t.tolist() 52 | self.count = int(t[0]) 53 | self.total = t[1] 54 | 55 | @property 56 | def median(self): 57 | d = torch.tensor(list(self.deque)) 58 | return d.median().item() 59 | 60 | @property 61 | def avg(self): 62 | d = torch.tensor(list(self.deque), dtype=torch.float32) 63 | return d.mean().item() 64 | 65 | @property 66 | def global_avg(self): 67 | return self.total / self.count 68 | 69 | @property 70 | def max(self): 71 | return max(self.deque) 72 | 73 | @property 74 | def value(self): 75 | return self.deque[-1] 76 | 77 | def __str__(self): 78 | return self.fmt.format( 79 | median=self.median, 80 | avg=self.avg, 81 | global_avg=self.global_avg, 82 | max=self.max, 83 | value=self.value, 84 | ) 85 | 86 | 87 | class MetricLogger: 88 | def __init__(self, delimiter="\t"): 89 | self.meters = defaultdict(SmoothedValue) 90 | self.delimiter = delimiter 91 | 92 | def update(self, **kwargs): 93 | for k, v in kwargs.items(): 94 | if v is None: 95 | continue 96 | if isinstance(v, torch.Tensor): 97 | v = v.item() 98 | assert isinstance(v, (float, int)) 99 | self.meters[k].update(v) 100 | 101 | def __getattr__(self, attr): 102 | if attr in self.meters: 103 | return self.meters[attr] 104 | if attr in self.__dict__: 105 | return self.__dict__[attr] 106 | raise AttributeError( 107 | f"'{type(self).__name__}' object has no attribute '{attr}'" 108 | ) 109 | 110 | def __str__(self): 111 | loss_str = [] 112 | for name, meter in self.meters.items(): 113 | loss_str.append(f"{name}: {str(meter)}") 114 | return self.delimiter.join(loss_str) 115 | 116 | def synchronize_between_processes(self): 117 | for meter in self.meters.values(): 118 | meter.synchronize_between_processes() 119 | 120 | def add_meter(self, name, meter): 121 | self.meters[name] = meter 122 | 123 | def log_every(self, iterable, print_freq, header=None): 124 | i = 0 125 | if not header: 126 | header = "" 127 | start_time = time.time() 128 | end = time.time() 129 | iter_time = SmoothedValue(fmt="{avg:.4f}") 130 | data_time = SmoothedValue(fmt="{avg:.4f}") 131 | space_fmt = ":" + str(len(str(len(iterable)))) + "d" 132 | log_msg = [ 133 | header, 134 | "[{0" + space_fmt + "}/{1}]", 135 | "eta: {eta}", 136 | "{meters}", 137 | "time: {time}", 138 | "data: {data}", 139 | ] 140 | if torch.cuda.is_available(): 141 | log_msg.append("max mem: {memory:.0f}") 142 | log_msg = self.delimiter.join(log_msg) 143 | MB = 1024.0 * 1024.0 144 | for obj in iterable: 145 | data_time.update(time.time() - end) 146 | yield obj 147 | iter_time.update(time.time() - end) 148 | if i % print_freq == 0 or i == len(iterable) - 1: 149 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 150 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 151 | if torch.cuda.is_available(): 152 | print( 153 | log_msg.format( 154 | i, 155 | len(iterable), 156 | eta=eta_string, 157 | meters=str(self), 158 | time=str(iter_time), 159 | data=str(data_time), 160 | memory=torch.cuda.max_memory_allocated() / MB, 161 | ) 162 | ) 163 | else: 164 | print( 165 | log_msg.format( 166 | i, 167 | len(iterable), 168 | eta=eta_string, 169 | meters=str(self), 170 | time=str(iter_time), 171 | data=str(data_time), 172 | ) 173 | ) 174 | i += 1 175 | end = time.time() 176 | total_time = time.time() - start_time 177 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 178 | print( 179 | "{} Total time: {} ({:.4f} s / it)".format( 180 | header, total_time_str, total_time / len(iterable) 181 | ) 182 | ) 183 | 184 | 185 | setup = False 186 | 187 | 188 | def setup_for_distributed(is_master): 189 | """ 190 | This function disables printing when not in master process 191 | """ 192 | global setup 193 | if not setup: 194 | builtin_print = builtins.print 195 | 196 | def print(*args, **kwargs): 197 | force = kwargs.pop("force", False) 198 | force = force or (get_world_size() > 8) 199 | if is_master or force: 200 | now = datetime.datetime.now().time() 201 | builtin_print(f"[{now}] ", end="") # print with time stamp 202 | builtin_print(*args, **kwargs) 203 | 204 | builtins.print = print 205 | setup = True 206 | 207 | 208 | def is_dist_avail_and_initialized(): 209 | if not dist.is_available(): 210 | return False 211 | if not dist.is_initialized(): 212 | return False 213 | return True 214 | 215 | 216 | def get_world_size(): 217 | if not is_dist_avail_and_initialized(): 218 | return 1 219 | return dist.get_world_size() 220 | 221 | 222 | def get_rank(): 223 | if not is_dist_avail_and_initialized(): 224 | return 0 225 | return dist.get_rank() 226 | 227 | 228 | def is_main_process(): 229 | return get_rank() == 0 230 | 231 | 232 | def save_on_master(*args, **kwargs): 233 | if is_main_process(): 234 | torch.save(*args, **kwargs) 235 | 236 | 237 | def init_distributed_mode(args): 238 | if args.dist_on_itp: 239 | args.rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) 240 | args.world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) 241 | args.gpu = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) 242 | args.dist_url = "tcp://{}:{}".format( 243 | os.environ["MASTER_ADDR"], os.environ["MASTER_PORT"] 244 | ) 245 | os.environ["LOCAL_RANK"] = str(args.gpu) 246 | os.environ["RANK"] = str(args.rank) 247 | os.environ["WORLD_SIZE"] = str(args.world_size) 248 | # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"] 249 | elif "RANK" in os.environ and "WORLD_SIZE" in os.environ: 250 | args.rank = int(os.environ["RANK"]) 251 | args.world_size = int(os.environ["WORLD_SIZE"]) 252 | args.gpu = int(os.environ["LOCAL_RANK"]) 253 | elif "SLURM_PROCID" in os.environ: 254 | args.rank = int(os.environ["SLURM_PROCID"]) 255 | args.gpu = args.rank % torch.cuda.device_count() 256 | else: 257 | print("Not using distributed mode") 258 | # setup_for_distributed(is_master=True) # hack 259 | args.distributed = False 260 | return 261 | 262 | args.distributed = True 263 | 264 | torch.cuda.set_device(args.gpu) 265 | args.dist_backend = "nccl" 266 | print( 267 | "| distributed init (rank {}): {}, gpu {}".format( 268 | args.rank, args.dist_url, args.gpu 269 | ), 270 | flush=True, 271 | ) 272 | if not torch.distributed.is_initialized(): 273 | torch.distributed.init_process_group( 274 | backend=args.dist_backend, 275 | init_method=args.dist_url, 276 | world_size=args.world_size, 277 | rank=args.rank, 278 | ) 279 | torch.distributed.barrier() 280 | setup_for_distributed(args.rank == 0) 281 | 282 | 283 | class NativeScalerWithGradNormCount: 284 | state_dict_key = "amp_scaler" 285 | 286 | def __init__(self): 287 | self._scaler = torch.cuda.amp.GradScaler() 288 | 289 | def __call__( 290 | self, 291 | loss, 292 | optimizer, 293 | clip_grad=None, 294 | parameters=None, 295 | create_graph=False, 296 | update_grad=True, 297 | ): 298 | self._scaler.scale(loss).backward(create_graph=create_graph) 299 | if update_grad: 300 | if clip_grad is not None: 301 | assert parameters is not None 302 | self._scaler.unscale_( 303 | optimizer 304 | ) # unscale the gradients of optimizer's assigned params in-place 305 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 306 | else: 307 | self._scaler.unscale_(optimizer) 308 | norm = get_grad_norm_(parameters) 309 | self._scaler.step(optimizer) 310 | self._scaler.update() 311 | else: 312 | norm = None 313 | return norm 314 | 315 | def state_dict(self): 316 | return self._scaler.state_dict() 317 | 318 | def load_state_dict(self, state_dict): 319 | self._scaler.load_state_dict(state_dict) 320 | 321 | 322 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 323 | if isinstance(parameters, torch.Tensor): 324 | parameters = [parameters] 325 | parameters = [p for p in parameters if p.grad is not None] 326 | norm_type = float(norm_type) 327 | if len(parameters) == 0: 328 | return torch.tensor(0.0) 329 | device = parameters[0].grad.device 330 | if norm_type == inf: 331 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 332 | else: 333 | total_norm = torch.norm( 334 | torch.stack( 335 | [torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters] 336 | ), 337 | norm_type, 338 | ) 339 | return total_norm 340 | 341 | 342 | def save_model( 343 | args, epoch, model, model_without_ddp, optimizer, loss_scaler, latest=False 344 | ): 345 | output_dir = Path(args.output_dir) 346 | if latest: 347 | epoch_name = "latest" 348 | else: 349 | epoch_name = str(epoch) 350 | if loss_scaler is not None: 351 | checkpoint_paths = [output_dir / ("checkpoint-%s.pth" % epoch_name)] 352 | for checkpoint_path in checkpoint_paths: 353 | to_save = { 354 | "model": model_without_ddp.state_dict(), 355 | "optimizer": optimizer.state_dict(), 356 | "epoch": epoch, 357 | "scaler": loss_scaler.state_dict(), 358 | "args": args, 359 | } 360 | 361 | save_on_master(to_save, checkpoint_path) 362 | else: 363 | client_state = {"epoch": epoch} 364 | model.save_checkpoint( 365 | save_dir=args.output_dir, 366 | tag="checkpoint-%s" % epoch_name, 367 | client_state=client_state, 368 | ) 369 | 370 | 371 | def load_model(args, model_without_ddp, optimizer, loss_scaler, strict=True): 372 | if args.resume: 373 | if args.resume.startswith("https"): 374 | checkpoint = torch.hub.load_state_dict_from_url( 375 | args.resume, map_location="cpu", check_hash=True 376 | ) 377 | else: 378 | checkpoint = torch.load(args.resume, map_location="cpu") 379 | model_without_ddp.load_state_dict(checkpoint["model"], strict=False) 380 | print("Resume checkpoint %s" % args.resume) 381 | if ( 382 | "optimizer" in checkpoint 383 | and "epoch" in checkpoint 384 | and not (hasattr(args, "eval") and args.eval) 385 | ): 386 | optimizer.load_state_dict(checkpoint["optimizer"]) 387 | if args.restart: 388 | args.start_epoch = 0 389 | else: 390 | args.start_epoch = checkpoint["epoch"] + 1 391 | if "scaler" in checkpoint: 392 | loss_scaler.load_state_dict(checkpoint["scaler"]) 393 | print("With optim & sched!") 394 | if "wandb_id" in checkpoint.get("args", []) and not args.no_autoresume: 395 | args.wandb_id = checkpoint["args"].wandb_id 396 | 397 | 398 | def all_reduce_mean(x): 399 | world_size = get_world_size() 400 | if world_size > 1: 401 | x_reduce = torch.tensor(x).cuda() 402 | dist.all_reduce(x_reduce) 403 | x_reduce /= world_size 404 | return x_reduce.item() 405 | else: 406 | return x 407 | -------------------------------------------------------------------------------- /mae/util/pos_embed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | import torch 12 | 13 | 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_with_resolution( 39 | embed_dim, grid_size, res, cls_token=False, device="cpu" 40 | ): 41 | """ 42 | grid_size: int of the grid height and width 43 | res: array of size n, representing the resolution of a pixel (say, in meters), 44 | return: 45 | pos_embed: [n,grid_size*grid_size, embed_dim] or [n,1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 46 | """ 47 | # res = torch.FloatTensor(res).to(device) 48 | res = res.to(device) 49 | grid_h = torch.arange(grid_size, dtype=torch.float32, device=device) 50 | grid_w = torch.arange(grid_size, dtype=torch.float32, device=device) 51 | grid = torch.meshgrid( 52 | grid_w, grid_h, indexing="xy" 53 | ) # here h goes first,direction reversed for numpy 54 | grid = torch.stack(grid, dim=0) # 2 x h x w 55 | 56 | # grid = grid.reshape([2, 1, grid_size, grid_size]) 57 | grid = torch.einsum("chw,n->cnhw", grid, res) # 2 x n x h x w 58 | _, n, h, w = grid.shape 59 | pos_embed = get_2d_sincos_pos_embed_from_grid_torch( 60 | embed_dim, grid 61 | ) # # (nxH*W, D/2) 62 | pos_embed = pos_embed.reshape(n, h * w, embed_dim) 63 | if cls_token: 64 | pos_embed = torch.cat( 65 | [ 66 | torch.zeros( 67 | [n, 1, embed_dim], dtype=torch.float32, device=pos_embed.device 68 | ), 69 | pos_embed, 70 | ], 71 | dim=1, 72 | ) 73 | return pos_embed 74 | 75 | 76 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 77 | assert embed_dim % 2 == 0 78 | 79 | # use half of dimensions to encode grid_h 80 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 81 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 82 | 83 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 84 | return emb 85 | 86 | 87 | def get_2d_sincos_pos_embed_from_grid_torch(embed_dim, grid): 88 | assert embed_dim % 2 == 0 89 | 90 | # use half of dimensions to encode grid_h 91 | emb_h = get_1d_sincos_pos_embed_from_grid_torch( 92 | embed_dim // 2, grid[0] 93 | ) # (H*W, D/2) 94 | emb_w = get_1d_sincos_pos_embed_from_grid_torch( 95 | embed_dim // 2, grid[1] 96 | ) # (H*W, D/2) 97 | 98 | emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D) 99 | return emb 100 | 101 | 102 | def get_1d_sincos_pos_embed_from_grid_torch(embed_dim, pos): 103 | """ 104 | embed_dim: output dimension for each position 105 | pos: a list of positions to be encoded: size (M,) 106 | out: (M, D) 107 | """ 108 | assert embed_dim % 2 == 0 109 | old_shape = pos 110 | omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device) 111 | omega /= embed_dim / 2.0 112 | omega = 1.0 / 10000**omega # (D/2,) 113 | 114 | pos = pos.reshape(-1) # (M,) 115 | out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product 116 | 117 | emb_sin = torch.sin(out) # (M, D/2) 118 | emb_cos = torch.cos(out) # (M, D/2) 119 | 120 | emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) 121 | return emb 122 | 123 | 124 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 125 | """ 126 | embed_dim: output dimension for each position 127 | pos: a list of positions to be encoded: size (M,) 128 | out: (M, D) 129 | """ 130 | assert embed_dim % 2 == 0 131 | omega = np.arange(embed_dim // 2, dtype=np.float32) 132 | omega /= embed_dim / 2.0 133 | omega = 1.0 / 10000**omega # (D/2,) 134 | 135 | pos = pos.reshape(-1) # (M,) 136 | out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product 137 | 138 | emb_sin = np.sin(out) # (M, D/2) 139 | emb_cos = np.cos(out) # (M, D/2) 140 | 141 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 142 | return emb 143 | 144 | 145 | # -------------------------------------------------------- 146 | # Interpolate position embeddings for high-resolution 147 | # References: 148 | # DeiT: https://github.com/facebookresearch/deit 149 | # -------------------------------------------------------- 150 | def interpolate_pos_embed(model, checkpoint_model): 151 | if "pos_embed" in checkpoint_model: 152 | pos_embed_checkpoint = checkpoint_model["pos_embed"] 153 | embedding_size = pos_embed_checkpoint.shape[-1] 154 | num_patches = model.patch_embed.num_patches 155 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 156 | # height (== width) for the checkpoint position embedding 157 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 158 | # height (== width) for the new position embedding 159 | new_size = int(num_patches**0.5) 160 | # class_token and dist_token are kept unchanged 161 | if orig_size != new_size: 162 | print( 163 | "Position interpolate from %dx%d to %dx%d" 164 | % (orig_size, orig_size, new_size, new_size) 165 | ) 166 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 167 | # only the position tokens are interpolated 168 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 169 | pos_tokens = pos_tokens.reshape( 170 | -1, orig_size, orig_size, embedding_size 171 | ).permute(0, 3, 1, 2) 172 | pos_tokens = torch.nn.functional.interpolate( 173 | pos_tokens, 174 | size=(new_size, new_size), 175 | mode="bicubic", 176 | align_corners=False, 177 | ) 178 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 179 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 180 | checkpoint_model["pos_embed"] = new_pos_embed 181 | -------------------------------------------------------------------------------- /mae/util/resolution_sched.py: -------------------------------------------------------------------------------- 1 | from lib.scheduler import ConstantResolutionScheduler, RandomResolutionScheduler 2 | import numpy as np 3 | import re 4 | 5 | 6 | def get_output_size_scheduler(args): 7 | if args.fixed_output_size_min or args.fixed_output_size_max: 8 | assert ( 9 | args.fixed_output_size_min > 0 10 | and args.fixed_output_size_max > 0 11 | and args.fixed_output_size_max >= args.fixed_output_size_min 12 | ) 13 | output_size_scheduler = RandomResolutionScheduler( 14 | target_size=np.arange( 15 | args.fixed_output_size_min, args.fixed_output_size_max + 1, 16 16 | ) 17 | ) 18 | else: 19 | output_size_scheduler = ConstantResolutionScheduler(target_size=0) 20 | 21 | return output_size_scheduler 22 | 23 | 24 | def get_target_size_scheduler(args): 25 | if args.target_size_scheduler == "random": 26 | target_size_scheduler = RandomResolutionScheduler(args.target_size) 27 | elif args.target_size_scheduler == "constant": 28 | target_size_scheduler = ConstantResolutionScheduler(args.target_size) 29 | else: 30 | match = re.compile("random:([0-9])").findall(args.target_size_scheduler) 31 | if match: 32 | target_size_scheduler = RandomResolutionScheduler( 33 | args.target_size, int(match[0]) 34 | ) 35 | else: 36 | raise NotImplementedError 37 | 38 | return target_size_scheduler 39 | 40 | 41 | def get_source_size_scheduler(args): 42 | if args.source_size_scheduler == "random": 43 | source_size_scheduler = RandomResolutionScheduler(args.source_size) 44 | elif args.source_size_scheduler == "constant": 45 | source_size_scheduler = ConstantResolutionScheduler(args.source_size) 46 | else: 47 | raise NotImplementedError 48 | 49 | return source_size_scheduler 50 | -------------------------------------------------------------------------------- /mae/wandb_log.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import util.misc as misc 5 | import wandb 6 | 7 | # def equalize(x): 8 | # x = (x - x.min()) / (x.max()-x.min()+1e-6) * 255 9 | # x = x.astype(np.uint8) 10 | # r_image, g_image, b_image = cv2.split(x) 11 | # r_image_eq = cv2.equalizeHist(r_image) 12 | # g_image_eq = cv2.equalizeHist(g_image) 13 | # b_image_eq = cv2.equalizeHist(b_image) 14 | 15 | # image_eq = cv2.merge((r_image_eq, g_image_eq, b_image_eq)) 16 | # return image_eq 17 | 18 | 19 | class WANDB_LOG_IMG_CONFIG: 20 | mean = np.zeros(3) 21 | std = np.ones(3) 22 | factor = 1.0 23 | 24 | 25 | def equalize(x): 26 | x = x * WANDB_LOG_IMG_CONFIG.std.reshape( 27 | 1, 1, 3 28 | ) + WANDB_LOG_IMG_CONFIG.mean.reshape(1, 1, 3) 29 | if x.max() > 2.0: 30 | x /= WANDB_LOG_IMG_CONFIG.factor 31 | return x 32 | 33 | 34 | def wandb_dump_input_output(x, ys, epoch=0, texts=""): 35 | """ 36 | x: H X W X C 37 | y: H X W X C 38 | """ 39 | if misc.is_main_process(): 40 | n_imgs = 1 + len(ys) 41 | x = x.numpy() 42 | ys = [y.numpy().astype(float) for y in ys] 43 | ys = [equalize(y) for y in ys] 44 | x = equalize(x) 45 | fig, axes = plt.subplots(1, n_imgs, figsize=(5 * n_imgs, 5)) 46 | if texts: 47 | fig.suptitle(texts) 48 | axes[0].imshow(x) 49 | axes[0].title.set_text(f"({x.shape[0]}, {x.shape[1]})") 50 | for idx, y in enumerate(ys): 51 | axes[1 + idx].imshow(y) 52 | axes[1 + idx].title.set_text(f"({y.shape[0]}, {y.shape[1]})") 53 | wandb.log({"vis": wandb.Image(fig), "epoch": epoch}) 54 | plt.close(fig) 55 | 56 | 57 | def wandb_dump_images(imgs, name="vis", epoch=0): 58 | """ 59 | x: H X W X C 60 | y: H X W X C 61 | """ 62 | if misc.is_main_process(): 63 | n_imgs = len(imgs) 64 | fig, axes = plt.subplots(1, n_imgs, figsize=(5 * n_imgs, 5)) 65 | for idx, img in enumerate(imgs): 66 | axes[idx].imshow(img) 67 | wandb.log({name: wandb.Image(fig), "epoch": epoch}) 68 | plt.close(fig) 69 | 70 | 71 | def compare_pos_embedding(posa, posb, ns=[0]): 72 | """ 73 | posa,posa: N X (L+1) X d_emb 74 | """ 75 | n, l1, d = posa.shape 76 | _, l2, _ = posb.shape 77 | dim1 = int((l1 - 1) ** 0.5) 78 | dim2 = int((l2 - 1) ** 0.5) 79 | idx = [0, d // 4, d // 2, d // 4 * 3, d - 1] 80 | for j in ns: 81 | imgs = [] 82 | for i in idx: 83 | a = posa[j, 1:, i].reshape(dim1, dim1).cpu().numpy() 84 | b = posb[j, 1:, i].reshape(dim2, dim2).cpu().numpy() 85 | imgs.append(a) 86 | imgs.append(b) 87 | wandb_dump_images(imgs) 88 | 89 | 90 | def wandb_log_metadata(metadata): 91 | if misc.is_main_process(): 92 | payload = {} 93 | if "zero_ratio" in metadata: 94 | zero_ratio = metadata.get("zero_ratio") 95 | payload.update( 96 | dict( 97 | zero_ratio_mean=zero_ratio.mean(), 98 | zero_ratio_max=zero_ratio.max(), 99 | zero_ratio_min=zero_ratio.min(), 100 | zero_ratio_std=zero_ratio.std(), 101 | ) 102 | ) 103 | 104 | wandb.log(payload) 105 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0,<64"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.black] 6 | target-version = ["py38", "py39", "py310"] 7 | color = true 8 | skip_magic_trailing_comma = true 9 | 10 | [tool.isort] 11 | profile = "black" 12 | known_first_party = ["docs", "test", "src"] 13 | skip_gitignore = true 14 | color_output = true -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | # https://setuptools.readthedocs.io/en/latest/userguide/declarative_config.html 2 | [metadata] 3 | name=scalemae 4 | version = 0.1.0 5 | author = Colorado J Reed, Ritwik Gupta, Shufan Li 6 | author_email = colorado.j.reed@gmail.com 7 | description = Scale-MAE: A Scale-Aware Masked Autoencoder for Multiscale Geospatial Representation Learning 8 | long_description = file: README.md 9 | long_description_content_type = text/markdown 10 | url = https://github.com/bair-climate-initiative/scale-mae 11 | classifiers = 12 | Development Status :: 3 - Alpha 13 | Intended Audience :: Science/Research 14 | Programming Language :: Python :: 3 15 | Programming Language :: Python :: 3.8 16 | Programming Language :: Python :: 3.9 17 | Programming Language :: Python :: 3.10 18 | Operating System :: OS Independent 19 | Topic :: Scientific/Engineering :: GIS 20 | keywords = remote sensing, representation learning 21 | 22 | [options] 23 | install_requires = 24 | # BSD 3-Clause License 25 | # numpy, version 1.21 needed for datetime64[ns] support 26 | numpy >=1.21.0,<2.0.0 27 | timm == 0.6.12 28 | wandb 29 | 30 | python_requires = ~= 3.9 31 | 32 | packages = find: 33 | 34 | [options.packages.find] 35 | include = mae* 36 | 37 | [options.extras_require] 38 | dev = 39 | # pre-commit, version 2.15 needed for python 3.10 support 40 | pre-commit>=2.15.0,<3.0.0 41 | # black 21.8+ required for Jupyter support 42 | black[jupyter]>=21.8,<23 43 | # flake8 3.8+ depends on pyflakes 2.2+, which fixes a bug with mypy error code ignores: 44 | # https://github.com/PyCQA/pyflakes/pull/455 45 | flake8>=3.8,<5 46 | # isort 5.8+ required for extend_skip option 47 | isort[colors]>=5.8,<6 48 | # pydocstyle 6.1+ required for pyproject.toml support 49 | pydocstyle[toml]>=6.1,<7 50 | # pyupgrade 1.24+ required for --py37-plus flag 51 | pyupgrade>=1.24,<3 52 | # pytest 6.1.2+ required by nbmake 53 | pytest>=6.1.2,<8 54 | # pytest-cov 2.4+ required for pytest --cov flags 55 | pytest-cov>=2.4,<4 56 | # for mypy 57 | types-requests>=2.28.9,<2.30.0 58 | types-python-dateutil>=2.8.19,<2.9.0 59 | 60 | 61 | [flake8] 62 | max-line-length = 100 63 | extend-ignore = 64 | # See https://github.com/PyCQA/pycodestyle/issues/373 65 | E203, 66 | exclude = 67 | # Source 68 | data/, 69 | images 70 | logs/, 71 | output/, 72 | 73 | # Python 74 | build/, 75 | dist/, 76 | .cache/, 77 | .mypy_cache/, 78 | .pytest_cache/, 79 | __pycache__/, 80 | *.egg-info/, 81 | 82 | # Git 83 | .git/, 84 | .github/ 85 | --------------------------------------------------------------------------------