├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── config.py ├── configs ├── swin_base__100ep │ ├── simmim_finetune__swin_base__img192_window6__100ep.yaml │ ├── simmim_finetune__swin_base__img224_window7__100ep.yaml │ └── simmim_pretrain__swin_base__img192_window6__100ep.yaml ├── swin_base__800ep │ ├── simmim_finetune__swin_base__img224_window7__800ep.yaml │ └── simmim_pretrain__swin_base__img192_window6__800ep.yaml ├── swin_large__800ep │ ├── simmim_finetune__swin_large__img224_window14__800ep.yaml │ └── simmim_pretrain__swin_large__img192_window12__800ep.yaml └── vit_base__800ep │ ├── simmim_finetune__vit_base__img224__800ep.yaml │ └── simmim_pretrain__vit_base__img224__800ep.yaml ├── data ├── __init__.py ├── data_finetune.py └── data_simmim.py ├── figures └── teaser.jpg ├── logger.py ├── lr_scheduler.py ├── main_finetune.py ├── main_simmim.py ├── models ├── __init__.py ├── build.py ├── simmim.py ├── swin_transformer.py └── vision_transformer.py ├── optimizer.py ├── requirements.txt └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SimMIM 2 | 3 | By [Zhenda Xie](https://zdaxie.github.io)\*, [Zheng Zhang](https://stupidzz.github.io/)\*, [Yue Cao](http://yue-cao.me)\*, [Yutong Lin](https://github.com/impiga), [Jianmin Bao](https://jianminbao.github.io/), [Zhuliang Yao](https://github.com/Howal), [Qi Dai](https://www.microsoft.com/en-us/research/people/qid/) and [Han Hu](https://ancientmooner.github.io/)\*. 4 | 5 | This repo is the official implementation of ["SimMIM: A Simple Framework for Masked Image Modeling"](https://arxiv.org/abs/2111.09886). 6 | 7 | ## Updates 8 | 9 | ***09/29/2022*** 10 | 11 | SimMIM was merged to [Swin Transformer repo on GitHub](https://github.com/microsoft/Swin-Transformer). 12 | 13 | ***03/02/2022*** 14 | 15 | SimMIM got accepted by CVPR 2022. SimMIM was used in ["Swin Transformer V2"](https://github.com/microsoft/Swin-Transformer) to alleviate the data hungry problem for large-scale vision model training. 16 | 17 | ***12/09/2021*** 18 | 19 | Initial commits: 20 | 21 | 1. Pre-trained and fine-tuned models on ImageNet-1K (`Swin Base`, `Swin Large`, and `ViT Base`) are provided. 22 | 2. The supported code for ImageNet-1K pre-training and fine-tuneing is provided. 23 | 24 | ## Introduction 25 | 26 | **SimMIM** is initially described in [arxiv](https://arxiv.org/abs/2111.09886), which serves as a 27 | simple framework for masked image modeling. From systematically study, we find that simple designs of each component have revealed very strong representation learning performance: 1) random masking of the input image with a moderately large masked patch size (e.g., 32) makes a strong pre-text task; 2) predicting raw pixels of RGB values by direct regression performs no worse than the patch classification approaches with complex designs; 3) the prediction head can be as light as a linear layer, with no worse performance than heavier ones. 28 | 29 |
30 | 31 |
32 | 33 | ## Main Results on ImageNet 34 | 35 | ### Swin Transformer 36 | 37 | **ImageNet-1K Pre-trained and Fine-tuned Models** 38 | 39 | | name | pre-train epochs | pre-train resolution | fine-tune resolution | acc@1 | pre-trained model | fine-tuned model | 40 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | 41 | | Swin-Base | 100 | 192x192 | 192x192 | 82.8 | [google](https://drive.google.com/file/d/1Wcbr66JL26FF30Kip9fZa_0lXrDAKP-d/view?usp=sharing)/[config](configs/swin_base__100ep/simmim_pretrain__swin_base__img192_window6__100ep.yaml) | [google](https://drive.google.com/file/d/1RsgHfjB4B1ZYblXEQVT-FPX3WSvBrxcs/view?usp=sharing)/[config](configs/swin_base__100ep/simmim_finetune__swin_base__img192_window6__100ep.yaml) | 42 | | Swin-Base | 100 | 192x192 | 224x224 | 83.5 | [google](https://drive.google.com/file/d/1Wcbr66JL26FF30Kip9fZa_0lXrDAKP-d/view?usp=sharing)/[config](configs/swin_base__100ep/simmim_pretrain__swin_base__img192_window6__100ep.yaml) | [google](https://drive.google.com/file/d/1mb43BkW56F5smwiX-g7QUUD7f1Rftq8u/view?usp=sharing)/[config](configs/swin_base__100ep/simmim_finetune__swin_base__img224_window7__100ep.yaml) | 43 | | Swin-Base | 800 | 192x192 | 224x224 | 84.0 | [google](https://drive.google.com/file/d/15zENvGjHlM71uKQ3d2FbljWPubtrPtjl/view?usp=sharing)/[config](configs/swin_base__800ep/simmim_pretrain__swin_base__img192_window6__800ep.yaml) | [google](https://drive.google.com/file/d/1xEKyfMTsdh6TfnYhk5vbw0Yz7a-viZ0w/view?usp=sharing)/[config](configs/swin_base__800ep/simmim_finetune__swin_base__img224_window7__800ep.yaml) | 44 | | Swin-Large | 800 | 192x192 | 224x224 | 85.4 | [google](https://drive.google.com/file/d/1qDxrTl2YUDB0505_4QrU5LU2R1kKmcBP/view?usp=sharing)/[config](configs/swin_large__800ep/simmim_pretrain__swin_large__img192_window12__800ep.yaml) | [google](https://drive.google.com/file/d/1mf0ZpXttEvFsH87Www4oQ-t8Kwr0x485/view?usp=sharing)/[config](configs/swin_large__800ep/simmim_finetune__swin_large__img224_window14__800ep.yaml) | 45 | | SwinV2-Huge | 800 | 192x192 | 224x224 | 85.7 | / | / | 46 | | SwinV2-Huge | 800 | 192x192 | 512x512 | 87.1 | / | / | 47 | 48 | ### Vision Transformer 49 | 50 | **ImageNet-1K Pre-trained and Fine-tuned Models** 51 | 52 | | name | pre-train epochs | pre-train resolution | fine-tune resolution | acc@1 | pre-trained model | fine-tuned model | 53 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | 54 | | ViT-Base | 800 | 224x224 | 224x224 | 83.8 | [google](https://drive.google.com/file/d/1dJn6GYkwMIcoP3zqOEyW1_iQfpBi8UOw/view?usp=sharing)/[config](configs/vit_base__800ep/simmim_pretrain__vit_base__img224__800ep.yaml) | [google](https://drive.google.com/file/d/1fKgDYd0tRgyHyTnyB1CleYxjo0Gn5tEB/view?usp=sharing)/[config](configs/vit_base__800ep/simmim_finetune__vit_base__img224__800ep.yaml) | 55 | 56 | ## Citing SimMIM 57 | 58 | ``` 59 | @inproceedings{xie2021simmim, 60 | title={SimMIM: A Simple Framework for Masked Image Modeling}, 61 | author={Xie, Zhenda and Zhang, Zheng and Cao, Yue and Lin, Yutong and Bao, Jianmin and Yao, Zhuliang and Dai, Qi and Hu, Han}, 62 | booktitle={International Conference on Computer Vision and Pattern Recognition (CVPR)}, 63 | year={2022} 64 | } 65 | ``` 66 | 67 | ## Getting Started 68 | 69 | ### Installation 70 | 71 | - Install `CUDA 11.3` with `cuDNN 8` following the official installation guide of [CUDA](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html) and [cuDNN](https://developer.nvidia.com/rdp/cudnn-archive). 72 | 73 | - Setup conda environment: 74 | ```bash 75 | # Create environment 76 | conda create -n SimMIM python=3.8 -y 77 | conda activate SimMIM 78 | 79 | # Install requirements 80 | conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch -y 81 | 82 | # Install apex 83 | git clone https://github.com/NVIDIA/apex 84 | cd apex 85 | pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ 86 | cd .. 87 | 88 | # Clone SimMIM 89 | git clone https://github.com/microsoft/SimMIM 90 | cd SimMIM 91 | 92 | # Install other requirements 93 | pip install -r requirements.txt 94 | ``` 95 | 96 | ### Evaluating provided models 97 | 98 | To evaluate a provided model on ImageNet validation set, run: 99 | ```bash 100 | python -m torch.distributed.launch --nproc_per_node main_finetune.py \ 101 | --eval --cfg --resume --data-path 102 | ``` 103 | 104 | For example, to evaluate the `Swin Base` model on a single GPU, run: 105 | ```bash 106 | python -m torch.distributed.launch --nproc_per_node 1 main_finetune.py \ 107 | --eval --cfg configs/swin_base__800ep/simmim_finetune__swin_base__img224_window7__800ep.yaml --resume simmim_finetune__swin_base__img224_window7__800ep.pth --data-path 108 | ``` 109 | 110 | ### Pre-training with SimMIM 111 | To pre-train models with `SimMIM`, run: 112 | ```bash 113 | python -m torch.distributed.launch --nproc_per_node main_simmim.py \ 114 | --cfg --data-path /train [--batch-size --output --tag ] 115 | ``` 116 | 117 | For example, to pre-train `Swin Base` for 800 epochs on one DGX-2 server, run: 118 | ```bash 119 | python -m torch.distributed.launch --nproc_per_node 16 main_simmim.py \ 120 | --cfg configs/swin_base__800ep/simmim_pretrain__swin_base__img192_window6__800ep.yaml --batch-size 128 --data-path /train [--output --tag ] 121 | ``` 122 | 123 | ### Fine-tuning pre-trained models 124 | To fine-tune models pre-trained by `SimMIM`, run: 125 | ```bash 126 | python -m torch.distributed.launch --nproc_per_node main_finetune.py \ 127 | --cfg --data-path --pretrained [--batch-size --output --tag ] 128 | ``` 129 | 130 | For example, to fine-tune `Swin Base` pre-trained by `SimMIM` on one DGX-2 server, run: 131 | ```bash 132 | python -m torch.distributed.launch --nproc_per_node 16 main_finetune.py \ 133 | --cfg configs/swin_base__800ep/simmim_finetune__swin_base__img224_window7__800ep.yaml --batch-size 128 --data-path --pretrained [--output --tag ] 134 | ``` 135 | 136 | ## Contributing 137 | 138 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 139 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 140 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 141 | 142 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 143 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 144 | provided by the bot. You will only need to do this once across all repos using our CLA. 145 | 146 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 147 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 148 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 149 | 150 | ## Trademarks 151 | 152 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft 153 | trademarks or logos is subject to and must follow 154 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). 155 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. 156 | Any use of third-party trademarks or logos are subject to those third-party's policies. 157 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). 40 | 41 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/spot](https://aka.ms/spot). CSS will work with/help you to determine next steps. More details also available at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). 7 | - **Not sure?** Fill out a SPOT intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # SimMIM 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # Modified by Zhenda Xie 7 | # -------------------------------------------------------- 8 | 9 | import os 10 | import yaml 11 | from yacs.config import CfgNode as CN 12 | 13 | _C = CN() 14 | 15 | # Base config files 16 | _C.BASE = [''] 17 | 18 | # ----------------------------------------------------------------------------- 19 | # Data settings 20 | # ----------------------------------------------------------------------------- 21 | _C.DATA = CN() 22 | # Batch size for a single GPU, could be overwritten by command line argument 23 | _C.DATA.BATCH_SIZE = 128 24 | # Path to dataset, could be overwritten by command line argument 25 | _C.DATA.DATA_PATH = '' 26 | # Dataset name 27 | _C.DATA.DATASET = 'imagenet' 28 | # Input image size 29 | _C.DATA.IMG_SIZE = 224 30 | # Interpolation to resize image (random, bilinear, bicubic) 31 | _C.DATA.INTERPOLATION = 'bicubic' 32 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU. 33 | _C.DATA.PIN_MEMORY = True 34 | # Number of data loading threads 35 | _C.DATA.NUM_WORKERS = 8 36 | 37 | # [SimMIM] Mask patch size for MaskGenerator 38 | _C.DATA.MASK_PATCH_SIZE = 32 39 | # [SimMIM] Mask ratio for MaskGenerator 40 | _C.DATA.MASK_RATIO = 0.6 41 | 42 | # ----------------------------------------------------------------------------- 43 | # Model settings 44 | # ----------------------------------------------------------------------------- 45 | _C.MODEL = CN() 46 | # Model type 47 | _C.MODEL.TYPE = 'swin' 48 | # Model name 49 | _C.MODEL.NAME = 'swin_tiny_patch4_window7_224' 50 | # Checkpoint to resume, could be overwritten by command line argument 51 | _C.MODEL.RESUME = '' 52 | # Number of classes, overwritten in data preparation 53 | _C.MODEL.NUM_CLASSES = 1000 54 | # Dropout rate 55 | _C.MODEL.DROP_RATE = 0.0 56 | # Drop path rate 57 | _C.MODEL.DROP_PATH_RATE = 0.1 58 | # Label Smoothing 59 | _C.MODEL.LABEL_SMOOTHING = 0.1 60 | 61 | # Swin Transformer parameters 62 | _C.MODEL.SWIN = CN() 63 | _C.MODEL.SWIN.PATCH_SIZE = 4 64 | _C.MODEL.SWIN.IN_CHANS = 3 65 | _C.MODEL.SWIN.EMBED_DIM = 96 66 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] 67 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] 68 | _C.MODEL.SWIN.WINDOW_SIZE = 7 69 | _C.MODEL.SWIN.MLP_RATIO = 4. 70 | _C.MODEL.SWIN.QKV_BIAS = True 71 | _C.MODEL.SWIN.QK_SCALE = None 72 | _C.MODEL.SWIN.APE = False 73 | _C.MODEL.SWIN.PATCH_NORM = True 74 | 75 | # Vision Transformer parameters 76 | _C.MODEL.VIT = CN() 77 | _C.MODEL.VIT.PATCH_SIZE = 16 78 | _C.MODEL.VIT.IN_CHANS = 3 79 | _C.MODEL.VIT.EMBED_DIM = 768 80 | _C.MODEL.VIT.DEPTH = 12 81 | _C.MODEL.VIT.NUM_HEADS = 12 82 | _C.MODEL.VIT.MLP_RATIO = 4 83 | _C.MODEL.VIT.QKV_BIAS = True 84 | _C.MODEL.VIT.INIT_VALUES = 0.1 85 | _C.MODEL.VIT.USE_APE = False 86 | _C.MODEL.VIT.USE_RPB = False 87 | _C.MODEL.VIT.USE_SHARED_RPB = True 88 | _C.MODEL.VIT.USE_MEAN_POOLING = False 89 | 90 | # ----------------------------------------------------------------------------- 91 | # Training settings 92 | # ----------------------------------------------------------------------------- 93 | _C.TRAIN = CN() 94 | _C.TRAIN.START_EPOCH = 0 95 | _C.TRAIN.EPOCHS = 300 96 | _C.TRAIN.WARMUP_EPOCHS = 20 97 | _C.TRAIN.WEIGHT_DECAY = 0.05 98 | _C.TRAIN.BASE_LR = 5e-4 99 | _C.TRAIN.WARMUP_LR = 5e-7 100 | _C.TRAIN.MIN_LR = 5e-6 101 | # Clip gradient norm 102 | _C.TRAIN.CLIP_GRAD = 5.0 103 | # Auto resume from latest checkpoint 104 | _C.TRAIN.AUTO_RESUME = True 105 | # Gradient accumulation steps 106 | # could be overwritten by command line argument 107 | _C.TRAIN.ACCUMULATION_STEPS = 0 108 | # Whether to use gradient checkpointing to save memory 109 | # could be overwritten by command line argument 110 | _C.TRAIN.USE_CHECKPOINT = False 111 | 112 | # LR scheduler 113 | _C.TRAIN.LR_SCHEDULER = CN() 114 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine' 115 | # Epoch interval to decay LR, used in StepLRScheduler 116 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 117 | # LR decay rate, used in StepLRScheduler 118 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 119 | # Gamma / Multi steps value, used in MultiStepLRScheduler 120 | _C.TRAIN.LR_SCHEDULER.GAMMA = 0.1 121 | _C.TRAIN.LR_SCHEDULER.MULTISTEPS = [] 122 | 123 | # Optimizer 124 | _C.TRAIN.OPTIMIZER = CN() 125 | _C.TRAIN.OPTIMIZER.NAME = 'adamw' 126 | # Optimizer Epsilon 127 | _C.TRAIN.OPTIMIZER.EPS = 1e-8 128 | # Optimizer Betas 129 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) 130 | # SGD momentum 131 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 132 | 133 | # [SimMIM] Layer decay for fine-tuning 134 | _C.TRAIN.LAYER_DECAY = 1.0 135 | 136 | # ----------------------------------------------------------------------------- 137 | # Augmentation settings 138 | # ----------------------------------------------------------------------------- 139 | _C.AUG = CN() 140 | # Color jitter factor 141 | _C.AUG.COLOR_JITTER = 0.4 142 | # Use AutoAugment policy. "v0" or "original" 143 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' 144 | # Random erase prob 145 | _C.AUG.REPROB = 0.25 146 | # Random erase mode 147 | _C.AUG.REMODE = 'pixel' 148 | # Random erase count 149 | _C.AUG.RECOUNT = 1 150 | # Mixup alpha, mixup enabled if > 0 151 | _C.AUG.MIXUP = 0.8 152 | # Cutmix alpha, cutmix enabled if > 0 153 | _C.AUG.CUTMIX = 1.0 154 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set 155 | _C.AUG.CUTMIX_MINMAX = None 156 | # Probability of performing mixup or cutmix when either/both is enabled 157 | _C.AUG.MIXUP_PROB = 1.0 158 | # Probability of switching to cutmix when both mixup and cutmix enabled 159 | _C.AUG.MIXUP_SWITCH_PROB = 0.5 160 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem" 161 | _C.AUG.MIXUP_MODE = 'batch' 162 | 163 | # ----------------------------------------------------------------------------- 164 | # Testing settings 165 | # ----------------------------------------------------------------------------- 166 | _C.TEST = CN() 167 | # Whether to use center crop when testing 168 | _C.TEST.CROP = True 169 | 170 | # ----------------------------------------------------------------------------- 171 | # Misc 172 | # ----------------------------------------------------------------------------- 173 | # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2') 174 | # overwritten by command line argument 175 | _C.AMP_OPT_LEVEL = '' 176 | # Path to output folder, overwritten by command line argument 177 | _C.OUTPUT = '' 178 | # Tag of experiment, overwritten by command line argument 179 | _C.TAG = 'default' 180 | # Frequency to save checkpoint 181 | _C.SAVE_FREQ = 1 182 | # Frequency to logging info 183 | _C.PRINT_FREQ = 10 184 | # Fixed random seed 185 | _C.SEED = 0 186 | # Perform evaluation only, overwritten by command line argument 187 | _C.EVAL_MODE = False 188 | # Test throughput only, overwritten by command line argument 189 | _C.THROUGHPUT_MODE = False 190 | # local rank for DistributedDataParallel, given by command line argument 191 | _C.LOCAL_RANK = 0 192 | 193 | # [SimMIM] path to pre-trained model 194 | _C.PRETRAINED = '' 195 | 196 | 197 | def _update_config_from_file(config, cfg_file): 198 | config.defrost() 199 | with open(cfg_file, 'r') as f: 200 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) 201 | 202 | for cfg in yaml_cfg.setdefault('BASE', ['']): 203 | if cfg: 204 | _update_config_from_file( 205 | config, os.path.join(os.path.dirname(cfg_file), cfg) 206 | ) 207 | print('=> merge config from {}'.format(cfg_file)) 208 | config.merge_from_file(cfg_file) 209 | config.freeze() 210 | 211 | 212 | def update_config(config, args): 213 | _update_config_from_file(config, args.cfg) 214 | 215 | config.defrost() 216 | if args.opts: 217 | config.merge_from_list(args.opts) 218 | 219 | def _check_args(name): 220 | if hasattr(args, name) and eval(f'args.{name}'): 221 | return True 222 | return False 223 | 224 | # merge from specific arguments 225 | if _check_args('batch_size'): 226 | config.DATA.BATCH_SIZE = args.batch_size 227 | if _check_args('data_path'): 228 | config.DATA.DATA_PATH = args.data_path 229 | if _check_args('resume'): 230 | config.MODEL.RESUME = args.resume 231 | if _check_args('pretrained'): 232 | config.PRETRAINED = args.pretrained 233 | if _check_args('accumulation_steps'): 234 | config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps 235 | if _check_args('use_checkpoint'): 236 | config.TRAIN.USE_CHECKPOINT = True 237 | if _check_args('amp_opt_level'): 238 | config.AMP_OPT_LEVEL = args.amp_opt_level 239 | if _check_args('output'): 240 | config.OUTPUT = args.output 241 | if _check_args('tag'): 242 | config.TAG = args.tag 243 | if _check_args('eval'): 244 | config.EVAL_MODE = True 245 | if _check_args('throughput'): 246 | config.THROUGHPUT_MODE = True 247 | 248 | # set local rank for distributed training 249 | config.LOCAL_RANK = args.local_rank 250 | 251 | # output folder 252 | config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG) 253 | 254 | config.freeze() 255 | 256 | 257 | def get_config(args): 258 | """Get a yacs CfgNode object with default values.""" 259 | # Return a clone so that the defaults will not be altered 260 | # This is for the "local variable" use pattern 261 | config = _C.clone() 262 | update_config(config, args) 263 | 264 | return config 265 | -------------------------------------------------------------------------------- /configs/swin_base__100ep/simmim_finetune__swin_base__img192_window6__100ep.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: simmim_finetune 4 | DROP_PATH_RATE: 0.1 5 | SWIN: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 4, 8, 16, 32 ] 9 | WINDOW_SIZE: 6 10 | DATA: 11 | IMG_SIZE: 192 12 | TRAIN: 13 | EPOCHS: 100 14 | WARMUP_EPOCHS: 20 15 | BASE_LR: 1.25e-3 16 | WARMUP_LR: 2.5e-7 17 | MIN_LR: 2.5e-7 18 | WEIGHT_DECAY: 0.05 19 | LAYER_DECAY: 0.9 20 | PRINT_FREQ: 100 21 | SAVE_FREQ: 5 22 | TAG: simmim_finetune__swin_base__img192_window6__100ep -------------------------------------------------------------------------------- /configs/swin_base__100ep/simmim_finetune__swin_base__img224_window7__100ep.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: simmim_finetune 4 | DROP_PATH_RATE: 0.1 5 | SWIN: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 4, 8, 16, 32 ] 9 | WINDOW_SIZE: 7 10 | DATA: 11 | IMG_SIZE: 224 12 | TRAIN: 13 | EPOCHS: 100 14 | WARMUP_EPOCHS: 20 15 | BASE_LR: 1.25e-3 16 | WARMUP_LR: 2.5e-7 17 | MIN_LR: 2.5e-7 18 | WEIGHT_DECAY: 0.05 19 | LAYER_DECAY: 0.9 20 | PRINT_FREQ: 100 21 | SAVE_FREQ: 5 22 | TAG: simmim_finetune__swin_base__img224_window7__100ep -------------------------------------------------------------------------------- /configs/swin_base__100ep/simmim_pretrain__swin_base__img192_window6__100ep.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: simmim_pretrain 4 | DROP_PATH_RATE: 0.0 5 | SWIN: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 4, 8, 16, 32 ] 9 | WINDOW_SIZE: 6 10 | DATA: 11 | IMG_SIZE: 192 12 | MASK_PATCH_SIZE: 32 13 | MASK_RATIO: 0.6 14 | TRAIN: 15 | EPOCHS: 100 16 | WARMUP_EPOCHS: 10 17 | BASE_LR: 2e-4 18 | WARMUP_LR: 1e-6 19 | MIN_LR: 1e-5 20 | WEIGHT_DECAY: 0.05 21 | PRINT_FREQ: 100 22 | SAVE_FREQ: 5 23 | TAG: simmim_pretrain__swin_base__img192_window6__100ep -------------------------------------------------------------------------------- /configs/swin_base__800ep/simmim_finetune__swin_base__img224_window7__800ep.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: simmim_finetune 4 | DROP_PATH_RATE: 0.1 5 | SWIN: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 4, 8, 16, 32 ] 9 | WINDOW_SIZE: 7 10 | DATA: 11 | IMG_SIZE: 224 12 | TRAIN: 13 | EPOCHS: 100 14 | WARMUP_EPOCHS: 20 15 | BASE_LR: 1.25e-3 16 | WARMUP_LR: 2.5e-7 17 | MIN_LR: 2.5e-7 18 | WEIGHT_DECAY: 0.05 19 | LAYER_DECAY: 0.8 20 | PRINT_FREQ: 100 21 | SAVE_FREQ: 5 22 | TAG: simmim_finetune__swin_base__img224_window7__800ep -------------------------------------------------------------------------------- /configs/swin_base__800ep/simmim_pretrain__swin_base__img192_window6__800ep.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: simmim_pretrain 4 | DROP_PATH_RATE: 0.0 5 | SWIN: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 4, 8, 16, 32 ] 9 | WINDOW_SIZE: 6 10 | DATA: 11 | IMG_SIZE: 192 12 | MASK_PATCH_SIZE: 32 13 | MASK_RATIO: 0.6 14 | TRAIN: 15 | EPOCHS: 800 16 | WARMUP_EPOCHS: 10 17 | BASE_LR: 1e-4 18 | WARMUP_LR: 5e-7 19 | WEIGHT_DECAY: 0.05 20 | LR_SCHEDULER: 21 | NAME: 'multistep' 22 | GAMMA: 0.1 23 | MULTISTEPS: [700,] 24 | PRINT_FREQ: 100 25 | SAVE_FREQ: 5 26 | TAG: simmim_pretrain__swin_base__img192_window6__800ep -------------------------------------------------------------------------------- /configs/swin_large__800ep/simmim_finetune__swin_large__img224_window14__800ep.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: simmim_finetune 4 | DROP_PATH_RATE: 0.2 5 | SWIN: 6 | EMBED_DIM: 192 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 6, 12, 24, 48 ] 9 | WINDOW_SIZE: 14 10 | DATA: 11 | IMG_SIZE: 224 12 | TRAIN: 13 | EPOCHS: 100 14 | WARMUP_EPOCHS: 20 15 | BASE_LR: 1.25e-3 16 | WARMUP_LR: 2.5e-7 17 | MIN_LR: 2.5e-7 18 | WEIGHT_DECAY: 0.05 19 | LAYER_DECAY: 0.7 20 | PRINT_FREQ: 100 21 | SAVE_FREQ: 5 22 | TAG: simmim_finetune__swin_large__img224_window14__800ep -------------------------------------------------------------------------------- /configs/swin_large__800ep/simmim_pretrain__swin_large__img192_window12__800ep.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: simmim_pretrain 4 | DROP_PATH_RATE: 0.0 5 | SWIN: 6 | EMBED_DIM: 192 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 6, 12, 24, 48 ] 9 | WINDOW_SIZE: 12 10 | DATA: 11 | IMG_SIZE: 192 12 | MASK_PATCH_SIZE: 32 13 | MASK_RATIO: 0.6 14 | TRAIN: 15 | EPOCHS: 800 16 | WARMUP_EPOCHS: 10 17 | BASE_LR: 1e-4 18 | WARMUP_LR: 5e-7 19 | WEIGHT_DECAY: 0.05 20 | LR_SCHEDULER: 21 | NAME: 'multistep' 22 | GAMMA: 0.1 23 | MULTISTEPS: [700,] 24 | PRINT_FREQ: 100 25 | SAVE_FREQ: 5 26 | TAG: simmim_pretrain__swin_large__img192_window12__800ep -------------------------------------------------------------------------------- /configs/vit_base__800ep/simmim_finetune__vit_base__img224__800ep.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vit 3 | NAME: simmim_finetune 4 | DROP_PATH_RATE: 0.1 5 | VIT: 6 | EMBED_DIM: 768 7 | DEPTH: 12 8 | NUM_HEADS: 12 9 | USE_APE: False 10 | USE_RPB: True 11 | USE_SHARED_RPB: False 12 | USE_MEAN_POOLING: True 13 | DATA: 14 | IMG_SIZE: 224 15 | TRAIN: 16 | EPOCHS: 100 17 | WARMUP_EPOCHS: 20 18 | BASE_LR: 1.25e-3 19 | WARMUP_LR: 2.5e-7 20 | MIN_LR: 2.5e-7 21 | WEIGHT_DECAY: 0.05 22 | LAYER_DECAY: 0.65 23 | PRINT_FREQ: 100 24 | SAVE_FREQ: 5 25 | TAG: simmim_finetune__vit_base__img224__800ep 26 | -------------------------------------------------------------------------------- /configs/vit_base__800ep/simmim_pretrain__vit_base__img224__800ep.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vit 3 | NAME: simmim_pretrain 4 | DROP_PATH_RATE: 0.1 5 | VIT: 6 | EMBED_DIM: 768 7 | DEPTH: 12 8 | NUM_HEADS: 12 9 | USE_APE: False 10 | USE_RPB: False 11 | USE_SHARED_RPB: True 12 | USE_MEAN_POOLING: False 13 | DATA: 14 | IMG_SIZE: 224 15 | MASK_PATCH_SIZE: 32 16 | MASK_RATIO: 0.6 17 | TRAIN: 18 | EPOCHS: 800 19 | WARMUP_EPOCHS: 10 20 | BASE_LR: 1e-4 21 | WARMUP_LR: 5e-7 22 | WEIGHT_DECAY: 0.05 23 | LR_SCHEDULER: 24 | NAME: 'multistep' 25 | GAMMA: 0.1 26 | MULTISTEPS: [700,] 27 | PRINT_FREQ: 100 28 | SAVE_FREQ: 5 29 | TAG: simmim_pretrain__vit_base__img224__800ep 30 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_simmim import build_loader_simmim 2 | from .data_finetune import build_loader_finetune 3 | 4 | def build_loader(config, logger, is_pretrain): 5 | if is_pretrain: 6 | return build_loader_simmim(config, logger) 7 | else: 8 | return build_loader_finetune(config, logger) -------------------------------------------------------------------------------- /data/data_finetune.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # SimMIM 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Zhenda Xie 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import torch.distributed as dist 10 | from torch.utils.data import DataLoader, DistributedSampler 11 | from torchvision import datasets, transforms 12 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 13 | from timm.data import Mixup 14 | from timm.data import create_transform 15 | from timm.data.transforms import _pil_interp 16 | 17 | 18 | def build_loader_finetune(config, logger): 19 | config.defrost() 20 | dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config, logger=logger) 21 | config.freeze() 22 | dataset_val, _ = build_dataset(is_train=False, config=config, logger=logger) 23 | logger.info(f"Build dataset: train images = {len(dataset_train)}, val images = {len(dataset_val)}") 24 | 25 | num_tasks = dist.get_world_size() 26 | global_rank = dist.get_rank() 27 | sampler_train = DistributedSampler( 28 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 29 | ) 30 | sampler_val = DistributedSampler( 31 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False 32 | ) 33 | 34 | data_loader_train = DataLoader( 35 | dataset_train, sampler=sampler_train, 36 | batch_size=config.DATA.BATCH_SIZE, 37 | num_workers=config.DATA.NUM_WORKERS, 38 | pin_memory=config.DATA.PIN_MEMORY, 39 | drop_last=True, 40 | ) 41 | 42 | data_loader_val = DataLoader( 43 | dataset_val, sampler=sampler_val, 44 | batch_size=config.DATA.BATCH_SIZE, 45 | num_workers=config.DATA.NUM_WORKERS, 46 | pin_memory=config.DATA.PIN_MEMORY, 47 | drop_last=False, 48 | ) 49 | 50 | # setup mixup / cutmix 51 | mixup_fn = None 52 | mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None 53 | if mixup_active: 54 | mixup_fn = Mixup( 55 | mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX, 56 | prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE, 57 | label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES) 58 | 59 | return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn 60 | 61 | 62 | def build_dataset(is_train, config, logger): 63 | transform = build_transform(is_train, config) 64 | logger.info(f'Fine-tune data transform, is_train={is_train}:\n{transform}') 65 | 66 | if config.DATA.DATASET == 'imagenet': 67 | prefix = 'train' if is_train else 'val' 68 | root = os.path.join(config.DATA.DATA_PATH, prefix) 69 | dataset = datasets.ImageFolder(root, transform=transform) 70 | nb_classes = 1000 71 | else: 72 | raise NotImplementedError("We only support ImageNet Now.") 73 | 74 | return dataset, nb_classes 75 | 76 | 77 | def build_transform(is_train, config): 78 | resize_im = config.DATA.IMG_SIZE > 32 79 | if is_train: 80 | # this should always dispatch to transforms_imagenet_train 81 | transform = create_transform( 82 | input_size=config.DATA.IMG_SIZE, 83 | is_training=True, 84 | color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None, 85 | auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None, 86 | re_prob=config.AUG.REPROB, 87 | re_mode=config.AUG.REMODE, 88 | re_count=config.AUG.RECOUNT, 89 | interpolation=config.DATA.INTERPOLATION, 90 | ) 91 | if not resize_im: 92 | # replace RandomResizedCropAndInterpolation with 93 | # RandomCrop 94 | transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4) 95 | return transform 96 | 97 | t = [] 98 | if resize_im: 99 | if config.TEST.CROP: 100 | size = int((256 / 224) * config.DATA.IMG_SIZE) 101 | t.append( 102 | transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)), 103 | # to maintain same ratio w.r.t. 224 images 104 | ) 105 | t.append(transforms.CenterCrop(config.DATA.IMG_SIZE)) 106 | else: 107 | t.append( 108 | transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE), 109 | interpolation=_pil_interp(config.DATA.INTERPOLATION)) 110 | ) 111 | 112 | t.append(transforms.ToTensor()) 113 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 114 | return transforms.Compose(t) -------------------------------------------------------------------------------- /data/data_simmim.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # SimMIM 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Zhenda Xie 6 | # -------------------------------------------------------- 7 | 8 | import math 9 | import random 10 | import numpy as np 11 | 12 | import torch 13 | import torch.distributed as dist 14 | import torchvision.transforms as T 15 | from torch.utils.data import DataLoader, DistributedSampler 16 | from torch.utils.data._utils.collate import default_collate 17 | from torchvision.datasets import ImageFolder 18 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 19 | 20 | 21 | class MaskGenerator: 22 | def __init__(self, input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6): 23 | self.input_size = input_size 24 | self.mask_patch_size = mask_patch_size 25 | self.model_patch_size = model_patch_size 26 | self.mask_ratio = mask_ratio 27 | 28 | assert self.input_size % self.mask_patch_size == 0 29 | assert self.mask_patch_size % self.model_patch_size == 0 30 | 31 | self.rand_size = self.input_size // self.mask_patch_size 32 | self.scale = self.mask_patch_size // self.model_patch_size 33 | 34 | self.token_count = self.rand_size ** 2 35 | self.mask_count = int(np.ceil(self.token_count * self.mask_ratio)) 36 | 37 | def __call__(self): 38 | mask_idx = np.random.permutation(self.token_count)[:self.mask_count] 39 | mask = np.zeros(self.token_count, dtype=int) 40 | mask[mask_idx] = 1 41 | 42 | mask = mask.reshape((self.rand_size, self.rand_size)) 43 | mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1) 44 | 45 | return mask 46 | 47 | 48 | class SimMIMTransform: 49 | def __init__(self, config): 50 | self.transform_img = T.Compose([ 51 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), 52 | T.RandomResizedCrop(config.DATA.IMG_SIZE, scale=(0.67, 1.), ratio=(3. / 4., 4. / 3.)), 53 | T.RandomHorizontalFlip(), 54 | T.ToTensor(), 55 | T.Normalize(mean=torch.tensor(IMAGENET_DEFAULT_MEAN),std=torch.tensor(IMAGENET_DEFAULT_STD)), 56 | ]) 57 | 58 | if config.MODEL.TYPE == 'swin': 59 | model_patch_size=config.MODEL.SWIN.PATCH_SIZE 60 | elif config.MODEL.TYPE == 'vit': 61 | model_patch_size=config.MODEL.VIT.PATCH_SIZE 62 | else: 63 | raise NotImplementedError 64 | 65 | self.mask_generator = MaskGenerator( 66 | input_size=config.DATA.IMG_SIZE, 67 | mask_patch_size=config.DATA.MASK_PATCH_SIZE, 68 | model_patch_size=model_patch_size, 69 | mask_ratio=config.DATA.MASK_RATIO, 70 | ) 71 | 72 | def __call__(self, img): 73 | img = self.transform_img(img) 74 | mask = self.mask_generator() 75 | 76 | return img, mask 77 | 78 | 79 | def collate_fn(batch): 80 | if not isinstance(batch[0][0], tuple): 81 | return default_collate(batch) 82 | else: 83 | batch_num = len(batch) 84 | ret = [] 85 | for item_idx in range(len(batch[0][0])): 86 | if batch[0][0][item_idx] is None: 87 | ret.append(None) 88 | else: 89 | ret.append(default_collate([batch[i][0][item_idx] for i in range(batch_num)])) 90 | ret.append(default_collate([batch[i][1] for i in range(batch_num)])) 91 | return ret 92 | 93 | 94 | def build_loader_simmim(config, logger): 95 | transform = SimMIMTransform(config) 96 | logger.info(f'Pre-train data transform:\n{transform}') 97 | 98 | dataset = ImageFolder(config.DATA.DATA_PATH, transform) 99 | logger.info(f'Build dataset: train images = {len(dataset)}') 100 | 101 | sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True) 102 | dataloader = DataLoader(dataset, config.DATA.BATCH_SIZE, sampler=sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, drop_last=True, collate_fn=collate_fn) 103 | 104 | return dataloader -------------------------------------------------------------------------------- /figures/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/SimMIM/d3e29bcac950b83edc34ca33fe4404f38309052c/figures/teaser.jpg -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # SimMIM 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # Modified by Zhenda Xie 7 | # -------------------------------------------------------- 8 | 9 | import os 10 | import sys 11 | import logging 12 | import functools 13 | from termcolor import colored 14 | 15 | 16 | @functools.lru_cache() 17 | def create_logger(output_dir, dist_rank=0, name=''): 18 | # create logger 19 | logger = logging.getLogger(name) 20 | logger.setLevel(logging.DEBUG) 21 | logger.propagate = False 22 | 23 | # create formatter 24 | fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s' 25 | color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \ 26 | colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s' 27 | 28 | # create console handlers for master process 29 | if dist_rank == 0: 30 | console_handler = logging.StreamHandler(sys.stdout) 31 | console_handler.setLevel(logging.DEBUG) 32 | console_handler.setFormatter( 33 | logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S')) 34 | logger.addHandler(console_handler) 35 | 36 | # create file handlers 37 | file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a') 38 | file_handler.setLevel(logging.DEBUG) 39 | file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S')) 40 | logger.addHandler(file_handler) 41 | 42 | return logger 43 | -------------------------------------------------------------------------------- /lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # SimMIM 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # Modified by Zhenda Xie 7 | # -------------------------------------------------------- 8 | 9 | from collections import Counter 10 | from bisect import bisect_right 11 | 12 | import torch 13 | from timm.scheduler.cosine_lr import CosineLRScheduler 14 | from timm.scheduler.step_lr import StepLRScheduler 15 | from timm.scheduler.scheduler import Scheduler 16 | 17 | 18 | def build_scheduler(config, optimizer, n_iter_per_epoch): 19 | num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch) 20 | warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch) 21 | decay_steps = int(config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch) 22 | multi_steps = [i * n_iter_per_epoch for i in config.TRAIN.LR_SCHEDULER.MULTISTEPS] 23 | 24 | lr_scheduler = None 25 | if config.TRAIN.LR_SCHEDULER.NAME == 'cosine': 26 | lr_scheduler = CosineLRScheduler( 27 | optimizer, 28 | t_initial=num_steps, 29 | t_mul=1., 30 | lr_min=config.TRAIN.MIN_LR, 31 | warmup_lr_init=config.TRAIN.WARMUP_LR, 32 | warmup_t=warmup_steps, 33 | cycle_limit=1, 34 | t_in_epochs=False, 35 | ) 36 | elif config.TRAIN.LR_SCHEDULER.NAME == 'linear': 37 | lr_scheduler = LinearLRScheduler( 38 | optimizer, 39 | t_initial=num_steps, 40 | lr_min_rate=0.01, 41 | warmup_lr_init=config.TRAIN.WARMUP_LR, 42 | warmup_t=warmup_steps, 43 | t_in_epochs=False, 44 | ) 45 | elif config.TRAIN.LR_SCHEDULER.NAME == 'step': 46 | lr_scheduler = StepLRScheduler( 47 | optimizer, 48 | decay_t=decay_steps, 49 | decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE, 50 | warmup_lr_init=config.TRAIN.WARMUP_LR, 51 | warmup_t=warmup_steps, 52 | t_in_epochs=False, 53 | ) 54 | elif config.TRAIN.LR_SCHEDULER.NAME == 'multistep': 55 | lr_scheduler = MultiStepLRScheduler( 56 | optimizer, 57 | milestones=multi_steps, 58 | gamma=config.TRAIN.LR_SCHEDULER.GAMMA, 59 | warmup_lr_init=config.TRAIN.WARMUP_LR, 60 | warmup_t=warmup_steps, 61 | t_in_epochs=False, 62 | ) 63 | 64 | return lr_scheduler 65 | 66 | 67 | class LinearLRScheduler(Scheduler): 68 | def __init__(self, 69 | optimizer: torch.optim.Optimizer, 70 | t_initial: int, 71 | lr_min_rate: float, 72 | warmup_t=0, 73 | warmup_lr_init=0., 74 | t_in_epochs=True, 75 | noise_range_t=None, 76 | noise_pct=0.67, 77 | noise_std=1.0, 78 | noise_seed=42, 79 | initialize=True, 80 | ) -> None: 81 | super().__init__( 82 | optimizer, param_group_field="lr", 83 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 84 | initialize=initialize) 85 | 86 | self.t_initial = t_initial 87 | self.lr_min_rate = lr_min_rate 88 | self.warmup_t = warmup_t 89 | self.warmup_lr_init = warmup_lr_init 90 | self.t_in_epochs = t_in_epochs 91 | if self.warmup_t: 92 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 93 | super().update_groups(self.warmup_lr_init) 94 | else: 95 | self.warmup_steps = [1 for _ in self.base_values] 96 | 97 | def _get_lr(self, t): 98 | if t < self.warmup_t: 99 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 100 | else: 101 | t = t - self.warmup_t 102 | total_t = self.t_initial - self.warmup_t 103 | lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values] 104 | return lrs 105 | 106 | def get_epoch_values(self, epoch: int): 107 | if self.t_in_epochs: 108 | return self._get_lr(epoch) 109 | else: 110 | return None 111 | 112 | def get_update_values(self, num_updates: int): 113 | if not self.t_in_epochs: 114 | return self._get_lr(num_updates) 115 | else: 116 | return None 117 | 118 | 119 | class MultiStepLRScheduler(Scheduler): 120 | def __init__(self, optimizer: torch.optim.Optimizer, milestones, gamma=0.1, warmup_t=0, warmup_lr_init=0, t_in_epochs=True) -> None: 121 | super().__init__(optimizer, param_group_field="lr") 122 | 123 | self.milestones = milestones 124 | self.gamma = gamma 125 | self.warmup_t = warmup_t 126 | self.warmup_lr_init = warmup_lr_init 127 | self.t_in_epochs = t_in_epochs 128 | if self.warmup_t: 129 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 130 | super().update_groups(self.warmup_lr_init) 131 | else: 132 | self.warmup_steps = [1 for _ in self.base_values] 133 | 134 | assert self.warmup_t <= min(self.milestones) 135 | 136 | def _get_lr(self, t): 137 | if t < self.warmup_t: 138 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 139 | else: 140 | lrs = [v * (self.gamma ** bisect_right(self.milestones, t)) for v in self.base_values] 141 | return lrs 142 | 143 | def get_epoch_values(self, epoch: int): 144 | if self.t_in_epochs: 145 | return self._get_lr(epoch) 146 | else: 147 | return None 148 | 149 | def get_update_values(self, num_updates: int): 150 | if not self.t_in_epochs: 151 | return self._get_lr(num_updates) 152 | else: 153 | return None -------------------------------------------------------------------------------- /main_finetune.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # SimMIM 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # Modified by Zhenda Xie 7 | # -------------------------------------------------------- 8 | 9 | import os 10 | import time 11 | import argparse 12 | import datetime 13 | import numpy as np 14 | 15 | import torch 16 | import torch.backends.cudnn as cudnn 17 | import torch.distributed as dist 18 | 19 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 20 | from timm.utils import accuracy, AverageMeter 21 | 22 | from config import get_config 23 | from models import build_model 24 | from data import build_loader 25 | from lr_scheduler import build_scheduler 26 | from optimizer import build_optimizer 27 | from logger import create_logger 28 | from utils import load_checkpoint, load_pretrained, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor 29 | 30 | try: 31 | # noinspection PyUnresolvedReferences 32 | from apex import amp 33 | except ImportError: 34 | amp = None 35 | 36 | 37 | def parse_option(): 38 | parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False) 39 | parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', ) 40 | parser.add_argument( 41 | "--opts", 42 | help="Modify config options by adding 'KEY VALUE' pairs. ", 43 | default=None, 44 | nargs='+', 45 | ) 46 | 47 | # easy config modification 48 | parser.add_argument('--batch-size', type=int, help="batch size for single GPU") 49 | parser.add_argument('--data-path', type=str, help='path to dataset') 50 | parser.add_argument('--pretrained', type=str, help='path to pre-trained model') 51 | parser.add_argument('--resume', help='resume from checkpoint') 52 | parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps") 53 | parser.add_argument('--use-checkpoint', action='store_true', 54 | help="whether to use gradient checkpointing to save memory") 55 | parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'], 56 | help='mixed precision opt level, if O0, no amp is used') 57 | parser.add_argument('--output', default='output', type=str, metavar='PATH', 58 | help='root of output folder, the full path is // (default: output)') 59 | parser.add_argument('--tag', help='tag of experiment') 60 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 61 | parser.add_argument('--throughput', action='store_true', help='Test throughput only') 62 | 63 | # distributed training 64 | parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel') 65 | 66 | args = parser.parse_args() 67 | 68 | config = get_config(args) 69 | 70 | return args, config 71 | 72 | 73 | def main(config): 74 | dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config, logger, is_pretrain=False) 75 | 76 | logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") 77 | model = build_model(config, is_pretrain=False) 78 | model.cuda() 79 | logger.info(str(model)) 80 | 81 | optimizer = build_optimizer(config, model, logger, is_pretrain=False) 82 | if config.AMP_OPT_LEVEL != "O0": 83 | model, optimizer = amp.initialize(model, optimizer, opt_level=config.AMP_OPT_LEVEL) 84 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False) 85 | model_without_ddp = model.module 86 | 87 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 88 | logger.info(f"number of params: {n_parameters}") 89 | if hasattr(model_without_ddp, 'flops'): 90 | flops = model_without_ddp.flops() 91 | logger.info(f"number of GFLOPs: {flops / 1e9}") 92 | 93 | lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) 94 | 95 | if config.AUG.MIXUP > 0.: 96 | # smoothing is handled with mixup label transform 97 | criterion = SoftTargetCrossEntropy() 98 | elif config.MODEL.LABEL_SMOOTHING > 0.: 99 | criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING) 100 | else: 101 | criterion = torch.nn.CrossEntropyLoss() 102 | 103 | max_accuracy = 0.0 104 | 105 | if config.TRAIN.AUTO_RESUME: 106 | resume_file = auto_resume_helper(config.OUTPUT, logger) 107 | if resume_file: 108 | if config.MODEL.RESUME: 109 | logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}") 110 | config.defrost() 111 | config.MODEL.RESUME = resume_file 112 | config.freeze() 113 | logger.info(f'auto resuming from {resume_file}') 114 | else: 115 | logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume') 116 | 117 | if config.MODEL.RESUME: 118 | max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, logger) 119 | acc1, acc5, loss = validate(config, data_loader_val, model) 120 | logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") 121 | if config.EVAL_MODE: 122 | return 123 | elif config.PRETRAINED: 124 | load_pretrained(config, model_without_ddp, logger) 125 | 126 | if config.THROUGHPUT_MODE: 127 | throughput(data_loader_val, model, logger) 128 | return 129 | 130 | logger.info("Start training") 131 | start_time = time.time() 132 | for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): 133 | data_loader_train.sampler.set_epoch(epoch) 134 | 135 | train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler) 136 | if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)): 137 | save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, logger) 138 | 139 | acc1, acc5, loss = validate(config, data_loader_val, model) 140 | logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") 141 | max_accuracy = max(max_accuracy, acc1) 142 | logger.info(f'Max accuracy: {max_accuracy:.2f}%') 143 | 144 | total_time = time.time() - start_time 145 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 146 | logger.info('Training time {}'.format(total_time_str)) 147 | 148 | 149 | def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler): 150 | model.train() 151 | optimizer.zero_grad() 152 | 153 | logger.info(f'Current learning rate for different parameter groups: {[it["lr"] for it in optimizer.param_groups]}') 154 | 155 | num_steps = len(data_loader) 156 | batch_time = AverageMeter() 157 | loss_meter = AverageMeter() 158 | norm_meter = AverageMeter() 159 | 160 | start = time.time() 161 | end = time.time() 162 | for idx, (samples, targets) in enumerate(data_loader): 163 | samples = samples.cuda(non_blocking=True) 164 | targets = targets.cuda(non_blocking=True) 165 | 166 | if mixup_fn is not None: 167 | samples, targets = mixup_fn(samples, targets) 168 | 169 | outputs = model(samples) 170 | 171 | if config.TRAIN.ACCUMULATION_STEPS > 1: 172 | loss = criterion(outputs, targets) 173 | loss = loss / config.TRAIN.ACCUMULATION_STEPS 174 | if config.AMP_OPT_LEVEL != "O0": 175 | with amp.scale_loss(loss, optimizer) as scaled_loss: 176 | scaled_loss.backward() 177 | if config.TRAIN.CLIP_GRAD: 178 | grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.TRAIN.CLIP_GRAD) 179 | else: 180 | grad_norm = get_grad_norm(amp.master_params(optimizer)) 181 | else: 182 | loss.backward() 183 | if config.TRAIN.CLIP_GRAD: 184 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) 185 | else: 186 | grad_norm = get_grad_norm(model.parameters()) 187 | if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: 188 | optimizer.step() 189 | optimizer.zero_grad() 190 | lr_scheduler.step_update(epoch * num_steps + idx) 191 | else: 192 | loss = criterion(outputs, targets) 193 | optimizer.zero_grad() 194 | if config.AMP_OPT_LEVEL != "O0": 195 | with amp.scale_loss(loss, optimizer) as scaled_loss: 196 | scaled_loss.backward() 197 | if config.TRAIN.CLIP_GRAD: 198 | grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.TRAIN.CLIP_GRAD) 199 | else: 200 | grad_norm = get_grad_norm(amp.master_params(optimizer)) 201 | else: 202 | loss.backward() 203 | if config.TRAIN.CLIP_GRAD: 204 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) 205 | else: 206 | grad_norm = get_grad_norm(model.parameters()) 207 | optimizer.step() 208 | lr_scheduler.step_update(epoch * num_steps + idx) 209 | 210 | torch.cuda.synchronize() 211 | 212 | loss_meter.update(loss.item(), targets.size(0)) 213 | norm_meter.update(grad_norm) 214 | batch_time.update(time.time() - end) 215 | end = time.time() 216 | 217 | if idx % config.PRINT_FREQ == 0: 218 | lr = optimizer.param_groups[-1]['lr'] 219 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) 220 | etas = batch_time.avg * (num_steps - idx) 221 | logger.info( 222 | f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' 223 | f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t' 224 | f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' 225 | f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' 226 | f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' 227 | f'mem {memory_used:.0f}MB') 228 | epoch_time = time.time() - start 229 | logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}") 230 | 231 | 232 | @torch.no_grad() 233 | def validate(config, data_loader, model): 234 | criterion = torch.nn.CrossEntropyLoss() 235 | model.eval() 236 | 237 | batch_time = AverageMeter() 238 | loss_meter = AverageMeter() 239 | acc1_meter = AverageMeter() 240 | acc5_meter = AverageMeter() 241 | 242 | end = time.time() 243 | for idx, (images, target) in enumerate(data_loader): 244 | images = images.cuda(non_blocking=True) 245 | target = target.cuda(non_blocking=True) 246 | 247 | # compute output 248 | output = model(images) 249 | 250 | # measure accuracy and record loss 251 | loss = criterion(output, target) 252 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 253 | 254 | acc1 = reduce_tensor(acc1) 255 | acc5 = reduce_tensor(acc5) 256 | loss = reduce_tensor(loss) 257 | 258 | loss_meter.update(loss.item(), target.size(0)) 259 | acc1_meter.update(acc1.item(), target.size(0)) 260 | acc5_meter.update(acc5.item(), target.size(0)) 261 | 262 | # measure elapsed time 263 | batch_time.update(time.time() - end) 264 | end = time.time() 265 | 266 | if idx % config.PRINT_FREQ == 0: 267 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) 268 | logger.info( 269 | f'Test: [{idx}/{len(data_loader)}]\t' 270 | f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 271 | f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' 272 | f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t' 273 | f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t' 274 | f'Mem {memory_used:.0f}MB') 275 | logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}') 276 | return acc1_meter.avg, acc5_meter.avg, loss_meter.avg 277 | 278 | 279 | @torch.no_grad() 280 | def throughput(data_loader, model, logger): 281 | model.eval() 282 | 283 | for idx, (images, _) in enumerate(data_loader): 284 | images = images.cuda(non_blocking=True) 285 | batch_size = images.shape[0] 286 | for i in range(50): 287 | model(images) 288 | torch.cuda.synchronize() 289 | logger.info(f"throughput averaged with 30 times") 290 | tic1 = time.time() 291 | for i in range(30): 292 | model(images) 293 | torch.cuda.synchronize() 294 | tic2 = time.time() 295 | logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}") 296 | return 297 | 298 | 299 | if __name__ == '__main__': 300 | _, config = parse_option() 301 | 302 | if config.AMP_OPT_LEVEL != "O0": 303 | assert amp is not None, "amp not installed!" 304 | 305 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 306 | rank = int(os.environ["RANK"]) 307 | world_size = int(os.environ['WORLD_SIZE']) 308 | print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}") 309 | else: 310 | rank = -1 311 | world_size = -1 312 | torch.cuda.set_device(config.LOCAL_RANK) 313 | torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank) 314 | torch.distributed.barrier() 315 | 316 | seed = config.SEED + dist.get_rank() 317 | torch.manual_seed(seed) 318 | np.random.seed(seed) 319 | cudnn.benchmark = True 320 | 321 | # linear scale the learning rate according to total batch size, may not be optimal 322 | linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 323 | linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 324 | linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 325 | # gradient accumulation also need to scale the learning rate 326 | if config.TRAIN.ACCUMULATION_STEPS > 1: 327 | linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS 328 | linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS 329 | linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS 330 | config.defrost() 331 | config.TRAIN.BASE_LR = linear_scaled_lr 332 | config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr 333 | config.TRAIN.MIN_LR = linear_scaled_min_lr 334 | config.freeze() 335 | 336 | os.makedirs(config.OUTPUT, exist_ok=True) 337 | logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}") 338 | 339 | if dist.get_rank() == 0: 340 | path = os.path.join(config.OUTPUT, "config.json") 341 | with open(path, "w") as f: 342 | f.write(config.dump()) 343 | logger.info(f"Full config saved to {path}") 344 | 345 | # print config 346 | logger.info(config.dump()) 347 | 348 | main(config) 349 | -------------------------------------------------------------------------------- /main_simmim.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # SimMIM 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # Modified by Zhenda Xie 7 | # -------------------------------------------------------- 8 | 9 | import os 10 | import time 11 | import argparse 12 | import datetime 13 | import numpy as np 14 | 15 | import torch 16 | import torch.backends.cudnn as cudnn 17 | import torch.distributed as dist 18 | from timm.utils import AverageMeter 19 | 20 | from config import get_config 21 | from models import build_model 22 | from data import build_loader 23 | from lr_scheduler import build_scheduler 24 | from optimizer import build_optimizer 25 | from logger import create_logger 26 | from utils import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper 27 | 28 | try: 29 | # noinspection PyUnresolvedReferences 30 | from apex import amp 31 | except ImportError: 32 | amp = None 33 | 34 | 35 | def parse_option(): 36 | parser = argparse.ArgumentParser('SimMIM pre-training script', add_help=False) 37 | parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', ) 38 | parser.add_argument( 39 | "--opts", 40 | help="Modify config options by adding 'KEY VALUE' pairs. ", 41 | default=None, 42 | nargs='+', 43 | ) 44 | 45 | # easy config modification 46 | parser.add_argument('--batch-size', type=int, help="batch size for single GPU") 47 | parser.add_argument('--data-path', type=str, help='path to dataset') 48 | parser.add_argument('--resume', help='resume from checkpoint') 49 | parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps") 50 | parser.add_argument('--use-checkpoint', action='store_true', 51 | help="whether to use gradient checkpointing to save memory") 52 | parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'], 53 | help='mixed precision opt level, if O0, no amp is used') 54 | parser.add_argument('--output', default='output', type=str, metavar='PATH', 55 | help='root of output folder, the full path is // (default: output)') 56 | parser.add_argument('--tag', help='tag of experiment') 57 | 58 | # distributed training 59 | parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel') 60 | 61 | args = parser.parse_args() 62 | 63 | config = get_config(args) 64 | 65 | return args, config 66 | 67 | 68 | def main(config): 69 | data_loader_train = build_loader(config, logger, is_pretrain=True) 70 | 71 | logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") 72 | model = build_model(config, is_pretrain=True) 73 | model.cuda() 74 | logger.info(str(model)) 75 | 76 | optimizer = build_optimizer(config, model, logger, is_pretrain=True) 77 | if config.AMP_OPT_LEVEL != "O0": 78 | model, optimizer = amp.initialize(model, optimizer, opt_level=config.AMP_OPT_LEVEL) 79 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False) 80 | model_without_ddp = model.module 81 | 82 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 83 | logger.info(f"number of params: {n_parameters}") 84 | if hasattr(model_without_ddp, 'flops'): 85 | flops = model_without_ddp.flops() 86 | logger.info(f"number of GFLOPs: {flops / 1e9}") 87 | 88 | lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) 89 | 90 | if config.TRAIN.AUTO_RESUME: 91 | resume_file = auto_resume_helper(config.OUTPUT, logger) 92 | if resume_file: 93 | if config.MODEL.RESUME: 94 | logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}") 95 | config.defrost() 96 | config.MODEL.RESUME = resume_file 97 | config.freeze() 98 | logger.info(f'auto resuming from {resume_file}') 99 | else: 100 | logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume') 101 | 102 | if config.MODEL.RESUME: 103 | load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, logger) 104 | 105 | logger.info("Start training") 106 | start_time = time.time() 107 | for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): 108 | data_loader_train.sampler.set_epoch(epoch) 109 | 110 | train_one_epoch(config, model, data_loader_train, optimizer, epoch, lr_scheduler) 111 | if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)): 112 | save_checkpoint(config, epoch, model_without_ddp, 0., optimizer, lr_scheduler, logger) 113 | 114 | total_time = time.time() - start_time 115 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 116 | logger.info('Training time {}'.format(total_time_str)) 117 | 118 | 119 | def train_one_epoch(config, model, data_loader, optimizer, epoch, lr_scheduler): 120 | model.train() 121 | optimizer.zero_grad() 122 | 123 | num_steps = len(data_loader) 124 | batch_time = AverageMeter() 125 | loss_meter = AverageMeter() 126 | norm_meter = AverageMeter() 127 | 128 | start = time.time() 129 | end = time.time() 130 | for idx, (img, mask, _) in enumerate(data_loader): 131 | img = img.cuda(non_blocking=True) 132 | mask = mask.cuda(non_blocking=True) 133 | 134 | loss = model(img, mask) 135 | 136 | if config.TRAIN.ACCUMULATION_STEPS > 1: 137 | loss = loss / config.TRAIN.ACCUMULATION_STEPS 138 | if config.AMP_OPT_LEVEL != "O0": 139 | with amp.scale_loss(loss, optimizer) as scaled_loss: 140 | scaled_loss.backward() 141 | if config.TRAIN.CLIP_GRAD: 142 | grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.TRAIN.CLIP_GRAD) 143 | else: 144 | grad_norm = get_grad_norm(amp.master_params(optimizer)) 145 | else: 146 | loss.backward() 147 | if config.TRAIN.CLIP_GRAD: 148 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) 149 | else: 150 | grad_norm = get_grad_norm(model.parameters()) 151 | if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: 152 | optimizer.step() 153 | optimizer.zero_grad() 154 | lr_scheduler.step_update(epoch * num_steps + idx) 155 | else: 156 | optimizer.zero_grad() 157 | if config.AMP_OPT_LEVEL != "O0": 158 | with amp.scale_loss(loss, optimizer) as scaled_loss: 159 | scaled_loss.backward() 160 | if config.TRAIN.CLIP_GRAD: 161 | grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.TRAIN.CLIP_GRAD) 162 | else: 163 | grad_norm = get_grad_norm(amp.master_params(optimizer)) 164 | else: 165 | loss.backward() 166 | if config.TRAIN.CLIP_GRAD: 167 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) 168 | else: 169 | grad_norm = get_grad_norm(model.parameters()) 170 | optimizer.step() 171 | lr_scheduler.step_update(epoch * num_steps + idx) 172 | 173 | torch.cuda.synchronize() 174 | 175 | loss_meter.update(loss.item(), img.size(0)) 176 | norm_meter.update(grad_norm) 177 | batch_time.update(time.time() - end) 178 | end = time.time() 179 | 180 | if idx % config.PRINT_FREQ == 0: 181 | lr = optimizer.param_groups[0]['lr'] 182 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) 183 | etas = batch_time.avg * (num_steps - idx) 184 | logger.info( 185 | f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' 186 | f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t' 187 | f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' 188 | f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' 189 | f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' 190 | f'mem {memory_used:.0f}MB') 191 | epoch_time = time.time() - start 192 | logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}") 193 | 194 | 195 | if __name__ == '__main__': 196 | _, config = parse_option() 197 | 198 | if config.AMP_OPT_LEVEL != "O0": 199 | assert amp is not None, "amp not installed!" 200 | 201 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 202 | rank = int(os.environ["RANK"]) 203 | world_size = int(os.environ['WORLD_SIZE']) 204 | print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}") 205 | else: 206 | rank = -1 207 | world_size = -1 208 | torch.cuda.set_device(config.LOCAL_RANK) 209 | torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank) 210 | torch.distributed.barrier() 211 | 212 | seed = config.SEED + dist.get_rank() 213 | torch.manual_seed(seed) 214 | np.random.seed(seed) 215 | cudnn.benchmark = True 216 | 217 | # linear scale the learning rate according to total batch size, may not be optimal 218 | linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 219 | linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 220 | linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 221 | # gradient accumulation also need to scale the learning rate 222 | if config.TRAIN.ACCUMULATION_STEPS > 1: 223 | linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS 224 | linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS 225 | linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS 226 | config.defrost() 227 | config.TRAIN.BASE_LR = linear_scaled_lr 228 | config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr 229 | config.TRAIN.MIN_LR = linear_scaled_min_lr 230 | config.freeze() 231 | 232 | os.makedirs(config.OUTPUT, exist_ok=True) 233 | logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}") 234 | 235 | if dist.get_rank() == 0: 236 | path = os.path.join(config.OUTPUT, "config.json") 237 | with open(path, "w") as f: 238 | f.write(config.dump()) 239 | logger.info(f"Full config saved to {path}") 240 | 241 | # print config 242 | logger.info(config.dump()) 243 | 244 | main(config) 245 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_model -------------------------------------------------------------------------------- /models/build.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # SimMIM 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # Modified by Zhenda Xie 7 | # -------------------------------------------------------- 8 | 9 | from .swin_transformer import build_swin 10 | from .vision_transformer import build_vit 11 | from .simmim import build_simmim 12 | 13 | 14 | def build_model(config, is_pretrain=True): 15 | if is_pretrain: 16 | model = build_simmim(config) 17 | else: 18 | model_type = config.MODEL.TYPE 19 | if model_type == 'swin': 20 | model = build_swin(config) 21 | elif model_type == 'vit': 22 | model = build_vit(config) 23 | else: 24 | raise NotImplementedError(f"Unknown fine-tune model: {model_type}") 25 | 26 | return model 27 | -------------------------------------------------------------------------------- /models/simmim.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # SimMIM 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Zhenda Xie 6 | # -------------------------------------------------------- 7 | 8 | from functools import partial 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from timm.models.layers import trunc_normal_ 14 | 15 | from .swin_transformer import SwinTransformer 16 | from .vision_transformer import VisionTransformer 17 | 18 | 19 | class SwinTransformerForSimMIM(SwinTransformer): 20 | def __init__(self, **kwargs): 21 | super().__init__(**kwargs) 22 | 23 | assert self.num_classes == 0 24 | 25 | self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) 26 | trunc_normal_(self.mask_token, mean=0., std=.02) 27 | 28 | def forward(self, x, mask): 29 | x = self.patch_embed(x) 30 | 31 | assert mask is not None 32 | B, L, _ = x.shape 33 | 34 | mask_tokens = self.mask_token.expand(B, L, -1) 35 | w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens) 36 | x = x * (1. - w) + mask_tokens * w 37 | 38 | if self.ape: 39 | x = x + self.absolute_pos_embed 40 | x = self.pos_drop(x) 41 | 42 | for layer in self.layers: 43 | x = layer(x) 44 | x = self.norm(x) 45 | 46 | x = x.transpose(1, 2) 47 | B, C, L = x.shape 48 | H = W = int(L ** 0.5) 49 | x = x.reshape(B, C, H, W) 50 | return x 51 | 52 | @torch.jit.ignore 53 | def no_weight_decay(self): 54 | return super().no_weight_decay() | {'mask_token'} 55 | 56 | 57 | class VisionTransformerForSimMIM(VisionTransformer): 58 | def __init__(self, **kwargs): 59 | super().__init__(**kwargs) 60 | 61 | assert self.num_classes == 0 62 | 63 | self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) 64 | self._trunc_normal_(self.mask_token, std=.02) 65 | 66 | def _trunc_normal_(self, tensor, mean=0., std=1.): 67 | trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std) 68 | 69 | def forward(self, x, mask): 70 | x = self.patch_embed(x) 71 | 72 | assert mask is not None 73 | B, L, _ = x.shape 74 | 75 | mask_token = self.mask_token.expand(B, L, -1) 76 | w = mask.flatten(1).unsqueeze(-1).type_as(mask_token) 77 | x = x * (1 - w) + mask_token * w 78 | 79 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 80 | x = torch.cat((cls_tokens, x), dim=1) 81 | 82 | if self.pos_embed is not None: 83 | x = x + self.pos_embed 84 | x = self.pos_drop(x) 85 | 86 | rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None 87 | for blk in self.blocks: 88 | x = blk(x, rel_pos_bias=rel_pos_bias) 89 | x = self.norm(x) 90 | 91 | x = x[:, 1:] 92 | B, L, C = x.shape 93 | H = W = int(L ** 0.5) 94 | x = x.permute(0, 2, 1).reshape(B, C, H, W) 95 | return x 96 | 97 | 98 | class SimMIM(nn.Module): 99 | def __init__(self, encoder, encoder_stride): 100 | super().__init__() 101 | self.encoder = encoder 102 | self.encoder_stride = encoder_stride 103 | 104 | self.decoder = nn.Sequential( 105 | nn.Conv2d( 106 | in_channels=self.encoder.num_features, 107 | out_channels=self.encoder_stride ** 2 * 3, kernel_size=1), 108 | nn.PixelShuffle(self.encoder_stride), 109 | ) 110 | 111 | self.in_chans = self.encoder.in_chans 112 | self.patch_size = self.encoder.patch_size 113 | 114 | def forward(self, x, mask): 115 | z = self.encoder(x, mask) 116 | x_rec = self.decoder(z) 117 | 118 | mask = mask.repeat_interleave(self.patch_size, 1).repeat_interleave(self.patch_size, 2).unsqueeze(1).contiguous() 119 | loss_recon = F.l1_loss(x, x_rec, reduction='none') 120 | loss = (loss_recon * mask).sum() / (mask.sum() + 1e-5) / self.in_chans 121 | return loss 122 | 123 | @torch.jit.ignore 124 | def no_weight_decay(self): 125 | if hasattr(self.encoder, 'no_weight_decay'): 126 | return {'encoder.' + i for i in self.encoder.no_weight_decay()} 127 | return {} 128 | 129 | @torch.jit.ignore 130 | def no_weight_decay_keywords(self): 131 | if hasattr(self.encoder, 'no_weight_decay_keywords'): 132 | return {'encoder.' + i for i in self.encoder.no_weight_decay_keywords()} 133 | return {} 134 | 135 | 136 | def build_simmim(config): 137 | model_type = config.MODEL.TYPE 138 | if model_type == 'swin': 139 | encoder = SwinTransformerForSimMIM( 140 | img_size=config.DATA.IMG_SIZE, 141 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 142 | in_chans=config.MODEL.SWIN.IN_CHANS, 143 | num_classes=0, 144 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 145 | depths=config.MODEL.SWIN.DEPTHS, 146 | num_heads=config.MODEL.SWIN.NUM_HEADS, 147 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 148 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 149 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 150 | qk_scale=config.MODEL.SWIN.QK_SCALE, 151 | drop_rate=config.MODEL.DROP_RATE, 152 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 153 | ape=config.MODEL.SWIN.APE, 154 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 155 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 156 | encoder_stride = 32 157 | elif model_type == 'vit': 158 | encoder = VisionTransformerForSimMIM( 159 | img_size=config.DATA.IMG_SIZE, 160 | patch_size=config.MODEL.VIT.PATCH_SIZE, 161 | in_chans=config.MODEL.VIT.IN_CHANS, 162 | num_classes=0, 163 | embed_dim=config.MODEL.VIT.EMBED_DIM, 164 | depth=config.MODEL.VIT.DEPTH, 165 | num_heads=config.MODEL.VIT.NUM_HEADS, 166 | mlp_ratio=config.MODEL.VIT.MLP_RATIO, 167 | qkv_bias=config.MODEL.VIT.QKV_BIAS, 168 | drop_rate=config.MODEL.DROP_RATE, 169 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 170 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 171 | init_values=config.MODEL.VIT.INIT_VALUES, 172 | use_abs_pos_emb=config.MODEL.VIT.USE_APE, 173 | use_rel_pos_bias=config.MODEL.VIT.USE_RPB, 174 | use_shared_rel_pos_bias=config.MODEL.VIT.USE_SHARED_RPB, 175 | use_mean_pooling=config.MODEL.VIT.USE_MEAN_POOLING) 176 | encoder_stride = 16 177 | else: 178 | raise NotImplementedError(f"Unknown pre-train model: {model_type}") 179 | 180 | model = SimMIM(encoder=encoder, encoder_stride=encoder_stride) 181 | 182 | return model 183 | -------------------------------------------------------------------------------- /models/swin_transformer.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # SimMIM 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # Modified by Zhenda Xie 7 | # -------------------------------------------------------- 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.utils.checkpoint as checkpoint 12 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 13 | 14 | 15 | class Mlp(nn.Module): 16 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 17 | super().__init__() 18 | out_features = out_features or in_features 19 | hidden_features = hidden_features or in_features 20 | self.fc1 = nn.Linear(in_features, hidden_features) 21 | self.act = act_layer() 22 | self.fc2 = nn.Linear(hidden_features, out_features) 23 | self.drop = nn.Dropout(drop) 24 | 25 | def forward(self, x): 26 | x = self.fc1(x) 27 | x = self.act(x) 28 | x = self.drop(x) 29 | x = self.fc2(x) 30 | x = self.drop(x) 31 | return x 32 | 33 | 34 | def window_partition(x, window_size): 35 | """ 36 | Args: 37 | x: (B, H, W, C) 38 | window_size (int): window size 39 | 40 | Returns: 41 | windows: (num_windows*B, window_size, window_size, C) 42 | """ 43 | B, H, W, C = x.shape 44 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 45 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 46 | return windows 47 | 48 | 49 | def window_reverse(windows, window_size, H, W): 50 | """ 51 | Args: 52 | windows: (num_windows*B, window_size, window_size, C) 53 | window_size (int): Window size 54 | H (int): Height of image 55 | W (int): Width of image 56 | 57 | Returns: 58 | x: (B, H, W, C) 59 | """ 60 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 61 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 62 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 63 | return x 64 | 65 | 66 | class WindowAttention(nn.Module): 67 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 68 | It supports both of shifted and non-shifted window. 69 | 70 | Args: 71 | dim (int): Number of input channels. 72 | window_size (tuple[int]): The height and width of the window. 73 | num_heads (int): Number of attention heads. 74 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 75 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 76 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 77 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 78 | """ 79 | 80 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 81 | 82 | super().__init__() 83 | self.dim = dim 84 | self.window_size = window_size # Wh, Ww 85 | self.num_heads = num_heads 86 | head_dim = dim // num_heads 87 | self.scale = qk_scale or head_dim ** -0.5 88 | 89 | # define a parameter table of relative position bias 90 | self.relative_position_bias_table = nn.Parameter( 91 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 92 | 93 | # get pair-wise relative position index for each token inside the window 94 | coords_h = torch.arange(self.window_size[0]) 95 | coords_w = torch.arange(self.window_size[1]) 96 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 97 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 98 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 99 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 100 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 101 | relative_coords[:, :, 1] += self.window_size[1] - 1 102 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 103 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 104 | self.register_buffer("relative_position_index", relative_position_index) 105 | 106 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 107 | self.attn_drop = nn.Dropout(attn_drop) 108 | self.proj = nn.Linear(dim, dim) 109 | self.proj_drop = nn.Dropout(proj_drop) 110 | 111 | trunc_normal_(self.relative_position_bias_table, std=.02) 112 | self.softmax = nn.Softmax(dim=-1) 113 | 114 | def forward(self, x, mask=None): 115 | """ 116 | Args: 117 | x: input features with shape of (num_windows*B, N, C) 118 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 119 | """ 120 | B_, N, C = x.shape 121 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 122 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 123 | 124 | q = q * self.scale 125 | attn = (q @ k.transpose(-2, -1)) 126 | 127 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 128 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 129 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 130 | attn = attn + relative_position_bias.unsqueeze(0) 131 | 132 | if mask is not None: 133 | nW = mask.shape[0] 134 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 135 | attn = attn.view(-1, self.num_heads, N, N) 136 | attn = self.softmax(attn) 137 | else: 138 | attn = self.softmax(attn) 139 | 140 | attn = self.attn_drop(attn) 141 | 142 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 143 | x = self.proj(x) 144 | x = self.proj_drop(x) 145 | return x 146 | 147 | def extra_repr(self) -> str: 148 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' 149 | 150 | def flops(self, N): 151 | # calculate flops for 1 window with token length of N 152 | flops = 0 153 | # qkv = self.qkv(x) 154 | flops += N * self.dim * 3 * self.dim 155 | # attn = (q @ k.transpose(-2, -1)) 156 | flops += self.num_heads * N * (self.dim // self.num_heads) * N 157 | # x = (attn @ v) 158 | flops += self.num_heads * N * N * (self.dim // self.num_heads) 159 | # x = self.proj(x) 160 | flops += N * self.dim * self.dim 161 | return flops 162 | 163 | 164 | class SwinTransformerBlock(nn.Module): 165 | r""" Swin Transformer Block. 166 | 167 | Args: 168 | dim (int): Number of input channels. 169 | input_resolution (tuple[int]): Input resulotion. 170 | num_heads (int): Number of attention heads. 171 | window_size (int): Window size. 172 | shift_size (int): Shift size for SW-MSA. 173 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 174 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 175 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 176 | drop (float, optional): Dropout rate. Default: 0.0 177 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 178 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 179 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 180 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 181 | """ 182 | 183 | def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, 184 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 185 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 186 | super().__init__() 187 | self.dim = dim 188 | self.input_resolution = input_resolution 189 | self.num_heads = num_heads 190 | self.window_size = window_size 191 | self.shift_size = shift_size 192 | self.mlp_ratio = mlp_ratio 193 | if min(self.input_resolution) <= self.window_size: 194 | # if window size is larger than input resolution, we don't partition windows 195 | self.shift_size = 0 196 | self.window_size = min(self.input_resolution) 197 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 198 | 199 | self.norm1 = norm_layer(dim) 200 | self.attn = WindowAttention( 201 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 202 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 203 | 204 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 205 | self.norm2 = norm_layer(dim) 206 | mlp_hidden_dim = int(dim * mlp_ratio) 207 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 208 | 209 | if self.shift_size > 0: 210 | # calculate attention mask for SW-MSA 211 | H, W = self.input_resolution 212 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 213 | h_slices = (slice(0, -self.window_size), 214 | slice(-self.window_size, -self.shift_size), 215 | slice(-self.shift_size, None)) 216 | w_slices = (slice(0, -self.window_size), 217 | slice(-self.window_size, -self.shift_size), 218 | slice(-self.shift_size, None)) 219 | cnt = 0 220 | for h in h_slices: 221 | for w in w_slices: 222 | img_mask[:, h, w, :] = cnt 223 | cnt += 1 224 | 225 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 226 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 227 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 228 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 229 | else: 230 | attn_mask = None 231 | 232 | self.register_buffer("attn_mask", attn_mask) 233 | 234 | def forward(self, x): 235 | H, W = self.input_resolution 236 | B, L, C = x.shape 237 | assert L == H * W, "input feature has wrong size" 238 | 239 | shortcut = x 240 | x = self.norm1(x) 241 | x = x.view(B, H, W, C) 242 | 243 | # cyclic shift 244 | if self.shift_size > 0: 245 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 246 | else: 247 | shifted_x = x 248 | 249 | # partition windows 250 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 251 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 252 | 253 | # W-MSA/SW-MSA 254 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C 255 | 256 | # merge windows 257 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 258 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C 259 | 260 | # reverse cyclic shift 261 | if self.shift_size > 0: 262 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 263 | else: 264 | x = shifted_x 265 | x = x.view(B, H * W, C) 266 | 267 | # FFN 268 | x = shortcut + self.drop_path(x) 269 | x = x + self.drop_path(self.mlp(self.norm2(x))) 270 | 271 | return x 272 | 273 | def extra_repr(self) -> str: 274 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 275 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 276 | 277 | def flops(self): 278 | flops = 0 279 | H, W = self.input_resolution 280 | # norm1 281 | flops += self.dim * H * W 282 | # W-MSA/SW-MSA 283 | nW = H * W / self.window_size / self.window_size 284 | flops += nW * self.attn.flops(self.window_size * self.window_size) 285 | # mlp 286 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 287 | # norm2 288 | flops += self.dim * H * W 289 | return flops 290 | 291 | 292 | class PatchMerging(nn.Module): 293 | r""" Patch Merging Layer. 294 | 295 | Args: 296 | input_resolution (tuple[int]): Resolution of input feature. 297 | dim (int): Number of input channels. 298 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 299 | """ 300 | 301 | def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): 302 | super().__init__() 303 | self.input_resolution = input_resolution 304 | self.dim = dim 305 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 306 | self.norm = norm_layer(4 * dim) 307 | 308 | def forward(self, x): 309 | """ 310 | x: B, H*W, C 311 | """ 312 | H, W = self.input_resolution 313 | B, L, C = x.shape 314 | assert L == H * W, "input feature has wrong size" 315 | assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." 316 | 317 | x = x.view(B, H, W, C) 318 | 319 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 320 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 321 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 322 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 323 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 324 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 325 | 326 | x = self.norm(x) 327 | x = self.reduction(x) 328 | 329 | return x 330 | 331 | def extra_repr(self) -> str: 332 | return f"input_resolution={self.input_resolution}, dim={self.dim}" 333 | 334 | def flops(self): 335 | H, W = self.input_resolution 336 | flops = H * W * self.dim 337 | flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim 338 | return flops 339 | 340 | 341 | class BasicLayer(nn.Module): 342 | """ A basic Swin Transformer layer for one stage. 343 | 344 | Args: 345 | dim (int): Number of input channels. 346 | input_resolution (tuple[int]): Input resolution. 347 | depth (int): Number of blocks. 348 | num_heads (int): Number of attention heads. 349 | window_size (int): Local window size. 350 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 351 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 352 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 353 | drop (float, optional): Dropout rate. Default: 0.0 354 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 355 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 356 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 357 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 358 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 359 | """ 360 | 361 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 362 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 363 | drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): 364 | 365 | super().__init__() 366 | self.dim = dim 367 | self.input_resolution = input_resolution 368 | self.depth = depth 369 | self.use_checkpoint = use_checkpoint 370 | 371 | # build blocks 372 | self.blocks = nn.ModuleList([ 373 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 374 | num_heads=num_heads, window_size=window_size, 375 | shift_size=0 if (i % 2 == 0) else window_size // 2, 376 | mlp_ratio=mlp_ratio, 377 | qkv_bias=qkv_bias, qk_scale=qk_scale, 378 | drop=drop, attn_drop=attn_drop, 379 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 380 | norm_layer=norm_layer) 381 | for i in range(depth)]) 382 | 383 | # patch merging layer 384 | if downsample is not None: 385 | self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) 386 | else: 387 | self.downsample = None 388 | 389 | def forward(self, x): 390 | for blk in self.blocks: 391 | if self.use_checkpoint: 392 | x = checkpoint.checkpoint(blk, x) 393 | else: 394 | x = blk(x) 395 | if self.downsample is not None: 396 | x = self.downsample(x) 397 | return x 398 | 399 | def extra_repr(self) -> str: 400 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 401 | 402 | def flops(self): 403 | flops = 0 404 | for blk in self.blocks: 405 | flops += blk.flops() 406 | if self.downsample is not None: 407 | flops += self.downsample.flops() 408 | return flops 409 | 410 | 411 | class PatchEmbed(nn.Module): 412 | r""" Image to Patch Embedding 413 | 414 | Args: 415 | img_size (int): Image size. Default: 224. 416 | patch_size (int): Patch token size. Default: 4. 417 | in_chans (int): Number of input image channels. Default: 3. 418 | embed_dim (int): Number of linear projection output channels. Default: 96. 419 | norm_layer (nn.Module, optional): Normalization layer. Default: None 420 | """ 421 | 422 | def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 423 | super().__init__() 424 | img_size = to_2tuple(img_size) 425 | patch_size = to_2tuple(patch_size) 426 | patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] 427 | self.img_size = img_size 428 | self.patch_size = patch_size 429 | self.patches_resolution = patches_resolution 430 | self.num_patches = patches_resolution[0] * patches_resolution[1] 431 | 432 | self.in_chans = in_chans 433 | self.embed_dim = embed_dim 434 | 435 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 436 | if norm_layer is not None: 437 | self.norm = norm_layer(embed_dim) 438 | else: 439 | self.norm = None 440 | 441 | def forward(self, x): 442 | B, C, H, W = x.shape 443 | # FIXME look at relaxing size constraints 444 | assert H == self.img_size[0] and W == self.img_size[1], \ 445 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 446 | x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C 447 | if self.norm is not None: 448 | x = self.norm(x) 449 | return x 450 | 451 | def flops(self): 452 | Ho, Wo = self.patches_resolution 453 | flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) 454 | if self.norm is not None: 455 | flops += Ho * Wo * self.embed_dim 456 | return flops 457 | 458 | 459 | class SwinTransformer(nn.Module): 460 | r""" Swin Transformer 461 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - 462 | https://arxiv.org/pdf/2103.14030 463 | 464 | Args: 465 | img_size (int | tuple(int)): Input image size. Default 224 466 | patch_size (int | tuple(int)): Patch size. Default: 4 467 | in_chans (int): Number of input image channels. Default: 3 468 | num_classes (int): Number of classes for classification head. Default: 1000 469 | embed_dim (int): Patch embedding dimension. Default: 96 470 | depths (tuple(int)): Depth of each Swin Transformer layer. 471 | num_heads (tuple(int)): Number of attention heads in different layers. 472 | window_size (int): Window size. Default: 7 473 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 474 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 475 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None 476 | drop_rate (float): Dropout rate. Default: 0 477 | attn_drop_rate (float): Attention dropout rate. Default: 0 478 | drop_path_rate (float): Stochastic depth rate. Default: 0.1 479 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 480 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False 481 | patch_norm (bool): If True, add normalization after patch embedding. Default: True 482 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False 483 | """ 484 | 485 | def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, 486 | embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], 487 | window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, 488 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 489 | norm_layer=nn.LayerNorm, ape=False, patch_norm=True, 490 | use_checkpoint=False, **kwargs): 491 | super().__init__() 492 | 493 | self.img_size = img_size 494 | self.patch_size = patch_size 495 | self.in_chans = in_chans 496 | 497 | self.num_classes = num_classes 498 | self.num_layers = len(depths) 499 | self.embed_dim = embed_dim 500 | self.ape = ape 501 | self.patch_norm = patch_norm 502 | self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) 503 | self.mlp_ratio = mlp_ratio 504 | 505 | # split image into non-overlapping patches 506 | self.patch_embed = PatchEmbed( 507 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 508 | norm_layer=norm_layer if self.patch_norm else None) 509 | num_patches = self.patch_embed.num_patches 510 | patches_resolution = self.patch_embed.patches_resolution 511 | self.patches_resolution = patches_resolution 512 | 513 | # absolute position embedding 514 | if self.ape: 515 | self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 516 | trunc_normal_(self.absolute_pos_embed, std=.02) 517 | 518 | self.pos_drop = nn.Dropout(p=drop_rate) 519 | 520 | # stochastic depth 521 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 522 | 523 | # build layers 524 | self.layers = nn.ModuleList() 525 | for i_layer in range(self.num_layers): 526 | layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), 527 | input_resolution=(patches_resolution[0] // (2 ** i_layer), 528 | patches_resolution[1] // (2 ** i_layer)), 529 | depth=depths[i_layer], 530 | num_heads=num_heads[i_layer], 531 | window_size=window_size, 532 | mlp_ratio=self.mlp_ratio, 533 | qkv_bias=qkv_bias, qk_scale=qk_scale, 534 | drop=drop_rate, attn_drop=attn_drop_rate, 535 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 536 | norm_layer=norm_layer, 537 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, 538 | use_checkpoint=use_checkpoint) 539 | self.layers.append(layer) 540 | 541 | self.norm = norm_layer(self.num_features) 542 | self.avgpool = nn.AdaptiveAvgPool1d(1) 543 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 544 | 545 | self.apply(self._init_weights) 546 | 547 | def _init_weights(self, m): 548 | if isinstance(m, nn.Linear): 549 | trunc_normal_(m.weight, std=.02) 550 | if isinstance(m, nn.Linear) and m.bias is not None: 551 | nn.init.constant_(m.bias, 0) 552 | elif isinstance(m, nn.LayerNorm): 553 | nn.init.constant_(m.bias, 0) 554 | nn.init.constant_(m.weight, 1.0) 555 | 556 | @torch.jit.ignore 557 | def no_weight_decay(self): 558 | return {'absolute_pos_embed'} 559 | 560 | @torch.jit.ignore 561 | def no_weight_decay_keywords(self): 562 | return {'relative_position_bias_table'} 563 | 564 | def forward_features(self, x): 565 | x = self.patch_embed(x) 566 | if self.ape: 567 | x = x + self.absolute_pos_embed 568 | x = self.pos_drop(x) 569 | 570 | for layer in self.layers: 571 | x = layer(x) 572 | 573 | x = self.norm(x) # B L C 574 | x = self.avgpool(x.transpose(1, 2)) # B C 1 575 | x = torch.flatten(x, 1) 576 | return x 577 | 578 | def forward(self, x): 579 | x = self.forward_features(x) 580 | x = self.head(x) 581 | return x 582 | 583 | def flops(self): 584 | flops = 0 585 | flops += self.patch_embed.flops() 586 | for i, layer in enumerate(self.layers): 587 | flops += layer.flops() 588 | flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers) 589 | flops += self.num_features * self.num_classes 590 | return flops 591 | 592 | 593 | def build_swin(config): 594 | model = SwinTransformer( 595 | img_size=config.DATA.IMG_SIZE, 596 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 597 | in_chans=config.MODEL.SWIN.IN_CHANS, 598 | num_classes=config.MODEL.NUM_CLASSES, 599 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 600 | depths=config.MODEL.SWIN.DEPTHS, 601 | num_heads=config.MODEL.SWIN.NUM_HEADS, 602 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 603 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 604 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 605 | qk_scale=config.MODEL.SWIN.QK_SCALE, 606 | drop_rate=config.MODEL.DROP_RATE, 607 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 608 | ape=config.MODEL.SWIN.APE, 609 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 610 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 611 | 612 | return model -------------------------------------------------------------------------------- /models/vision_transformer.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # SimMIM 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Based on BEIT code bases (https://github.com/microsoft/unilm/tree/master/beit) 6 | # Written by Yutong Lin, Zhenda Xie 7 | # -------------------------------------------------------- 8 | 9 | import math 10 | from functools import partial 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 16 | 17 | 18 | class Mlp(nn.Module): 19 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 20 | super().__init__() 21 | out_features = out_features or in_features 22 | hidden_features = hidden_features or in_features 23 | self.fc1 = nn.Linear(in_features, hidden_features) 24 | self.act = act_layer() 25 | self.fc2 = nn.Linear(hidden_features, out_features) 26 | self.drop = nn.Dropout(drop) 27 | 28 | def forward(self, x): 29 | x = self.fc1(x) 30 | x = self.act(x) 31 | # x = self.drop(x) 32 | # comment out this for the orignal BERT implement 33 | x = self.fc2(x) 34 | x = self.drop(x) 35 | return x 36 | 37 | 38 | class Attention(nn.Module): 39 | def __init__( 40 | self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., 41 | proj_drop=0., window_size=None, attn_head_dim=None): 42 | super().__init__() 43 | self.num_heads = num_heads 44 | head_dim = dim // num_heads 45 | if attn_head_dim is not None: 46 | head_dim = attn_head_dim 47 | all_head_dim = head_dim * self.num_heads 48 | self.scale = qk_scale or head_dim ** -0.5 49 | 50 | self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) 51 | if qkv_bias: 52 | self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) 53 | self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) 54 | else: 55 | self.q_bias = None 56 | self.v_bias = None 57 | 58 | if window_size: 59 | self.window_size = window_size 60 | # cls to token & token to cls & cls to cls 61 | self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 62 | self.relative_position_bias_table = nn.Parameter( 63 | torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH 64 | 65 | # get pair-wise relative position index for each token inside the window 66 | coords_h = torch.arange(window_size[0]) 67 | coords_w = torch.arange(window_size[1]) 68 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 69 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 70 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 71 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 72 | relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 73 | relative_coords[:, :, 1] += window_size[1] - 1 74 | relative_coords[:, :, 0] *= 2 * window_size[1] - 1 75 | relative_position_index = \ 76 | torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype) 77 | relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 78 | relative_position_index[0, 0:] = self.num_relative_distance - 3 79 | relative_position_index[0:, 0] = self.num_relative_distance - 2 80 | relative_position_index[0, 0] = self.num_relative_distance - 1 81 | 82 | self.register_buffer("relative_position_index", relative_position_index) 83 | else: 84 | self.window_size = None 85 | self.relative_position_bias_table = None 86 | self.relative_position_index = None 87 | 88 | self.attn_drop = nn.Dropout(attn_drop) 89 | self.proj = nn.Linear(all_head_dim, dim) 90 | self.proj_drop = nn.Dropout(proj_drop) 91 | 92 | def forward(self, x, rel_pos_bias=None): 93 | B, N, C = x.shape 94 | qkv_bias = None 95 | if self.q_bias is not None: 96 | qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) 97 | qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) 98 | qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 99 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 100 | 101 | q = q * self.scale 102 | attn = (q @ k.transpose(-2, -1)) 103 | 104 | if self.relative_position_bias_table is not None: 105 | relative_position_bias = \ 106 | self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 107 | self.window_size[0] * self.window_size[1] + 1, 108 | self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH 109 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 110 | attn = attn + relative_position_bias.unsqueeze(0) 111 | 112 | if rel_pos_bias is not None: 113 | attn = attn + rel_pos_bias 114 | 115 | attn = attn.softmax(dim=-1) 116 | attn = self.attn_drop(attn) 117 | 118 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1) 119 | x = self.proj(x) 120 | x = self.proj_drop(x) 121 | return x 122 | 123 | 124 | class Block(nn.Module): 125 | 126 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 127 | drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, 128 | window_size=None, attn_head_dim=None): 129 | super().__init__() 130 | self.norm1 = norm_layer(dim) 131 | self.attn = Attention( 132 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 133 | attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim) 134 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 135 | self.norm2 = norm_layer(dim) 136 | mlp_hidden_dim = int(dim * mlp_ratio) 137 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 138 | 139 | if init_values is not None: 140 | self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) 141 | self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) 142 | else: 143 | self.gamma_1, self.gamma_2 = None, None 144 | 145 | def forward(self, x, rel_pos_bias=None): 146 | if self.gamma_1 is None: 147 | x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) 148 | x = x + self.drop_path(self.mlp(self.norm2(x))) 149 | else: 150 | x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)) 151 | x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) 152 | return x 153 | 154 | 155 | class PatchEmbed(nn.Module): 156 | """ Image to Patch Embedding 157 | """ 158 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 159 | super().__init__() 160 | img_size = to_2tuple(img_size) 161 | patch_size = to_2tuple(patch_size) 162 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 163 | self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 164 | self.img_size = img_size 165 | self.patch_size = patch_size 166 | self.num_patches = num_patches 167 | 168 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 169 | 170 | def forward(self, x, **kwargs): 171 | B, C, H, W = x.shape 172 | # FIXME look at relaxing size constraints 173 | assert H == self.img_size[0] and W == self.img_size[1], \ 174 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 175 | x = self.proj(x).flatten(2).transpose(1, 2) 176 | return x 177 | 178 | 179 | class RelativePositionBias(nn.Module): 180 | 181 | def __init__(self, window_size, num_heads): 182 | super().__init__() 183 | self.window_size = window_size 184 | self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 185 | self.relative_position_bias_table = nn.Parameter( 186 | torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH 187 | # cls to token & token 2 cls & cls to cls 188 | 189 | # get pair-wise relative position index for each token inside the window 190 | coords_h = torch.arange(window_size[0]) 191 | coords_w = torch.arange(window_size[1]) 192 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 193 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 194 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 195 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 196 | relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 197 | relative_coords[:, :, 1] += window_size[1] - 1 198 | relative_coords[:, :, 0] *= 2 * window_size[1] - 1 199 | relative_position_index = \ 200 | torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) 201 | relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 202 | relative_position_index[0, 0:] = self.num_relative_distance - 3 203 | relative_position_index[0:, 0] = self.num_relative_distance - 2 204 | relative_position_index[0, 0] = self.num_relative_distance - 1 205 | 206 | self.register_buffer("relative_position_index", relative_position_index) 207 | 208 | def forward(self): 209 | relative_position_bias = \ 210 | self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 211 | self.window_size[0] * self.window_size[1] + 1, 212 | self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH 213 | return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 214 | 215 | 216 | class VisionTransformer(nn.Module): 217 | """ Vision Transformer with support for patch or hybrid CNN input stage 218 | """ 219 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 220 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 221 | drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, 222 | use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, 223 | use_mean_pooling=True, init_scale=0.001): 224 | super().__init__() 225 | self.num_classes = num_classes 226 | self.num_features = self.embed_dim = embed_dim 227 | self.patch_size = patch_size 228 | self.in_chans = in_chans 229 | 230 | self.patch_embed = PatchEmbed( 231 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 232 | num_patches = self.patch_embed.num_patches 233 | 234 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 235 | if use_abs_pos_emb: 236 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 237 | else: 238 | self.pos_embed = None 239 | self.pos_drop = nn.Dropout(p=drop_rate) 240 | 241 | if use_shared_rel_pos_bias: 242 | self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads) 243 | else: 244 | self.rel_pos_bias = None 245 | 246 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 247 | self.use_rel_pos_bias = use_rel_pos_bias 248 | self.blocks = nn.ModuleList([ 249 | Block( 250 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 251 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 252 | init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None) 253 | for i in range(depth)]) 254 | self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim) 255 | self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None 256 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 257 | 258 | if self.pos_embed is not None: 259 | self._trunc_normal_(self.pos_embed, std=.02) 260 | self._trunc_normal_(self.cls_token, std=.02) 261 | if num_classes > 0: 262 | self._trunc_normal_(self.head.weight, std=.02) 263 | self.apply(self._init_weights) 264 | self.fix_init_weight() 265 | 266 | if num_classes > 0: 267 | self.head.weight.data.mul_(init_scale) 268 | self.head.bias.data.mul_(init_scale) 269 | 270 | def _trunc_normal_(self, tensor, mean=0., std=1.): 271 | trunc_normal_(tensor, mean=mean, std=std) 272 | 273 | def fix_init_weight(self): 274 | def rescale(param, layer_id): 275 | param.div_(math.sqrt(2.0 * layer_id)) 276 | 277 | for layer_id, layer in enumerate(self.blocks): 278 | rescale(layer.attn.proj.weight.data, layer_id + 1) 279 | rescale(layer.mlp.fc2.weight.data, layer_id + 1) 280 | 281 | def _init_weights(self, m): 282 | if isinstance(m, nn.Linear): 283 | self._trunc_normal_(m.weight, std=.02) 284 | if isinstance(m, nn.Linear) and m.bias is not None: 285 | nn.init.constant_(m.bias, 0) 286 | elif isinstance(m, nn.LayerNorm): 287 | nn.init.constant_(m.bias, 0) 288 | nn.init.constant_(m.weight, 1.0) 289 | elif isinstance(m, nn.Conv2d): 290 | self._trunc_normal_(m.weight, std=.02) 291 | if m.bias is not None: 292 | nn.init.constant_(m.bias, 0) 293 | 294 | def get_num_layers(self): 295 | return len(self.blocks) 296 | 297 | @torch.jit.ignore 298 | def no_weight_decay(self): 299 | return {'pos_embed', 'cls_token'} 300 | 301 | def get_classifier(self): 302 | return self.head 303 | 304 | def reset_classifier(self, num_classes, global_pool=''): 305 | self.num_classes = num_classes 306 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 307 | 308 | def forward_features(self, x): 309 | x = self.patch_embed(x) 310 | batch_size, seq_len, _ = x.size() 311 | 312 | cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 313 | x = torch.cat((cls_tokens, x), dim=1) 314 | if self.pos_embed is not None: 315 | x = x + self.pos_embed 316 | x = self.pos_drop(x) 317 | 318 | rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None 319 | for blk in self.blocks: 320 | x = blk(x, rel_pos_bias=rel_pos_bias) 321 | 322 | x = self.norm(x) 323 | if self.fc_norm is not None: 324 | t = x[:, 1:, :] 325 | return self.fc_norm(t.mean(1)) 326 | else: 327 | return x[:, 0] 328 | 329 | def forward(self, x): 330 | x = self.forward_features(x) 331 | x = self.head(x) 332 | return x 333 | 334 | 335 | def build_vit(config): 336 | model = VisionTransformer( 337 | img_size=config.DATA.IMG_SIZE, 338 | patch_size=config.MODEL.VIT.PATCH_SIZE, 339 | in_chans=config.MODEL.VIT.IN_CHANS, 340 | num_classes=config.MODEL.NUM_CLASSES, 341 | embed_dim=config.MODEL.VIT.EMBED_DIM, 342 | depth=config.MODEL.VIT.DEPTH, 343 | num_heads=config.MODEL.VIT.NUM_HEADS, 344 | mlp_ratio=config.MODEL.VIT.MLP_RATIO, 345 | qkv_bias=config.MODEL.VIT.QKV_BIAS, 346 | drop_rate=config.MODEL.DROP_RATE, 347 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 348 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 349 | init_values=config.MODEL.VIT.INIT_VALUES, 350 | use_abs_pos_emb=config.MODEL.VIT.USE_APE, 351 | use_rel_pos_bias=config.MODEL.VIT.USE_RPB, 352 | use_shared_rel_pos_bias=config.MODEL.VIT.USE_SHARED_RPB, 353 | use_mean_pooling=config.MODEL.VIT.USE_MEAN_POOLING) 354 | 355 | return model -------------------------------------------------------------------------------- /optimizer.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # SimMIM 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # Modified by Zhenda Xie 7 | # -------------------------------------------------------- 8 | 9 | import json 10 | from functools import partial 11 | from torch import optim as optim 12 | 13 | 14 | def build_optimizer(config, model, logger, is_pretrain): 15 | if is_pretrain: 16 | return build_pretrain_optimizer(config, model, logger) 17 | else: 18 | return build_finetune_optimizer(config, model, logger) 19 | 20 | 21 | def build_pretrain_optimizer(config, model, logger): 22 | logger.info('>>>>>>>>>> Build Optimizer for Pre-training Stage') 23 | skip = {} 24 | skip_keywords = {} 25 | if hasattr(model, 'no_weight_decay'): 26 | skip = model.no_weight_decay() 27 | logger.info(f'No weight decay: {skip}') 28 | if hasattr(model, 'no_weight_decay_keywords'): 29 | skip_keywords = model.no_weight_decay_keywords() 30 | logger.info(f'No weight decay keywords: {skip_keywords}') 31 | 32 | parameters = get_pretrain_param_groups(model, logger, skip, skip_keywords) 33 | 34 | opt_lower = config.TRAIN.OPTIMIZER.NAME.lower() 35 | optimizer = None 36 | if opt_lower == 'sgd': 37 | optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True, 38 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 39 | elif opt_lower == 'adamw': 40 | optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, 41 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 42 | 43 | logger.info(optimizer) 44 | return optimizer 45 | 46 | 47 | def get_pretrain_param_groups(model, logger, skip_list=(), skip_keywords=()): 48 | has_decay = [] 49 | no_decay = [] 50 | has_decay_name = [] 51 | no_decay_name = [] 52 | 53 | for name, param in model.named_parameters(): 54 | if not param.requires_grad: 55 | continue 56 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ 57 | check_keywords_in_name(name, skip_keywords): 58 | no_decay.append(param) 59 | no_decay_name.append(name) 60 | else: 61 | has_decay.append(param) 62 | has_decay_name.append(name) 63 | logger.info(f'No decay params: {no_decay_name}') 64 | logger.info(f'Has decay params: {has_decay_name}') 65 | return [{'params': has_decay}, 66 | {'params': no_decay, 'weight_decay': 0.}] 67 | 68 | 69 | def build_finetune_optimizer(config, model, logger): 70 | logger.info('>>>>>>>>>> Build Optimizer for Fine-tuning Stage') 71 | if config.MODEL.TYPE == 'swin': 72 | depths = config.MODEL.SWIN.DEPTHS 73 | num_layers = sum(depths) 74 | get_layer_func = partial(get_swin_layer, num_layers=num_layers + 2, depths=depths) 75 | elif config.MODEL.TYPE == 'vit': 76 | num_layers = config.MODEL.VIT.DEPTH 77 | get_layer_func = partial(get_vit_layer, num_layers=num_layers + 2) 78 | else: 79 | raise NotImplementedError 80 | 81 | scales = list(config.TRAIN.LAYER_DECAY ** i for i in reversed(range(num_layers + 2))) 82 | 83 | skip = {} 84 | skip_keywords = {} 85 | if hasattr(model, 'no_weight_decay'): 86 | skip = model.no_weight_decay() 87 | logger.info(f'No weight decay: {skip}') 88 | if hasattr(model, 'no_weight_decay_keywords'): 89 | skip_keywords = model.no_weight_decay_keywords() 90 | logger.info(f'No weight decay keywords: {skip_keywords}') 91 | 92 | parameters = get_finetune_param_groups( 93 | model, logger, config.TRAIN.BASE_LR, config.TRAIN.WEIGHT_DECAY, 94 | get_layer_func, scales, skip, skip_keywords) 95 | 96 | opt_lower = config.TRAIN.OPTIMIZER.NAME.lower() 97 | optimizer = None 98 | if opt_lower == 'sgd': 99 | optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True, 100 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 101 | elif opt_lower == 'adamw': 102 | optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, 103 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 104 | 105 | logger.info(optimizer) 106 | return optimizer 107 | 108 | 109 | def get_vit_layer(name, num_layers): 110 | if name in ("cls_token", "mask_token", "pos_embed"): 111 | return 0 112 | elif name.startswith("patch_embed"): 113 | return 0 114 | elif name.startswith("rel_pos_bias"): 115 | return num_layers - 1 116 | elif name.startswith("blocks"): 117 | layer_id = int(name.split('.')[1]) 118 | return layer_id + 1 119 | else: 120 | return num_layers - 1 121 | 122 | 123 | def get_swin_layer(name, num_layers, depths): 124 | if name in ("mask_token"): 125 | return 0 126 | elif name.startswith("patch_embed"): 127 | return 0 128 | elif name.startswith("layers"): 129 | layer_id = int(name.split('.')[1]) 130 | block_id = name.split('.')[3] 131 | if block_id == 'reduction' or block_id == 'norm': 132 | return sum(depths[:layer_id + 1]) 133 | layer_id = sum(depths[:layer_id]) + int(block_id) 134 | return layer_id + 1 135 | else: 136 | return num_layers - 1 137 | 138 | 139 | def get_finetune_param_groups(model, logger, lr, weight_decay, get_layer_func, scales, skip_list=(), skip_keywords=()): 140 | parameter_group_names = {} 141 | parameter_group_vars = {} 142 | 143 | for name, param in model.named_parameters(): 144 | if not param.requires_grad: 145 | continue 146 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ 147 | check_keywords_in_name(name, skip_keywords): 148 | group_name = "no_decay" 149 | this_weight_decay = 0. 150 | else: 151 | group_name = "decay" 152 | this_weight_decay = weight_decay 153 | if get_layer_func is not None: 154 | layer_id = get_layer_func(name) 155 | group_name = "layer_%d_%s" % (layer_id, group_name) 156 | else: 157 | layer_id = None 158 | 159 | if group_name not in parameter_group_names: 160 | if scales is not None: 161 | scale = scales[layer_id] 162 | else: 163 | scale = 1. 164 | 165 | parameter_group_names[group_name] = { 166 | "group_name": group_name, 167 | "weight_decay": this_weight_decay, 168 | "params": [], 169 | "lr": lr * scale, 170 | "lr_scale": scale, 171 | } 172 | parameter_group_vars[group_name] = { 173 | "group_name": group_name, 174 | "weight_decay": this_weight_decay, 175 | "params": [], 176 | "lr": lr * scale, 177 | "lr_scale": scale 178 | } 179 | 180 | parameter_group_vars[group_name]["params"].append(param) 181 | parameter_group_names[group_name]["params"].append(name) 182 | logger.info("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) 183 | return list(parameter_group_vars.values()) 184 | 185 | 186 | def check_keywords_in_name(name, keywords=()): 187 | isin = False 188 | for keyword in keywords: 189 | if keyword in name: 190 | isin = True 191 | return isin -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pyyaml 2 | scipy 3 | termcolor 4 | timm 5 | yacs -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # SimMIM 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # Modified by Zhenda Xie 7 | # -------------------------------------------------------- 8 | 9 | import os 10 | import torch 11 | import torch.distributed as dist 12 | import numpy as np 13 | from scipy import interpolate 14 | 15 | try: 16 | # noinspection PyUnresolvedReferences 17 | from apex import amp 18 | except ImportError: 19 | amp = None 20 | 21 | 22 | def load_checkpoint(config, model, optimizer, lr_scheduler, logger): 23 | logger.info(f">>>>>>>>>> Resuming from {config.MODEL.RESUME} ..........") 24 | if config.MODEL.RESUME.startswith('https'): 25 | checkpoint = torch.hub.load_state_dict_from_url( 26 | config.MODEL.RESUME, map_location='cpu', check_hash=True) 27 | else: 28 | checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu') 29 | msg = model.load_state_dict(checkpoint['model'], strict=False) 30 | logger.info(msg) 31 | max_accuracy = 0.0 32 | if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 33 | optimizer.load_state_dict(checkpoint['optimizer']) 34 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 35 | config.defrost() 36 | config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1 37 | config.freeze() 38 | if 'amp' in checkpoint and config.AMP_OPT_LEVEL != "O0" and checkpoint['config'].AMP_OPT_LEVEL != "O0": 39 | amp.load_state_dict(checkpoint['amp']) 40 | logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})") 41 | if 'max_accuracy' in checkpoint: 42 | max_accuracy = checkpoint['max_accuracy'] 43 | 44 | del checkpoint 45 | torch.cuda.empty_cache() 46 | return max_accuracy 47 | 48 | 49 | def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger): 50 | save_state = {'model': model.state_dict(), 51 | 'optimizer': optimizer.state_dict(), 52 | 'lr_scheduler': lr_scheduler.state_dict(), 53 | 'max_accuracy': max_accuracy, 54 | 'epoch': epoch, 55 | 'config': config} 56 | if config.AMP_OPT_LEVEL != "O0": 57 | save_state['amp'] = amp.state_dict() 58 | 59 | save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth') 60 | logger.info(f"{save_path} saving......") 61 | torch.save(save_state, save_path) 62 | logger.info(f"{save_path} saved !!!") 63 | 64 | 65 | def get_grad_norm(parameters, norm_type=2): 66 | if isinstance(parameters, torch.Tensor): 67 | parameters = [parameters] 68 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 69 | norm_type = float(norm_type) 70 | total_norm = 0 71 | for p in parameters: 72 | param_norm = p.grad.data.norm(norm_type) 73 | total_norm += param_norm.item() ** norm_type 74 | total_norm = total_norm ** (1. / norm_type) 75 | return total_norm 76 | 77 | 78 | def auto_resume_helper(output_dir, logger): 79 | checkpoints = os.listdir(output_dir) 80 | checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')] 81 | logger.info(f"All checkpoints founded in {output_dir}: {checkpoints}") 82 | if len(checkpoints) > 0: 83 | latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime) 84 | logger.info(f"The latest checkpoint founded: {latest_checkpoint}") 85 | resume_file = latest_checkpoint 86 | else: 87 | resume_file = None 88 | return resume_file 89 | 90 | 91 | def reduce_tensor(tensor): 92 | rt = tensor.clone() 93 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 94 | rt /= dist.get_world_size() 95 | return rt 96 | 97 | 98 | def load_pretrained(config, model, logger): 99 | logger.info(f">>>>>>>>>> Fine-tuned from {config.PRETRAINED} ..........") 100 | checkpoint = torch.load(config.PRETRAINED, map_location='cpu') 101 | checkpoint_model = checkpoint['model'] 102 | 103 | if any([True if 'encoder.' in k else False for k in checkpoint_model.keys()]): 104 | checkpoint_model = {k.replace('encoder.', ''): v for k, v in checkpoint_model.items() if k.startswith('encoder.')} 105 | logger.info('Detect pre-trained model, remove [encoder.] prefix.') 106 | else: 107 | logger.info('Detect non-pre-trained model, pass without doing anything.') 108 | 109 | if config.MODEL.TYPE == 'swin': 110 | logger.info(f">>>>>>>>>> Remapping pre-trained keys for SWIN ..........") 111 | checkpoint = remap_pretrained_keys_swin(model, checkpoint_model, logger) 112 | elif config.MODEL.TYPE == 'vit': 113 | logger.info(f">>>>>>>>>> Remapping pre-trained keys for VIT ..........") 114 | checkpoint = remap_pretrained_keys_vit(model, checkpoint_model, logger) 115 | else: 116 | raise NotImplementedError 117 | 118 | msg = model.load_state_dict(checkpoint_model, strict=False) 119 | logger.info(msg) 120 | 121 | del checkpoint 122 | torch.cuda.empty_cache() 123 | logger.info(f">>>>>>>>>> loaded successfully '{config.PRETRAINED}'") 124 | 125 | 126 | def remap_pretrained_keys_swin(model, checkpoint_model, logger): 127 | state_dict = model.state_dict() 128 | 129 | # Geometric interpolation when pre-trained patch size mismatch with fine-tuned patch size 130 | all_keys = list(checkpoint_model.keys()) 131 | for key in all_keys: 132 | if "relative_position_bias_table" in key: 133 | relative_position_bias_table_pretrained = checkpoint_model[key] 134 | relative_position_bias_table_current = state_dict[key] 135 | L1, nH1 = relative_position_bias_table_pretrained.size() 136 | L2, nH2 = relative_position_bias_table_current.size() 137 | if nH1 != nH2: 138 | logger.info(f"Error in loading {key}, passing......") 139 | else: 140 | if L1 != L2: 141 | logger.info(f"{key}: Interpolate relative_position_bias_table using geo.") 142 | src_size = int(L1 ** 0.5) 143 | dst_size = int(L2 ** 0.5) 144 | 145 | def geometric_progression(a, r, n): 146 | return a * (1.0 - r ** n) / (1.0 - r) 147 | 148 | left, right = 1.01, 1.5 149 | while right - left > 1e-6: 150 | q = (left + right) / 2.0 151 | gp = geometric_progression(1, q, src_size // 2) 152 | if gp > dst_size // 2: 153 | right = q 154 | else: 155 | left = q 156 | 157 | # if q > 1.090307: 158 | # q = 1.090307 159 | 160 | dis = [] 161 | cur = 1 162 | for i in range(src_size // 2): 163 | dis.append(cur) 164 | cur += q ** (i + 1) 165 | 166 | r_ids = [-_ for _ in reversed(dis)] 167 | 168 | x = r_ids + [0] + dis 169 | y = r_ids + [0] + dis 170 | 171 | t = dst_size // 2.0 172 | dx = np.arange(-t, t + 0.1, 1.0) 173 | dy = np.arange(-t, t + 0.1, 1.0) 174 | 175 | logger.info("Original positions = %s" % str(x)) 176 | logger.info("Target positions = %s" % str(dx)) 177 | 178 | all_rel_pos_bias = [] 179 | 180 | for i in range(nH1): 181 | z = relative_position_bias_table_pretrained[:, i].view(src_size, src_size).float().numpy() 182 | f_cubic = interpolate.interp2d(x, y, z, kind='cubic') 183 | all_rel_pos_bias.append(torch.Tensor(f_cubic(dx, dy)).contiguous().view(-1, 1).to( 184 | relative_position_bias_table_pretrained.device)) 185 | 186 | new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) 187 | checkpoint_model[key] = new_rel_pos_bias 188 | 189 | # delete relative_position_index since we always re-init it 190 | relative_position_index_keys = [k for k in checkpoint_model.keys() if "relative_position_index" in k] 191 | for k in relative_position_index_keys: 192 | del checkpoint_model[k] 193 | 194 | # delete relative_coords_table since we always re-init it 195 | relative_coords_table_keys = [k for k in checkpoint_model.keys() if "relative_coords_table" in k] 196 | for k in relative_coords_table_keys: 197 | del checkpoint_model[k] 198 | 199 | # delete attn_mask since we always re-init it 200 | attn_mask_keys = [k for k in checkpoint_model.keys() if "attn_mask" in k] 201 | for k in attn_mask_keys: 202 | del checkpoint_model[k] 203 | 204 | return checkpoint_model 205 | 206 | 207 | def remap_pretrained_keys_vit(model, checkpoint_model, logger): 208 | # Duplicate shared rel_pos_bias to each layer 209 | if getattr(model, 'use_rel_pos_bias', False) and "rel_pos_bias.relative_position_bias_table" in checkpoint_model: 210 | logger.info("Expand the shared relative position embedding to each transformer block.") 211 | num_layers = model.get_num_layers() 212 | rel_pos_bias = checkpoint_model["rel_pos_bias.relative_position_bias_table"] 213 | for i in range(num_layers): 214 | checkpoint_model["blocks.%d.attn.relative_position_bias_table" % i] = rel_pos_bias.clone() 215 | checkpoint_model.pop("rel_pos_bias.relative_position_bias_table") 216 | 217 | # Geometric interpolation when pre-trained patch size mismatch with fine-tuned patch size 218 | all_keys = list(checkpoint_model.keys()) 219 | for key in all_keys: 220 | if "relative_position_index" in key: 221 | checkpoint_model.pop(key) 222 | 223 | if "relative_position_bias_table" in key: 224 | rel_pos_bias = checkpoint_model[key] 225 | src_num_pos, num_attn_heads = rel_pos_bias.size() 226 | dst_num_pos, _ = model.state_dict()[key].size() 227 | dst_patch_shape = model.patch_embed.patch_shape 228 | if dst_patch_shape[0] != dst_patch_shape[1]: 229 | raise NotImplementedError() 230 | num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1) 231 | src_size = int((src_num_pos - num_extra_tokens) ** 0.5) 232 | dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5) 233 | if src_size != dst_size: 234 | logger.info("Position interpolate for %s from %dx%d to %dx%d" % (key, src_size, src_size, dst_size, dst_size)) 235 | extra_tokens = rel_pos_bias[-num_extra_tokens:, :] 236 | rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] 237 | 238 | def geometric_progression(a, r, n): 239 | return a * (1.0 - r ** n) / (1.0 - r) 240 | 241 | left, right = 1.01, 1.5 242 | while right - left > 1e-6: 243 | q = (left + right) / 2.0 244 | gp = geometric_progression(1, q, src_size // 2) 245 | if gp > dst_size // 2: 246 | right = q 247 | else: 248 | left = q 249 | 250 | # if q > 1.090307: 251 | # q = 1.090307 252 | 253 | dis = [] 254 | cur = 1 255 | for i in range(src_size // 2): 256 | dis.append(cur) 257 | cur += q ** (i + 1) 258 | 259 | r_ids = [-_ for _ in reversed(dis)] 260 | 261 | x = r_ids + [0] + dis 262 | y = r_ids + [0] + dis 263 | 264 | t = dst_size // 2.0 265 | dx = np.arange(-t, t + 0.1, 1.0) 266 | dy = np.arange(-t, t + 0.1, 1.0) 267 | 268 | logger.info("Original positions = %s" % str(x)) 269 | logger.info("Target positions = %s" % str(dx)) 270 | 271 | all_rel_pos_bias = [] 272 | 273 | for i in range(num_attn_heads): 274 | z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy() 275 | f = interpolate.interp2d(x, y, z, kind='cubic') 276 | all_rel_pos_bias.append( 277 | torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device)) 278 | 279 | rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) 280 | 281 | new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0) 282 | checkpoint_model[key] = new_rel_pos_bias 283 | 284 | return checkpoint_model --------------------------------------------------------------------------------