├── .gitignore
├── CODE_OF_CONDUCT.md
├── LICENSE
├── README.md
├── SECURITY.md
├── SUPPORT.md
├── config.py
├── configs
├── swin_base__100ep
│ ├── simmim_finetune__swin_base__img192_window6__100ep.yaml
│ ├── simmim_finetune__swin_base__img224_window7__100ep.yaml
│ └── simmim_pretrain__swin_base__img192_window6__100ep.yaml
├── swin_base__800ep
│ ├── simmim_finetune__swin_base__img224_window7__800ep.yaml
│ └── simmim_pretrain__swin_base__img192_window6__800ep.yaml
├── swin_large__800ep
│ ├── simmim_finetune__swin_large__img224_window14__800ep.yaml
│ └── simmim_pretrain__swin_large__img192_window12__800ep.yaml
└── vit_base__800ep
│ ├── simmim_finetune__vit_base__img224__800ep.yaml
│ └── simmim_pretrain__vit_base__img224__800ep.yaml
├── data
├── __init__.py
├── data_finetune.py
└── data_simmim.py
├── figures
└── teaser.jpg
├── logger.py
├── lr_scheduler.py
├── main_finetune.py
├── main_simmim.py
├── models
├── __init__.py
├── build.py
├── simmim.py
├── swin_transformer.py
└── vision_transformer.py
├── optimizer.py
├── requirements.txt
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Microsoft Open Source Code of Conduct
2 |
3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4 |
5 | Resources:
6 |
7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
10 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SimMIM
2 |
3 | By [Zhenda Xie](https://zdaxie.github.io)\*, [Zheng Zhang](https://stupidzz.github.io/)\*, [Yue Cao](http://yue-cao.me)\*, [Yutong Lin](https://github.com/impiga), [Jianmin Bao](https://jianminbao.github.io/), [Zhuliang Yao](https://github.com/Howal), [Qi Dai](https://www.microsoft.com/en-us/research/people/qid/) and [Han Hu](https://ancientmooner.github.io/)\*.
4 |
5 | This repo is the official implementation of ["SimMIM: A Simple Framework for Masked Image Modeling"](https://arxiv.org/abs/2111.09886).
6 |
7 | ## Updates
8 |
9 | ***09/29/2022***
10 |
11 | SimMIM was merged to [Swin Transformer repo on GitHub](https://github.com/microsoft/Swin-Transformer).
12 |
13 | ***03/02/2022***
14 |
15 | SimMIM got accepted by CVPR 2022. SimMIM was used in ["Swin Transformer V2"](https://github.com/microsoft/Swin-Transformer) to alleviate the data hungry problem for large-scale vision model training.
16 |
17 | ***12/09/2021***
18 |
19 | Initial commits:
20 |
21 | 1. Pre-trained and fine-tuned models on ImageNet-1K (`Swin Base`, `Swin Large`, and `ViT Base`) are provided.
22 | 2. The supported code for ImageNet-1K pre-training and fine-tuneing is provided.
23 |
24 | ## Introduction
25 |
26 | **SimMIM** is initially described in [arxiv](https://arxiv.org/abs/2111.09886), which serves as a
27 | simple framework for masked image modeling. From systematically study, we find that simple designs of each component have revealed very strong representation learning performance: 1) random masking of the input image with a moderately large masked patch size (e.g., 32) makes a strong pre-text task; 2) predicting raw pixels of RGB values by direct regression performs no worse than the patch classification approaches with complex designs; 3) the prediction head can be as light as a linear layer, with no worse performance than heavier ones.
28 |
29 |
30 |

31 |
32 |
33 | ## Main Results on ImageNet
34 |
35 | ### Swin Transformer
36 |
37 | **ImageNet-1K Pre-trained and Fine-tuned Models**
38 |
39 | | name | pre-train epochs | pre-train resolution | fine-tune resolution | acc@1 | pre-trained model | fine-tuned model |
40 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
41 | | Swin-Base | 100 | 192x192 | 192x192 | 82.8 | [google](https://drive.google.com/file/d/1Wcbr66JL26FF30Kip9fZa_0lXrDAKP-d/view?usp=sharing)/[config](configs/swin_base__100ep/simmim_pretrain__swin_base__img192_window6__100ep.yaml) | [google](https://drive.google.com/file/d/1RsgHfjB4B1ZYblXEQVT-FPX3WSvBrxcs/view?usp=sharing)/[config](configs/swin_base__100ep/simmim_finetune__swin_base__img192_window6__100ep.yaml) |
42 | | Swin-Base | 100 | 192x192 | 224x224 | 83.5 | [google](https://drive.google.com/file/d/1Wcbr66JL26FF30Kip9fZa_0lXrDAKP-d/view?usp=sharing)/[config](configs/swin_base__100ep/simmim_pretrain__swin_base__img192_window6__100ep.yaml) | [google](https://drive.google.com/file/d/1mb43BkW56F5smwiX-g7QUUD7f1Rftq8u/view?usp=sharing)/[config](configs/swin_base__100ep/simmim_finetune__swin_base__img224_window7__100ep.yaml) |
43 | | Swin-Base | 800 | 192x192 | 224x224 | 84.0 | [google](https://drive.google.com/file/d/15zENvGjHlM71uKQ3d2FbljWPubtrPtjl/view?usp=sharing)/[config](configs/swin_base__800ep/simmim_pretrain__swin_base__img192_window6__800ep.yaml) | [google](https://drive.google.com/file/d/1xEKyfMTsdh6TfnYhk5vbw0Yz7a-viZ0w/view?usp=sharing)/[config](configs/swin_base__800ep/simmim_finetune__swin_base__img224_window7__800ep.yaml) |
44 | | Swin-Large | 800 | 192x192 | 224x224 | 85.4 | [google](https://drive.google.com/file/d/1qDxrTl2YUDB0505_4QrU5LU2R1kKmcBP/view?usp=sharing)/[config](configs/swin_large__800ep/simmim_pretrain__swin_large__img192_window12__800ep.yaml) | [google](https://drive.google.com/file/d/1mf0ZpXttEvFsH87Www4oQ-t8Kwr0x485/view?usp=sharing)/[config](configs/swin_large__800ep/simmim_finetune__swin_large__img224_window14__800ep.yaml) |
45 | | SwinV2-Huge | 800 | 192x192 | 224x224 | 85.7 | / | / |
46 | | SwinV2-Huge | 800 | 192x192 | 512x512 | 87.1 | / | / |
47 |
48 | ### Vision Transformer
49 |
50 | **ImageNet-1K Pre-trained and Fine-tuned Models**
51 |
52 | | name | pre-train epochs | pre-train resolution | fine-tune resolution | acc@1 | pre-trained model | fine-tuned model |
53 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
54 | | ViT-Base | 800 | 224x224 | 224x224 | 83.8 | [google](https://drive.google.com/file/d/1dJn6GYkwMIcoP3zqOEyW1_iQfpBi8UOw/view?usp=sharing)/[config](configs/vit_base__800ep/simmim_pretrain__vit_base__img224__800ep.yaml) | [google](https://drive.google.com/file/d/1fKgDYd0tRgyHyTnyB1CleYxjo0Gn5tEB/view?usp=sharing)/[config](configs/vit_base__800ep/simmim_finetune__vit_base__img224__800ep.yaml) |
55 |
56 | ## Citing SimMIM
57 |
58 | ```
59 | @inproceedings{xie2021simmim,
60 | title={SimMIM: A Simple Framework for Masked Image Modeling},
61 | author={Xie, Zhenda and Zhang, Zheng and Cao, Yue and Lin, Yutong and Bao, Jianmin and Yao, Zhuliang and Dai, Qi and Hu, Han},
62 | booktitle={International Conference on Computer Vision and Pattern Recognition (CVPR)},
63 | year={2022}
64 | }
65 | ```
66 |
67 | ## Getting Started
68 |
69 | ### Installation
70 |
71 | - Install `CUDA 11.3` with `cuDNN 8` following the official installation guide of [CUDA](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html) and [cuDNN](https://developer.nvidia.com/rdp/cudnn-archive).
72 |
73 | - Setup conda environment:
74 | ```bash
75 | # Create environment
76 | conda create -n SimMIM python=3.8 -y
77 | conda activate SimMIM
78 |
79 | # Install requirements
80 | conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch -y
81 |
82 | # Install apex
83 | git clone https://github.com/NVIDIA/apex
84 | cd apex
85 | pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
86 | cd ..
87 |
88 | # Clone SimMIM
89 | git clone https://github.com/microsoft/SimMIM
90 | cd SimMIM
91 |
92 | # Install other requirements
93 | pip install -r requirements.txt
94 | ```
95 |
96 | ### Evaluating provided models
97 |
98 | To evaluate a provided model on ImageNet validation set, run:
99 | ```bash
100 | python -m torch.distributed.launch --nproc_per_node main_finetune.py \
101 | --eval --cfg --resume --data-path
102 | ```
103 |
104 | For example, to evaluate the `Swin Base` model on a single GPU, run:
105 | ```bash
106 | python -m torch.distributed.launch --nproc_per_node 1 main_finetune.py \
107 | --eval --cfg configs/swin_base__800ep/simmim_finetune__swin_base__img224_window7__800ep.yaml --resume simmim_finetune__swin_base__img224_window7__800ep.pth --data-path
108 | ```
109 |
110 | ### Pre-training with SimMIM
111 | To pre-train models with `SimMIM`, run:
112 | ```bash
113 | python -m torch.distributed.launch --nproc_per_node main_simmim.py \
114 | --cfg --data-path /train [--batch-size --output --tag ]
115 | ```
116 |
117 | For example, to pre-train `Swin Base` for 800 epochs on one DGX-2 server, run:
118 | ```bash
119 | python -m torch.distributed.launch --nproc_per_node 16 main_simmim.py \
120 | --cfg configs/swin_base__800ep/simmim_pretrain__swin_base__img192_window6__800ep.yaml --batch-size 128 --data-path /train [--output --tag ]
121 | ```
122 |
123 | ### Fine-tuning pre-trained models
124 | To fine-tune models pre-trained by `SimMIM`, run:
125 | ```bash
126 | python -m torch.distributed.launch --nproc_per_node main_finetune.py \
127 | --cfg --data-path --pretrained [--batch-size --output --tag ]
128 | ```
129 |
130 | For example, to fine-tune `Swin Base` pre-trained by `SimMIM` on one DGX-2 server, run:
131 | ```bash
132 | python -m torch.distributed.launch --nproc_per_node 16 main_finetune.py \
133 | --cfg configs/swin_base__800ep/simmim_finetune__swin_base__img224_window7__800ep.yaml --batch-size 128 --data-path --pretrained [--output --tag ]
134 | ```
135 |
136 | ## Contributing
137 |
138 | This project welcomes contributions and suggestions. Most contributions require you to agree to a
139 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
140 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
141 |
142 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide
143 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
144 | provided by the bot. You will only need to do this once across all repos using our CLA.
145 |
146 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
147 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
148 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
149 |
150 | ## Trademarks
151 |
152 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
153 | trademarks or logos is subject to and must follow
154 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
155 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
156 | Any use of third-party trademarks or logos are subject to those third-party's policies.
157 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd).
40 |
41 |
--------------------------------------------------------------------------------
/SUPPORT.md:
--------------------------------------------------------------------------------
1 | # TODO: The maintainer of this repo has not yet edited this file
2 |
3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
4 |
5 | - **No CSS support:** Fill out this template with information about how to file issues and get help.
6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/spot](https://aka.ms/spot). CSS will work with/help you to determine next steps. More details also available at [aka.ms/onboardsupport](https://aka.ms/onboardsupport).
7 | - **Not sure?** Fill out a SPOT intake as though the answer were "Yes". CSS will help you decide.
8 |
9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
10 |
11 | # Support
12 |
13 | ## How to file issues and get help
14 |
15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing
16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or
17 | feature request as a new Issue.
18 |
19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE
20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
22 |
23 | ## Microsoft Support Policy
24 |
25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
26 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # SimMIM
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # Modified by Zhenda Xie
7 | # --------------------------------------------------------
8 |
9 | import os
10 | import yaml
11 | from yacs.config import CfgNode as CN
12 |
13 | _C = CN()
14 |
15 | # Base config files
16 | _C.BASE = ['']
17 |
18 | # -----------------------------------------------------------------------------
19 | # Data settings
20 | # -----------------------------------------------------------------------------
21 | _C.DATA = CN()
22 | # Batch size for a single GPU, could be overwritten by command line argument
23 | _C.DATA.BATCH_SIZE = 128
24 | # Path to dataset, could be overwritten by command line argument
25 | _C.DATA.DATA_PATH = ''
26 | # Dataset name
27 | _C.DATA.DATASET = 'imagenet'
28 | # Input image size
29 | _C.DATA.IMG_SIZE = 224
30 | # Interpolation to resize image (random, bilinear, bicubic)
31 | _C.DATA.INTERPOLATION = 'bicubic'
32 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
33 | _C.DATA.PIN_MEMORY = True
34 | # Number of data loading threads
35 | _C.DATA.NUM_WORKERS = 8
36 |
37 | # [SimMIM] Mask patch size for MaskGenerator
38 | _C.DATA.MASK_PATCH_SIZE = 32
39 | # [SimMIM] Mask ratio for MaskGenerator
40 | _C.DATA.MASK_RATIO = 0.6
41 |
42 | # -----------------------------------------------------------------------------
43 | # Model settings
44 | # -----------------------------------------------------------------------------
45 | _C.MODEL = CN()
46 | # Model type
47 | _C.MODEL.TYPE = 'swin'
48 | # Model name
49 | _C.MODEL.NAME = 'swin_tiny_patch4_window7_224'
50 | # Checkpoint to resume, could be overwritten by command line argument
51 | _C.MODEL.RESUME = ''
52 | # Number of classes, overwritten in data preparation
53 | _C.MODEL.NUM_CLASSES = 1000
54 | # Dropout rate
55 | _C.MODEL.DROP_RATE = 0.0
56 | # Drop path rate
57 | _C.MODEL.DROP_PATH_RATE = 0.1
58 | # Label Smoothing
59 | _C.MODEL.LABEL_SMOOTHING = 0.1
60 |
61 | # Swin Transformer parameters
62 | _C.MODEL.SWIN = CN()
63 | _C.MODEL.SWIN.PATCH_SIZE = 4
64 | _C.MODEL.SWIN.IN_CHANS = 3
65 | _C.MODEL.SWIN.EMBED_DIM = 96
66 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
67 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
68 | _C.MODEL.SWIN.WINDOW_SIZE = 7
69 | _C.MODEL.SWIN.MLP_RATIO = 4.
70 | _C.MODEL.SWIN.QKV_BIAS = True
71 | _C.MODEL.SWIN.QK_SCALE = None
72 | _C.MODEL.SWIN.APE = False
73 | _C.MODEL.SWIN.PATCH_NORM = True
74 |
75 | # Vision Transformer parameters
76 | _C.MODEL.VIT = CN()
77 | _C.MODEL.VIT.PATCH_SIZE = 16
78 | _C.MODEL.VIT.IN_CHANS = 3
79 | _C.MODEL.VIT.EMBED_DIM = 768
80 | _C.MODEL.VIT.DEPTH = 12
81 | _C.MODEL.VIT.NUM_HEADS = 12
82 | _C.MODEL.VIT.MLP_RATIO = 4
83 | _C.MODEL.VIT.QKV_BIAS = True
84 | _C.MODEL.VIT.INIT_VALUES = 0.1
85 | _C.MODEL.VIT.USE_APE = False
86 | _C.MODEL.VIT.USE_RPB = False
87 | _C.MODEL.VIT.USE_SHARED_RPB = True
88 | _C.MODEL.VIT.USE_MEAN_POOLING = False
89 |
90 | # -----------------------------------------------------------------------------
91 | # Training settings
92 | # -----------------------------------------------------------------------------
93 | _C.TRAIN = CN()
94 | _C.TRAIN.START_EPOCH = 0
95 | _C.TRAIN.EPOCHS = 300
96 | _C.TRAIN.WARMUP_EPOCHS = 20
97 | _C.TRAIN.WEIGHT_DECAY = 0.05
98 | _C.TRAIN.BASE_LR = 5e-4
99 | _C.TRAIN.WARMUP_LR = 5e-7
100 | _C.TRAIN.MIN_LR = 5e-6
101 | # Clip gradient norm
102 | _C.TRAIN.CLIP_GRAD = 5.0
103 | # Auto resume from latest checkpoint
104 | _C.TRAIN.AUTO_RESUME = True
105 | # Gradient accumulation steps
106 | # could be overwritten by command line argument
107 | _C.TRAIN.ACCUMULATION_STEPS = 0
108 | # Whether to use gradient checkpointing to save memory
109 | # could be overwritten by command line argument
110 | _C.TRAIN.USE_CHECKPOINT = False
111 |
112 | # LR scheduler
113 | _C.TRAIN.LR_SCHEDULER = CN()
114 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
115 | # Epoch interval to decay LR, used in StepLRScheduler
116 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
117 | # LR decay rate, used in StepLRScheduler
118 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
119 | # Gamma / Multi steps value, used in MultiStepLRScheduler
120 | _C.TRAIN.LR_SCHEDULER.GAMMA = 0.1
121 | _C.TRAIN.LR_SCHEDULER.MULTISTEPS = []
122 |
123 | # Optimizer
124 | _C.TRAIN.OPTIMIZER = CN()
125 | _C.TRAIN.OPTIMIZER.NAME = 'adamw'
126 | # Optimizer Epsilon
127 | _C.TRAIN.OPTIMIZER.EPS = 1e-8
128 | # Optimizer Betas
129 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
130 | # SGD momentum
131 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
132 |
133 | # [SimMIM] Layer decay for fine-tuning
134 | _C.TRAIN.LAYER_DECAY = 1.0
135 |
136 | # -----------------------------------------------------------------------------
137 | # Augmentation settings
138 | # -----------------------------------------------------------------------------
139 | _C.AUG = CN()
140 | # Color jitter factor
141 | _C.AUG.COLOR_JITTER = 0.4
142 | # Use AutoAugment policy. "v0" or "original"
143 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
144 | # Random erase prob
145 | _C.AUG.REPROB = 0.25
146 | # Random erase mode
147 | _C.AUG.REMODE = 'pixel'
148 | # Random erase count
149 | _C.AUG.RECOUNT = 1
150 | # Mixup alpha, mixup enabled if > 0
151 | _C.AUG.MIXUP = 0.8
152 | # Cutmix alpha, cutmix enabled if > 0
153 | _C.AUG.CUTMIX = 1.0
154 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set
155 | _C.AUG.CUTMIX_MINMAX = None
156 | # Probability of performing mixup or cutmix when either/both is enabled
157 | _C.AUG.MIXUP_PROB = 1.0
158 | # Probability of switching to cutmix when both mixup and cutmix enabled
159 | _C.AUG.MIXUP_SWITCH_PROB = 0.5
160 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
161 | _C.AUG.MIXUP_MODE = 'batch'
162 |
163 | # -----------------------------------------------------------------------------
164 | # Testing settings
165 | # -----------------------------------------------------------------------------
166 | _C.TEST = CN()
167 | # Whether to use center crop when testing
168 | _C.TEST.CROP = True
169 |
170 | # -----------------------------------------------------------------------------
171 | # Misc
172 | # -----------------------------------------------------------------------------
173 | # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2')
174 | # overwritten by command line argument
175 | _C.AMP_OPT_LEVEL = ''
176 | # Path to output folder, overwritten by command line argument
177 | _C.OUTPUT = ''
178 | # Tag of experiment, overwritten by command line argument
179 | _C.TAG = 'default'
180 | # Frequency to save checkpoint
181 | _C.SAVE_FREQ = 1
182 | # Frequency to logging info
183 | _C.PRINT_FREQ = 10
184 | # Fixed random seed
185 | _C.SEED = 0
186 | # Perform evaluation only, overwritten by command line argument
187 | _C.EVAL_MODE = False
188 | # Test throughput only, overwritten by command line argument
189 | _C.THROUGHPUT_MODE = False
190 | # local rank for DistributedDataParallel, given by command line argument
191 | _C.LOCAL_RANK = 0
192 |
193 | # [SimMIM] path to pre-trained model
194 | _C.PRETRAINED = ''
195 |
196 |
197 | def _update_config_from_file(config, cfg_file):
198 | config.defrost()
199 | with open(cfg_file, 'r') as f:
200 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
201 |
202 | for cfg in yaml_cfg.setdefault('BASE', ['']):
203 | if cfg:
204 | _update_config_from_file(
205 | config, os.path.join(os.path.dirname(cfg_file), cfg)
206 | )
207 | print('=> merge config from {}'.format(cfg_file))
208 | config.merge_from_file(cfg_file)
209 | config.freeze()
210 |
211 |
212 | def update_config(config, args):
213 | _update_config_from_file(config, args.cfg)
214 |
215 | config.defrost()
216 | if args.opts:
217 | config.merge_from_list(args.opts)
218 |
219 | def _check_args(name):
220 | if hasattr(args, name) and eval(f'args.{name}'):
221 | return True
222 | return False
223 |
224 | # merge from specific arguments
225 | if _check_args('batch_size'):
226 | config.DATA.BATCH_SIZE = args.batch_size
227 | if _check_args('data_path'):
228 | config.DATA.DATA_PATH = args.data_path
229 | if _check_args('resume'):
230 | config.MODEL.RESUME = args.resume
231 | if _check_args('pretrained'):
232 | config.PRETRAINED = args.pretrained
233 | if _check_args('accumulation_steps'):
234 | config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps
235 | if _check_args('use_checkpoint'):
236 | config.TRAIN.USE_CHECKPOINT = True
237 | if _check_args('amp_opt_level'):
238 | config.AMP_OPT_LEVEL = args.amp_opt_level
239 | if _check_args('output'):
240 | config.OUTPUT = args.output
241 | if _check_args('tag'):
242 | config.TAG = args.tag
243 | if _check_args('eval'):
244 | config.EVAL_MODE = True
245 | if _check_args('throughput'):
246 | config.THROUGHPUT_MODE = True
247 |
248 | # set local rank for distributed training
249 | config.LOCAL_RANK = args.local_rank
250 |
251 | # output folder
252 | config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG)
253 |
254 | config.freeze()
255 |
256 |
257 | def get_config(args):
258 | """Get a yacs CfgNode object with default values."""
259 | # Return a clone so that the defaults will not be altered
260 | # This is for the "local variable" use pattern
261 | config = _C.clone()
262 | update_config(config, args)
263 |
264 | return config
265 |
--------------------------------------------------------------------------------
/configs/swin_base__100ep/simmim_finetune__swin_base__img192_window6__100ep.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: swin
3 | NAME: simmim_finetune
4 | DROP_PATH_RATE: 0.1
5 | SWIN:
6 | EMBED_DIM: 128
7 | DEPTHS: [ 2, 2, 18, 2 ]
8 | NUM_HEADS: [ 4, 8, 16, 32 ]
9 | WINDOW_SIZE: 6
10 | DATA:
11 | IMG_SIZE: 192
12 | TRAIN:
13 | EPOCHS: 100
14 | WARMUP_EPOCHS: 20
15 | BASE_LR: 1.25e-3
16 | WARMUP_LR: 2.5e-7
17 | MIN_LR: 2.5e-7
18 | WEIGHT_DECAY: 0.05
19 | LAYER_DECAY: 0.9
20 | PRINT_FREQ: 100
21 | SAVE_FREQ: 5
22 | TAG: simmim_finetune__swin_base__img192_window6__100ep
--------------------------------------------------------------------------------
/configs/swin_base__100ep/simmim_finetune__swin_base__img224_window7__100ep.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: swin
3 | NAME: simmim_finetune
4 | DROP_PATH_RATE: 0.1
5 | SWIN:
6 | EMBED_DIM: 128
7 | DEPTHS: [ 2, 2, 18, 2 ]
8 | NUM_HEADS: [ 4, 8, 16, 32 ]
9 | WINDOW_SIZE: 7
10 | DATA:
11 | IMG_SIZE: 224
12 | TRAIN:
13 | EPOCHS: 100
14 | WARMUP_EPOCHS: 20
15 | BASE_LR: 1.25e-3
16 | WARMUP_LR: 2.5e-7
17 | MIN_LR: 2.5e-7
18 | WEIGHT_DECAY: 0.05
19 | LAYER_DECAY: 0.9
20 | PRINT_FREQ: 100
21 | SAVE_FREQ: 5
22 | TAG: simmim_finetune__swin_base__img224_window7__100ep
--------------------------------------------------------------------------------
/configs/swin_base__100ep/simmim_pretrain__swin_base__img192_window6__100ep.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: swin
3 | NAME: simmim_pretrain
4 | DROP_PATH_RATE: 0.0
5 | SWIN:
6 | EMBED_DIM: 128
7 | DEPTHS: [ 2, 2, 18, 2 ]
8 | NUM_HEADS: [ 4, 8, 16, 32 ]
9 | WINDOW_SIZE: 6
10 | DATA:
11 | IMG_SIZE: 192
12 | MASK_PATCH_SIZE: 32
13 | MASK_RATIO: 0.6
14 | TRAIN:
15 | EPOCHS: 100
16 | WARMUP_EPOCHS: 10
17 | BASE_LR: 2e-4
18 | WARMUP_LR: 1e-6
19 | MIN_LR: 1e-5
20 | WEIGHT_DECAY: 0.05
21 | PRINT_FREQ: 100
22 | SAVE_FREQ: 5
23 | TAG: simmim_pretrain__swin_base__img192_window6__100ep
--------------------------------------------------------------------------------
/configs/swin_base__800ep/simmim_finetune__swin_base__img224_window7__800ep.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: swin
3 | NAME: simmim_finetune
4 | DROP_PATH_RATE: 0.1
5 | SWIN:
6 | EMBED_DIM: 128
7 | DEPTHS: [ 2, 2, 18, 2 ]
8 | NUM_HEADS: [ 4, 8, 16, 32 ]
9 | WINDOW_SIZE: 7
10 | DATA:
11 | IMG_SIZE: 224
12 | TRAIN:
13 | EPOCHS: 100
14 | WARMUP_EPOCHS: 20
15 | BASE_LR: 1.25e-3
16 | WARMUP_LR: 2.5e-7
17 | MIN_LR: 2.5e-7
18 | WEIGHT_DECAY: 0.05
19 | LAYER_DECAY: 0.8
20 | PRINT_FREQ: 100
21 | SAVE_FREQ: 5
22 | TAG: simmim_finetune__swin_base__img224_window7__800ep
--------------------------------------------------------------------------------
/configs/swin_base__800ep/simmim_pretrain__swin_base__img192_window6__800ep.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: swin
3 | NAME: simmim_pretrain
4 | DROP_PATH_RATE: 0.0
5 | SWIN:
6 | EMBED_DIM: 128
7 | DEPTHS: [ 2, 2, 18, 2 ]
8 | NUM_HEADS: [ 4, 8, 16, 32 ]
9 | WINDOW_SIZE: 6
10 | DATA:
11 | IMG_SIZE: 192
12 | MASK_PATCH_SIZE: 32
13 | MASK_RATIO: 0.6
14 | TRAIN:
15 | EPOCHS: 800
16 | WARMUP_EPOCHS: 10
17 | BASE_LR: 1e-4
18 | WARMUP_LR: 5e-7
19 | WEIGHT_DECAY: 0.05
20 | LR_SCHEDULER:
21 | NAME: 'multistep'
22 | GAMMA: 0.1
23 | MULTISTEPS: [700,]
24 | PRINT_FREQ: 100
25 | SAVE_FREQ: 5
26 | TAG: simmim_pretrain__swin_base__img192_window6__800ep
--------------------------------------------------------------------------------
/configs/swin_large__800ep/simmim_finetune__swin_large__img224_window14__800ep.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: swin
3 | NAME: simmim_finetune
4 | DROP_PATH_RATE: 0.2
5 | SWIN:
6 | EMBED_DIM: 192
7 | DEPTHS: [ 2, 2, 18, 2 ]
8 | NUM_HEADS: [ 6, 12, 24, 48 ]
9 | WINDOW_SIZE: 14
10 | DATA:
11 | IMG_SIZE: 224
12 | TRAIN:
13 | EPOCHS: 100
14 | WARMUP_EPOCHS: 20
15 | BASE_LR: 1.25e-3
16 | WARMUP_LR: 2.5e-7
17 | MIN_LR: 2.5e-7
18 | WEIGHT_DECAY: 0.05
19 | LAYER_DECAY: 0.7
20 | PRINT_FREQ: 100
21 | SAVE_FREQ: 5
22 | TAG: simmim_finetune__swin_large__img224_window14__800ep
--------------------------------------------------------------------------------
/configs/swin_large__800ep/simmim_pretrain__swin_large__img192_window12__800ep.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: swin
3 | NAME: simmim_pretrain
4 | DROP_PATH_RATE: 0.0
5 | SWIN:
6 | EMBED_DIM: 192
7 | DEPTHS: [ 2, 2, 18, 2 ]
8 | NUM_HEADS: [ 6, 12, 24, 48 ]
9 | WINDOW_SIZE: 12
10 | DATA:
11 | IMG_SIZE: 192
12 | MASK_PATCH_SIZE: 32
13 | MASK_RATIO: 0.6
14 | TRAIN:
15 | EPOCHS: 800
16 | WARMUP_EPOCHS: 10
17 | BASE_LR: 1e-4
18 | WARMUP_LR: 5e-7
19 | WEIGHT_DECAY: 0.05
20 | LR_SCHEDULER:
21 | NAME: 'multistep'
22 | GAMMA: 0.1
23 | MULTISTEPS: [700,]
24 | PRINT_FREQ: 100
25 | SAVE_FREQ: 5
26 | TAG: simmim_pretrain__swin_large__img192_window12__800ep
--------------------------------------------------------------------------------
/configs/vit_base__800ep/simmim_finetune__vit_base__img224__800ep.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: vit
3 | NAME: simmim_finetune
4 | DROP_PATH_RATE: 0.1
5 | VIT:
6 | EMBED_DIM: 768
7 | DEPTH: 12
8 | NUM_HEADS: 12
9 | USE_APE: False
10 | USE_RPB: True
11 | USE_SHARED_RPB: False
12 | USE_MEAN_POOLING: True
13 | DATA:
14 | IMG_SIZE: 224
15 | TRAIN:
16 | EPOCHS: 100
17 | WARMUP_EPOCHS: 20
18 | BASE_LR: 1.25e-3
19 | WARMUP_LR: 2.5e-7
20 | MIN_LR: 2.5e-7
21 | WEIGHT_DECAY: 0.05
22 | LAYER_DECAY: 0.65
23 | PRINT_FREQ: 100
24 | SAVE_FREQ: 5
25 | TAG: simmim_finetune__vit_base__img224__800ep
26 |
--------------------------------------------------------------------------------
/configs/vit_base__800ep/simmim_pretrain__vit_base__img224__800ep.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | TYPE: vit
3 | NAME: simmim_pretrain
4 | DROP_PATH_RATE: 0.1
5 | VIT:
6 | EMBED_DIM: 768
7 | DEPTH: 12
8 | NUM_HEADS: 12
9 | USE_APE: False
10 | USE_RPB: False
11 | USE_SHARED_RPB: True
12 | USE_MEAN_POOLING: False
13 | DATA:
14 | IMG_SIZE: 224
15 | MASK_PATCH_SIZE: 32
16 | MASK_RATIO: 0.6
17 | TRAIN:
18 | EPOCHS: 800
19 | WARMUP_EPOCHS: 10
20 | BASE_LR: 1e-4
21 | WARMUP_LR: 5e-7
22 | WEIGHT_DECAY: 0.05
23 | LR_SCHEDULER:
24 | NAME: 'multistep'
25 | GAMMA: 0.1
26 | MULTISTEPS: [700,]
27 | PRINT_FREQ: 100
28 | SAVE_FREQ: 5
29 | TAG: simmim_pretrain__vit_base__img224__800ep
30 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .data_simmim import build_loader_simmim
2 | from .data_finetune import build_loader_finetune
3 |
4 | def build_loader(config, logger, is_pretrain):
5 | if is_pretrain:
6 | return build_loader_simmim(config, logger)
7 | else:
8 | return build_loader_finetune(config, logger)
--------------------------------------------------------------------------------
/data/data_finetune.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # SimMIM
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Zhenda Xie
6 | # --------------------------------------------------------
7 |
8 | import os
9 | import torch.distributed as dist
10 | from torch.utils.data import DataLoader, DistributedSampler
11 | from torchvision import datasets, transforms
12 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
13 | from timm.data import Mixup
14 | from timm.data import create_transform
15 | from timm.data.transforms import _pil_interp
16 |
17 |
18 | def build_loader_finetune(config, logger):
19 | config.defrost()
20 | dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config, logger=logger)
21 | config.freeze()
22 | dataset_val, _ = build_dataset(is_train=False, config=config, logger=logger)
23 | logger.info(f"Build dataset: train images = {len(dataset_train)}, val images = {len(dataset_val)}")
24 |
25 | num_tasks = dist.get_world_size()
26 | global_rank = dist.get_rank()
27 | sampler_train = DistributedSampler(
28 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
29 | )
30 | sampler_val = DistributedSampler(
31 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False
32 | )
33 |
34 | data_loader_train = DataLoader(
35 | dataset_train, sampler=sampler_train,
36 | batch_size=config.DATA.BATCH_SIZE,
37 | num_workers=config.DATA.NUM_WORKERS,
38 | pin_memory=config.DATA.PIN_MEMORY,
39 | drop_last=True,
40 | )
41 |
42 | data_loader_val = DataLoader(
43 | dataset_val, sampler=sampler_val,
44 | batch_size=config.DATA.BATCH_SIZE,
45 | num_workers=config.DATA.NUM_WORKERS,
46 | pin_memory=config.DATA.PIN_MEMORY,
47 | drop_last=False,
48 | )
49 |
50 | # setup mixup / cutmix
51 | mixup_fn = None
52 | mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None
53 | if mixup_active:
54 | mixup_fn = Mixup(
55 | mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX,
56 | prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE,
57 | label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES)
58 |
59 | return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn
60 |
61 |
62 | def build_dataset(is_train, config, logger):
63 | transform = build_transform(is_train, config)
64 | logger.info(f'Fine-tune data transform, is_train={is_train}:\n{transform}')
65 |
66 | if config.DATA.DATASET == 'imagenet':
67 | prefix = 'train' if is_train else 'val'
68 | root = os.path.join(config.DATA.DATA_PATH, prefix)
69 | dataset = datasets.ImageFolder(root, transform=transform)
70 | nb_classes = 1000
71 | else:
72 | raise NotImplementedError("We only support ImageNet Now.")
73 |
74 | return dataset, nb_classes
75 |
76 |
77 | def build_transform(is_train, config):
78 | resize_im = config.DATA.IMG_SIZE > 32
79 | if is_train:
80 | # this should always dispatch to transforms_imagenet_train
81 | transform = create_transform(
82 | input_size=config.DATA.IMG_SIZE,
83 | is_training=True,
84 | color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None,
85 | auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None,
86 | re_prob=config.AUG.REPROB,
87 | re_mode=config.AUG.REMODE,
88 | re_count=config.AUG.RECOUNT,
89 | interpolation=config.DATA.INTERPOLATION,
90 | )
91 | if not resize_im:
92 | # replace RandomResizedCropAndInterpolation with
93 | # RandomCrop
94 | transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4)
95 | return transform
96 |
97 | t = []
98 | if resize_im:
99 | if config.TEST.CROP:
100 | size = int((256 / 224) * config.DATA.IMG_SIZE)
101 | t.append(
102 | transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)),
103 | # to maintain same ratio w.r.t. 224 images
104 | )
105 | t.append(transforms.CenterCrop(config.DATA.IMG_SIZE))
106 | else:
107 | t.append(
108 | transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
109 | interpolation=_pil_interp(config.DATA.INTERPOLATION))
110 | )
111 |
112 | t.append(transforms.ToTensor())
113 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD))
114 | return transforms.Compose(t)
--------------------------------------------------------------------------------
/data/data_simmim.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # SimMIM
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Zhenda Xie
6 | # --------------------------------------------------------
7 |
8 | import math
9 | import random
10 | import numpy as np
11 |
12 | import torch
13 | import torch.distributed as dist
14 | import torchvision.transforms as T
15 | from torch.utils.data import DataLoader, DistributedSampler
16 | from torch.utils.data._utils.collate import default_collate
17 | from torchvision.datasets import ImageFolder
18 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
19 |
20 |
21 | class MaskGenerator:
22 | def __init__(self, input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6):
23 | self.input_size = input_size
24 | self.mask_patch_size = mask_patch_size
25 | self.model_patch_size = model_patch_size
26 | self.mask_ratio = mask_ratio
27 |
28 | assert self.input_size % self.mask_patch_size == 0
29 | assert self.mask_patch_size % self.model_patch_size == 0
30 |
31 | self.rand_size = self.input_size // self.mask_patch_size
32 | self.scale = self.mask_patch_size // self.model_patch_size
33 |
34 | self.token_count = self.rand_size ** 2
35 | self.mask_count = int(np.ceil(self.token_count * self.mask_ratio))
36 |
37 | def __call__(self):
38 | mask_idx = np.random.permutation(self.token_count)[:self.mask_count]
39 | mask = np.zeros(self.token_count, dtype=int)
40 | mask[mask_idx] = 1
41 |
42 | mask = mask.reshape((self.rand_size, self.rand_size))
43 | mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1)
44 |
45 | return mask
46 |
47 |
48 | class SimMIMTransform:
49 | def __init__(self, config):
50 | self.transform_img = T.Compose([
51 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
52 | T.RandomResizedCrop(config.DATA.IMG_SIZE, scale=(0.67, 1.), ratio=(3. / 4., 4. / 3.)),
53 | T.RandomHorizontalFlip(),
54 | T.ToTensor(),
55 | T.Normalize(mean=torch.tensor(IMAGENET_DEFAULT_MEAN),std=torch.tensor(IMAGENET_DEFAULT_STD)),
56 | ])
57 |
58 | if config.MODEL.TYPE == 'swin':
59 | model_patch_size=config.MODEL.SWIN.PATCH_SIZE
60 | elif config.MODEL.TYPE == 'vit':
61 | model_patch_size=config.MODEL.VIT.PATCH_SIZE
62 | else:
63 | raise NotImplementedError
64 |
65 | self.mask_generator = MaskGenerator(
66 | input_size=config.DATA.IMG_SIZE,
67 | mask_patch_size=config.DATA.MASK_PATCH_SIZE,
68 | model_patch_size=model_patch_size,
69 | mask_ratio=config.DATA.MASK_RATIO,
70 | )
71 |
72 | def __call__(self, img):
73 | img = self.transform_img(img)
74 | mask = self.mask_generator()
75 |
76 | return img, mask
77 |
78 |
79 | def collate_fn(batch):
80 | if not isinstance(batch[0][0], tuple):
81 | return default_collate(batch)
82 | else:
83 | batch_num = len(batch)
84 | ret = []
85 | for item_idx in range(len(batch[0][0])):
86 | if batch[0][0][item_idx] is None:
87 | ret.append(None)
88 | else:
89 | ret.append(default_collate([batch[i][0][item_idx] for i in range(batch_num)]))
90 | ret.append(default_collate([batch[i][1] for i in range(batch_num)]))
91 | return ret
92 |
93 |
94 | def build_loader_simmim(config, logger):
95 | transform = SimMIMTransform(config)
96 | logger.info(f'Pre-train data transform:\n{transform}')
97 |
98 | dataset = ImageFolder(config.DATA.DATA_PATH, transform)
99 | logger.info(f'Build dataset: train images = {len(dataset)}')
100 |
101 | sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True)
102 | dataloader = DataLoader(dataset, config.DATA.BATCH_SIZE, sampler=sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, drop_last=True, collate_fn=collate_fn)
103 |
104 | return dataloader
--------------------------------------------------------------------------------
/figures/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/SimMIM/d3e29bcac950b83edc34ca33fe4404f38309052c/figures/teaser.jpg
--------------------------------------------------------------------------------
/logger.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # SimMIM
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # Modified by Zhenda Xie
7 | # --------------------------------------------------------
8 |
9 | import os
10 | import sys
11 | import logging
12 | import functools
13 | from termcolor import colored
14 |
15 |
16 | @functools.lru_cache()
17 | def create_logger(output_dir, dist_rank=0, name=''):
18 | # create logger
19 | logger = logging.getLogger(name)
20 | logger.setLevel(logging.DEBUG)
21 | logger.propagate = False
22 |
23 | # create formatter
24 | fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s'
25 | color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \
26 | colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s'
27 |
28 | # create console handlers for master process
29 | if dist_rank == 0:
30 | console_handler = logging.StreamHandler(sys.stdout)
31 | console_handler.setLevel(logging.DEBUG)
32 | console_handler.setFormatter(
33 | logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S'))
34 | logger.addHandler(console_handler)
35 |
36 | # create file handlers
37 | file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a')
38 | file_handler.setLevel(logging.DEBUG)
39 | file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S'))
40 | logger.addHandler(file_handler)
41 |
42 | return logger
43 |
--------------------------------------------------------------------------------
/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # SimMIM
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # Modified by Zhenda Xie
7 | # --------------------------------------------------------
8 |
9 | from collections import Counter
10 | from bisect import bisect_right
11 |
12 | import torch
13 | from timm.scheduler.cosine_lr import CosineLRScheduler
14 | from timm.scheduler.step_lr import StepLRScheduler
15 | from timm.scheduler.scheduler import Scheduler
16 |
17 |
18 | def build_scheduler(config, optimizer, n_iter_per_epoch):
19 | num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch)
20 | warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch)
21 | decay_steps = int(config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch)
22 | multi_steps = [i * n_iter_per_epoch for i in config.TRAIN.LR_SCHEDULER.MULTISTEPS]
23 |
24 | lr_scheduler = None
25 | if config.TRAIN.LR_SCHEDULER.NAME == 'cosine':
26 | lr_scheduler = CosineLRScheduler(
27 | optimizer,
28 | t_initial=num_steps,
29 | t_mul=1.,
30 | lr_min=config.TRAIN.MIN_LR,
31 | warmup_lr_init=config.TRAIN.WARMUP_LR,
32 | warmup_t=warmup_steps,
33 | cycle_limit=1,
34 | t_in_epochs=False,
35 | )
36 | elif config.TRAIN.LR_SCHEDULER.NAME == 'linear':
37 | lr_scheduler = LinearLRScheduler(
38 | optimizer,
39 | t_initial=num_steps,
40 | lr_min_rate=0.01,
41 | warmup_lr_init=config.TRAIN.WARMUP_LR,
42 | warmup_t=warmup_steps,
43 | t_in_epochs=False,
44 | )
45 | elif config.TRAIN.LR_SCHEDULER.NAME == 'step':
46 | lr_scheduler = StepLRScheduler(
47 | optimizer,
48 | decay_t=decay_steps,
49 | decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE,
50 | warmup_lr_init=config.TRAIN.WARMUP_LR,
51 | warmup_t=warmup_steps,
52 | t_in_epochs=False,
53 | )
54 | elif config.TRAIN.LR_SCHEDULER.NAME == 'multistep':
55 | lr_scheduler = MultiStepLRScheduler(
56 | optimizer,
57 | milestones=multi_steps,
58 | gamma=config.TRAIN.LR_SCHEDULER.GAMMA,
59 | warmup_lr_init=config.TRAIN.WARMUP_LR,
60 | warmup_t=warmup_steps,
61 | t_in_epochs=False,
62 | )
63 |
64 | return lr_scheduler
65 |
66 |
67 | class LinearLRScheduler(Scheduler):
68 | def __init__(self,
69 | optimizer: torch.optim.Optimizer,
70 | t_initial: int,
71 | lr_min_rate: float,
72 | warmup_t=0,
73 | warmup_lr_init=0.,
74 | t_in_epochs=True,
75 | noise_range_t=None,
76 | noise_pct=0.67,
77 | noise_std=1.0,
78 | noise_seed=42,
79 | initialize=True,
80 | ) -> None:
81 | super().__init__(
82 | optimizer, param_group_field="lr",
83 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
84 | initialize=initialize)
85 |
86 | self.t_initial = t_initial
87 | self.lr_min_rate = lr_min_rate
88 | self.warmup_t = warmup_t
89 | self.warmup_lr_init = warmup_lr_init
90 | self.t_in_epochs = t_in_epochs
91 | if self.warmup_t:
92 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
93 | super().update_groups(self.warmup_lr_init)
94 | else:
95 | self.warmup_steps = [1 for _ in self.base_values]
96 |
97 | def _get_lr(self, t):
98 | if t < self.warmup_t:
99 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
100 | else:
101 | t = t - self.warmup_t
102 | total_t = self.t_initial - self.warmup_t
103 | lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values]
104 | return lrs
105 |
106 | def get_epoch_values(self, epoch: int):
107 | if self.t_in_epochs:
108 | return self._get_lr(epoch)
109 | else:
110 | return None
111 |
112 | def get_update_values(self, num_updates: int):
113 | if not self.t_in_epochs:
114 | return self._get_lr(num_updates)
115 | else:
116 | return None
117 |
118 |
119 | class MultiStepLRScheduler(Scheduler):
120 | def __init__(self, optimizer: torch.optim.Optimizer, milestones, gamma=0.1, warmup_t=0, warmup_lr_init=0, t_in_epochs=True) -> None:
121 | super().__init__(optimizer, param_group_field="lr")
122 |
123 | self.milestones = milestones
124 | self.gamma = gamma
125 | self.warmup_t = warmup_t
126 | self.warmup_lr_init = warmup_lr_init
127 | self.t_in_epochs = t_in_epochs
128 | if self.warmup_t:
129 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
130 | super().update_groups(self.warmup_lr_init)
131 | else:
132 | self.warmup_steps = [1 for _ in self.base_values]
133 |
134 | assert self.warmup_t <= min(self.milestones)
135 |
136 | def _get_lr(self, t):
137 | if t < self.warmup_t:
138 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
139 | else:
140 | lrs = [v * (self.gamma ** bisect_right(self.milestones, t)) for v in self.base_values]
141 | return lrs
142 |
143 | def get_epoch_values(self, epoch: int):
144 | if self.t_in_epochs:
145 | return self._get_lr(epoch)
146 | else:
147 | return None
148 |
149 | def get_update_values(self, num_updates: int):
150 | if not self.t_in_epochs:
151 | return self._get_lr(num_updates)
152 | else:
153 | return None
--------------------------------------------------------------------------------
/main_finetune.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # SimMIM
3 | # Copyright (c) 2021 Microsoft
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ze Liu
6 | # Modified by Zhenda Xie
7 | # --------------------------------------------------------
8 |
9 | import os
10 | import time
11 | import argparse
12 | import datetime
13 | import numpy as np
14 |
15 | import torch
16 | import torch.backends.cudnn as cudnn
17 | import torch.distributed as dist
18 |
19 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
20 | from timm.utils import accuracy, AverageMeter
21 |
22 | from config import get_config
23 | from models import build_model
24 | from data import build_loader
25 | from lr_scheduler import build_scheduler
26 | from optimizer import build_optimizer
27 | from logger import create_logger
28 | from utils import load_checkpoint, load_pretrained, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor
29 |
30 | try:
31 | # noinspection PyUnresolvedReferences
32 | from apex import amp
33 | except ImportError:
34 | amp = None
35 |
36 |
37 | def parse_option():
38 | parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False)
39 | parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
40 | parser.add_argument(
41 | "--opts",
42 | help="Modify config options by adding 'KEY VALUE' pairs. ",
43 | default=None,
44 | nargs='+',
45 | )
46 |
47 | # easy config modification
48 | parser.add_argument('--batch-size', type=int, help="batch size for single GPU")
49 | parser.add_argument('--data-path', type=str, help='path to dataset')
50 | parser.add_argument('--pretrained', type=str, help='path to pre-trained model')
51 | parser.add_argument('--resume', help='resume from checkpoint')
52 | parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
53 | parser.add_argument('--use-checkpoint', action='store_true',
54 | help="whether to use gradient checkpointing to save memory")
55 | parser.add_argument('--amp-opt-level', type=str, default='O1', choices=['O0', 'O1', 'O2'],
56 | help='mixed precision opt level, if O0, no amp is used')
57 | parser.add_argument('--output', default='output', type=str, metavar='PATH',
58 | help='root of output folder, the full path is