├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENSE ├── MODELHUB.md ├── README.md ├── SECURITY.md ├── SUPPORT.md ├── config.py ├── configs ├── simmim │ ├── simmim_finetune__swin_base__img224_window7__800ep.yaml │ ├── simmim_finetune__swinv2_base__img224_window14__800ep.yaml │ ├── simmim_pretrain__swin_base__img192_window6__800ep.yaml │ └── simmim_pretrain__swinv2_base__img192_window12__800ep.yaml ├── swin │ ├── swin_base_patch4_window12_384_22kto1k_finetune.yaml │ ├── swin_base_patch4_window12_384_finetune.yaml │ ├── swin_base_patch4_window7_224.yaml │ ├── swin_base_patch4_window7_224_22k.yaml │ ├── swin_base_patch4_window7_224_22kto1k_finetune.yaml │ ├── swin_large_patch4_window12_384_22kto1k_finetune.yaml │ ├── swin_large_patch4_window7_224_22k.yaml │ ├── swin_large_patch4_window7_224_22kto1k_finetune.yaml │ ├── swin_small_patch4_window7_224.yaml │ ├── swin_small_patch4_window7_224_22k.yaml │ ├── swin_small_patch4_window7_224_22kto1k_finetune.yaml │ ├── swin_tiny_c24_patch4_window8_256.yaml │ ├── swin_tiny_patch4_window7_224.yaml │ ├── swin_tiny_patch4_window7_224_22k.yaml │ └── swin_tiny_patch4_window7_224_22kto1k_finetune.yaml ├── swinmlp │ ├── swin_mlp_base_patch4_window7_224.yaml │ ├── swin_mlp_tiny_c12_patch4_window8_256.yaml │ ├── swin_mlp_tiny_c24_patch4_window8_256.yaml │ └── swin_mlp_tiny_c6_patch4_window8_256.yaml ├── swinmoe │ ├── swin_moe_base_patch4_window12_192_16expert_32gpu_22k.yaml │ ├── swin_moe_base_patch4_window12_192_32expert_32gpu_22k.yaml │ ├── swin_moe_base_patch4_window12_192_8expert_32gpu_22k.yaml │ ├── swin_moe_base_patch4_window12_192_cosine_router_32expert_32gpu_22k.yaml │ ├── swin_moe_base_patch4_window12_192_densebaseline_22k.yaml │ ├── swin_moe_small_patch4_window12_192_16expert_32gpu_22k.yaml │ ├── swin_moe_small_patch4_window12_192_32expert_32gpu_22k.yaml │ ├── swin_moe_small_patch4_window12_192_64expert_64gpu_22k.yaml │ ├── swin_moe_small_patch4_window12_192_8expert_32gpu_22k.yaml │ ├── swin_moe_small_patch4_window12_192_cosine_router_32expert_32gpu_22k.yaml │ └── swin_moe_small_patch4_window12_192_densebaseline_22k.yaml └── swinv2 │ ├── swinv2_base_patch4_window12_192_22k.yaml │ ├── swinv2_base_patch4_window12to16_192to256_22kto1k_ft.yaml │ ├── swinv2_base_patch4_window12to24_192to384_22kto1k_ft.yaml │ ├── swinv2_base_patch4_window16_256.yaml │ ├── swinv2_base_patch4_window8_256.yaml │ ├── swinv2_large_patch4_window12_192_22k.yaml │ ├── swinv2_large_patch4_window12to16_192to256_22kto1k_ft.yaml │ ├── swinv2_large_patch4_window12to24_192to384_22kto1k_ft.yaml │ ├── swinv2_small_patch4_window16_256.yaml │ ├── swinv2_small_patch4_window8_256.yaml │ ├── swinv2_tiny_patch4_window16_256.yaml │ └── swinv2_tiny_patch4_window8_256.yaml ├── data ├── __init__.py ├── build.py ├── cached_image_folder.py ├── data_simmim_ft.py ├── data_simmim_pt.py ├── imagenet22k_dataset.py ├── map22kto1k.txt ├── samplers.py └── zipreader.py ├── figures └── teaser.png ├── get_started.md ├── kernels └── window_process │ ├── setup.py │ ├── swin_window_process.cpp │ ├── swin_window_process_kernel.cu │ ├── unit_test.py │ └── window_process.py ├── logger.py ├── lr_scheduler.py ├── main.py ├── main_moe.py ├── main_simmim_ft.py ├── main_simmim_pt.py ├── models ├── __init__.py ├── build.py ├── simmim.py ├── swin_mlp.py ├── swin_transformer.py ├── swin_transformer_moe.py └── swin_transformer_v2.py ├── optimizer.py ├── utils.py ├── utils_moe.py └── utils_simmim.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # launch bash 7 | *.sh 8 | # nsight system report files 9 | *.nsys-rep 10 | *.sqlite 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). 40 | 41 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/spot](https://aka.ms/spot). CSS will work with/help you to determine next steps. More details also available at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). 7 | - **Not sure?** Fill out a SPOT intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # --------------------------------------------------------' 7 | 8 | import os 9 | import torch 10 | import yaml 11 | from yacs.config import CfgNode as CN 12 | 13 | # pytorch major version (1.x or 2.x) 14 | PYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0]) 15 | 16 | _C = CN() 17 | 18 | # Base config files 19 | _C.BASE = [''] 20 | 21 | # ----------------------------------------------------------------------------- 22 | # Data settings 23 | # ----------------------------------------------------------------------------- 24 | _C.DATA = CN() 25 | # Batch size for a single GPU, could be overwritten by command line argument 26 | _C.DATA.BATCH_SIZE = 128 27 | # Path to dataset, could be overwritten by command line argument 28 | _C.DATA.DATA_PATH = '' 29 | # Dataset name 30 | _C.DATA.DATASET = 'imagenet' 31 | # Input image size 32 | _C.DATA.IMG_SIZE = 224 33 | # Interpolation to resize image (random, bilinear, bicubic) 34 | _C.DATA.INTERPOLATION = 'bicubic' 35 | # Use zipped dataset instead of folder dataset 36 | # could be overwritten by command line argument 37 | _C.DATA.ZIP_MODE = False 38 | # Cache Data in Memory, could be overwritten by command line argument 39 | _C.DATA.CACHE_MODE = 'part' 40 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU. 41 | _C.DATA.PIN_MEMORY = True 42 | # Number of data loading threads 43 | _C.DATA.NUM_WORKERS = 8 44 | 45 | # [SimMIM] Mask patch size for MaskGenerator 46 | _C.DATA.MASK_PATCH_SIZE = 32 47 | # [SimMIM] Mask ratio for MaskGenerator 48 | _C.DATA.MASK_RATIO = 0.6 49 | 50 | # ----------------------------------------------------------------------------- 51 | # Model settings 52 | # ----------------------------------------------------------------------------- 53 | _C.MODEL = CN() 54 | # Model type 55 | _C.MODEL.TYPE = 'swin' 56 | # Model name 57 | _C.MODEL.NAME = 'swin_tiny_patch4_window7_224' 58 | # Pretrained weight from checkpoint, could be imagenet22k pretrained weight 59 | # could be overwritten by command line argument 60 | _C.MODEL.PRETRAINED = '' 61 | # Checkpoint to resume, could be overwritten by command line argument 62 | _C.MODEL.RESUME = '' 63 | # Number of classes, overwritten in data preparation 64 | _C.MODEL.NUM_CLASSES = 1000 65 | # Dropout rate 66 | _C.MODEL.DROP_RATE = 0.0 67 | # Drop path rate 68 | _C.MODEL.DROP_PATH_RATE = 0.1 69 | # Label Smoothing 70 | _C.MODEL.LABEL_SMOOTHING = 0.1 71 | 72 | # Swin Transformer parameters 73 | _C.MODEL.SWIN = CN() 74 | _C.MODEL.SWIN.PATCH_SIZE = 4 75 | _C.MODEL.SWIN.IN_CHANS = 3 76 | _C.MODEL.SWIN.EMBED_DIM = 96 77 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] 78 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] 79 | _C.MODEL.SWIN.WINDOW_SIZE = 7 80 | _C.MODEL.SWIN.MLP_RATIO = 4. 81 | _C.MODEL.SWIN.QKV_BIAS = True 82 | _C.MODEL.SWIN.QK_SCALE = None 83 | _C.MODEL.SWIN.APE = False 84 | _C.MODEL.SWIN.PATCH_NORM = True 85 | 86 | # Swin Transformer V2 parameters 87 | _C.MODEL.SWINV2 = CN() 88 | _C.MODEL.SWINV2.PATCH_SIZE = 4 89 | _C.MODEL.SWINV2.IN_CHANS = 3 90 | _C.MODEL.SWINV2.EMBED_DIM = 96 91 | _C.MODEL.SWINV2.DEPTHS = [2, 2, 6, 2] 92 | _C.MODEL.SWINV2.NUM_HEADS = [3, 6, 12, 24] 93 | _C.MODEL.SWINV2.WINDOW_SIZE = 7 94 | _C.MODEL.SWINV2.MLP_RATIO = 4. 95 | _C.MODEL.SWINV2.QKV_BIAS = True 96 | _C.MODEL.SWINV2.APE = False 97 | _C.MODEL.SWINV2.PATCH_NORM = True 98 | _C.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES = [0, 0, 0, 0] 99 | 100 | # Swin Transformer MoE parameters 101 | _C.MODEL.SWIN_MOE = CN() 102 | _C.MODEL.SWIN_MOE.PATCH_SIZE = 4 103 | _C.MODEL.SWIN_MOE.IN_CHANS = 3 104 | _C.MODEL.SWIN_MOE.EMBED_DIM = 96 105 | _C.MODEL.SWIN_MOE.DEPTHS = [2, 2, 6, 2] 106 | _C.MODEL.SWIN_MOE.NUM_HEADS = [3, 6, 12, 24] 107 | _C.MODEL.SWIN_MOE.WINDOW_SIZE = 7 108 | _C.MODEL.SWIN_MOE.MLP_RATIO = 4. 109 | _C.MODEL.SWIN_MOE.QKV_BIAS = True 110 | _C.MODEL.SWIN_MOE.QK_SCALE = None 111 | _C.MODEL.SWIN_MOE.APE = False 112 | _C.MODEL.SWIN_MOE.PATCH_NORM = True 113 | _C.MODEL.SWIN_MOE.MLP_FC2_BIAS = True 114 | _C.MODEL.SWIN_MOE.INIT_STD = 0.02 115 | _C.MODEL.SWIN_MOE.PRETRAINED_WINDOW_SIZES = [0, 0, 0, 0] 116 | _C.MODEL.SWIN_MOE.MOE_BLOCKS = [[-1], [-1], [-1], [-1]] 117 | _C.MODEL.SWIN_MOE.NUM_LOCAL_EXPERTS = 1 118 | _C.MODEL.SWIN_MOE.TOP_VALUE = 1 119 | _C.MODEL.SWIN_MOE.CAPACITY_FACTOR = 1.25 120 | _C.MODEL.SWIN_MOE.COSINE_ROUTER = False 121 | _C.MODEL.SWIN_MOE.NORMALIZE_GATE = False 122 | _C.MODEL.SWIN_MOE.USE_BPR = True 123 | _C.MODEL.SWIN_MOE.IS_GSHARD_LOSS = False 124 | _C.MODEL.SWIN_MOE.GATE_NOISE = 1.0 125 | _C.MODEL.SWIN_MOE.COSINE_ROUTER_DIM = 256 126 | _C.MODEL.SWIN_MOE.COSINE_ROUTER_INIT_T = 0.5 127 | _C.MODEL.SWIN_MOE.MOE_DROP = 0.0 128 | _C.MODEL.SWIN_MOE.AUX_LOSS_WEIGHT = 0.01 129 | 130 | # Swin MLP parameters 131 | _C.MODEL.SWIN_MLP = CN() 132 | _C.MODEL.SWIN_MLP.PATCH_SIZE = 4 133 | _C.MODEL.SWIN_MLP.IN_CHANS = 3 134 | _C.MODEL.SWIN_MLP.EMBED_DIM = 96 135 | _C.MODEL.SWIN_MLP.DEPTHS = [2, 2, 6, 2] 136 | _C.MODEL.SWIN_MLP.NUM_HEADS = [3, 6, 12, 24] 137 | _C.MODEL.SWIN_MLP.WINDOW_SIZE = 7 138 | _C.MODEL.SWIN_MLP.MLP_RATIO = 4. 139 | _C.MODEL.SWIN_MLP.APE = False 140 | _C.MODEL.SWIN_MLP.PATCH_NORM = True 141 | 142 | # [SimMIM] Norm target during training 143 | _C.MODEL.SIMMIM = CN() 144 | _C.MODEL.SIMMIM.NORM_TARGET = CN() 145 | _C.MODEL.SIMMIM.NORM_TARGET.ENABLE = False 146 | _C.MODEL.SIMMIM.NORM_TARGET.PATCH_SIZE = 47 147 | 148 | # ----------------------------------------------------------------------------- 149 | # Training settings 150 | # ----------------------------------------------------------------------------- 151 | _C.TRAIN = CN() 152 | _C.TRAIN.START_EPOCH = 0 153 | _C.TRAIN.EPOCHS = 300 154 | _C.TRAIN.WARMUP_EPOCHS = 20 155 | _C.TRAIN.WEIGHT_DECAY = 0.05 156 | _C.TRAIN.BASE_LR = 5e-4 157 | _C.TRAIN.WARMUP_LR = 5e-7 158 | _C.TRAIN.MIN_LR = 5e-6 159 | # Clip gradient norm 160 | _C.TRAIN.CLIP_GRAD = 5.0 161 | # Auto resume from latest checkpoint 162 | _C.TRAIN.AUTO_RESUME = True 163 | # Gradient accumulation steps 164 | # could be overwritten by command line argument 165 | _C.TRAIN.ACCUMULATION_STEPS = 1 166 | # Whether to use gradient checkpointing to save memory 167 | # could be overwritten by command line argument 168 | _C.TRAIN.USE_CHECKPOINT = False 169 | 170 | # LR scheduler 171 | _C.TRAIN.LR_SCHEDULER = CN() 172 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine' 173 | # Epoch interval to decay LR, used in StepLRScheduler 174 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 175 | # LR decay rate, used in StepLRScheduler 176 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 177 | # warmup_prefix used in CosineLRScheduler 178 | _C.TRAIN.LR_SCHEDULER.WARMUP_PREFIX = True 179 | # [SimMIM] Gamma / Multi steps value, used in MultiStepLRScheduler 180 | _C.TRAIN.LR_SCHEDULER.GAMMA = 0.1 181 | _C.TRAIN.LR_SCHEDULER.MULTISTEPS = [] 182 | 183 | # Optimizer 184 | _C.TRAIN.OPTIMIZER = CN() 185 | _C.TRAIN.OPTIMIZER.NAME = 'adamw' 186 | # Optimizer Epsilon 187 | _C.TRAIN.OPTIMIZER.EPS = 1e-8 188 | # Optimizer Betas 189 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) 190 | # SGD momentum 191 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 192 | 193 | # [SimMIM] Layer decay for fine-tuning 194 | _C.TRAIN.LAYER_DECAY = 1.0 195 | 196 | # MoE 197 | _C.TRAIN.MOE = CN() 198 | # Only save model on master device 199 | _C.TRAIN.MOE.SAVE_MASTER = False 200 | # ----------------------------------------------------------------------------- 201 | # Augmentation settings 202 | # ----------------------------------------------------------------------------- 203 | _C.AUG = CN() 204 | # Color jitter factor 205 | _C.AUG.COLOR_JITTER = 0.4 206 | # Use AutoAugment policy. "v0" or "original" 207 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' 208 | # Random erase prob 209 | _C.AUG.REPROB = 0.25 210 | # Random erase mode 211 | _C.AUG.REMODE = 'pixel' 212 | # Random erase count 213 | _C.AUG.RECOUNT = 1 214 | # Mixup alpha, mixup enabled if > 0 215 | _C.AUG.MIXUP = 0.8 216 | # Cutmix alpha, cutmix enabled if > 0 217 | _C.AUG.CUTMIX = 1.0 218 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set 219 | _C.AUG.CUTMIX_MINMAX = None 220 | # Probability of performing mixup or cutmix when either/both is enabled 221 | _C.AUG.MIXUP_PROB = 1.0 222 | # Probability of switching to cutmix when both mixup and cutmix enabled 223 | _C.AUG.MIXUP_SWITCH_PROB = 0.5 224 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem" 225 | _C.AUG.MIXUP_MODE = 'batch' 226 | 227 | # ----------------------------------------------------------------------------- 228 | # Testing settings 229 | # ----------------------------------------------------------------------------- 230 | _C.TEST = CN() 231 | # Whether to use center crop when testing 232 | _C.TEST.CROP = True 233 | # Whether to use SequentialSampler as validation sampler 234 | _C.TEST.SEQUENTIAL = False 235 | _C.TEST.SHUFFLE = False 236 | 237 | # ----------------------------------------------------------------------------- 238 | # Misc 239 | # ----------------------------------------------------------------------------- 240 | # [SimMIM] Whether to enable pytorch amp, overwritten by command line argument 241 | _C.ENABLE_AMP = False 242 | 243 | # Enable Pytorch automatic mixed precision (amp). 244 | _C.AMP_ENABLE = True 245 | # [Deprecated] Mixed precision opt level of apex, if O0, no apex amp is used ('O0', 'O1', 'O2') 246 | _C.AMP_OPT_LEVEL = '' 247 | # Path to output folder, overwritten by command line argument 248 | _C.OUTPUT = '' 249 | # Tag of experiment, overwritten by command line argument 250 | _C.TAG = 'default' 251 | # Frequency to save checkpoint 252 | _C.SAVE_FREQ = 1 253 | # Frequency to logging info 254 | _C.PRINT_FREQ = 10 255 | # Fixed random seed 256 | _C.SEED = 0 257 | # Perform evaluation only, overwritten by command line argument 258 | _C.EVAL_MODE = False 259 | # Test throughput only, overwritten by command line argument 260 | _C.THROUGHPUT_MODE = False 261 | # local rank for DistributedDataParallel, given by command line argument 262 | _C.LOCAL_RANK = 0 263 | # for acceleration 264 | _C.FUSED_WINDOW_PROCESS = False 265 | _C.FUSED_LAYERNORM = False 266 | 267 | 268 | def _update_config_from_file(config, cfg_file): 269 | config.defrost() 270 | with open(cfg_file, 'r') as f: 271 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) 272 | 273 | for cfg in yaml_cfg.setdefault('BASE', ['']): 274 | if cfg: 275 | _update_config_from_file( 276 | config, os.path.join(os.path.dirname(cfg_file), cfg) 277 | ) 278 | print('=> merge config from {}'.format(cfg_file)) 279 | config.merge_from_file(cfg_file) 280 | config.freeze() 281 | 282 | 283 | def update_config(config, args): 284 | _update_config_from_file(config, args.cfg) 285 | 286 | config.defrost() 287 | if args.opts: 288 | config.merge_from_list(args.opts) 289 | 290 | def _check_args(name): 291 | if hasattr(args, name) and eval(f'args.{name}'): 292 | return True 293 | return False 294 | 295 | # merge from specific arguments 296 | if _check_args('batch_size'): 297 | config.DATA.BATCH_SIZE = args.batch_size 298 | if _check_args('data_path'): 299 | config.DATA.DATA_PATH = args.data_path 300 | if _check_args('zip'): 301 | config.DATA.ZIP_MODE = True 302 | if _check_args('cache_mode'): 303 | config.DATA.CACHE_MODE = args.cache_mode 304 | if _check_args('pretrained'): 305 | config.MODEL.PRETRAINED = args.pretrained 306 | if _check_args('resume'): 307 | config.MODEL.RESUME = args.resume 308 | if _check_args('accumulation_steps'): 309 | config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps 310 | if _check_args('use_checkpoint'): 311 | config.TRAIN.USE_CHECKPOINT = True 312 | if _check_args('amp_opt_level'): 313 | print("[warning] Apex amp has been deprecated, please use pytorch amp instead!") 314 | if args.amp_opt_level == 'O0': 315 | config.AMP_ENABLE = False 316 | if _check_args('disable_amp'): 317 | config.AMP_ENABLE = False 318 | if _check_args('output'): 319 | config.OUTPUT = args.output 320 | if _check_args('tag'): 321 | config.TAG = args.tag 322 | if _check_args('eval'): 323 | config.EVAL_MODE = True 324 | if _check_args('throughput'): 325 | config.THROUGHPUT_MODE = True 326 | 327 | # [SimMIM] 328 | if _check_args('enable_amp'): 329 | config.ENABLE_AMP = args.enable_amp 330 | 331 | # for acceleration 332 | if _check_args('fused_window_process'): 333 | config.FUSED_WINDOW_PROCESS = True 334 | if _check_args('fused_layernorm'): 335 | config.FUSED_LAYERNORM = True 336 | ## Overwrite optimizer if not None, currently we use it for [fused_adam, fused_lamb] 337 | if _check_args('optim'): 338 | config.TRAIN.OPTIMIZER.NAME = args.optim 339 | 340 | # set local rank for distributed training 341 | if PYTORCH_MAJOR_VERSION == 1: 342 | config.LOCAL_RANK = args.local_rank 343 | else: 344 | config.LOCAL_RANK = int(os.environ['LOCAL_RANK']) 345 | 346 | # output folder 347 | config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG) 348 | 349 | config.freeze() 350 | 351 | 352 | def get_config(args): 353 | """Get a yacs CfgNode object with default values.""" 354 | # Return a clone so that the defaults will not be altered 355 | # This is for the "local variable" use pattern 356 | config = _C.clone() 357 | update_config(config, args) 358 | 359 | return config 360 | -------------------------------------------------------------------------------- /configs/simmim/simmim_finetune__swin_base__img224_window7__800ep.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: simmim_finetune 4 | DROP_PATH_RATE: 0.1 5 | SWIN: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 4, 8, 16, 32 ] 9 | WINDOW_SIZE: 7 10 | DATA: 11 | IMG_SIZE: 224 12 | TRAIN: 13 | EPOCHS: 100 14 | WARMUP_EPOCHS: 20 15 | BASE_LR: 1.25e-3 16 | WARMUP_LR: 2.5e-7 17 | MIN_LR: 2.5e-7 18 | WEIGHT_DECAY: 0.05 19 | LAYER_DECAY: 0.8 20 | PRINT_FREQ: 100 21 | SAVE_FREQ: 5 22 | TAG: simmim_finetune__swin_base__img224_window7__800ep -------------------------------------------------------------------------------- /configs/simmim/simmim_finetune__swinv2_base__img224_window14__800ep.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swinv2 3 | NAME: simmim_finetune 4 | DROP_PATH_RATE: 0.1 5 | SWINV2: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 4, 8, 16, 32 ] 9 | WINDOW_SIZE: 14 10 | PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ] 11 | DATA: 12 | IMG_SIZE: 224 13 | TRAIN: 14 | EPOCHS: 100 15 | WARMUP_EPOCHS: 20 16 | BASE_LR: 1.25e-3 17 | WARMUP_LR: 2.5e-7 18 | MIN_LR: 2.5e-7 19 | WEIGHT_DECAY: 0.05 20 | LAYER_DECAY: 0.75 21 | PRINT_FREQ: 100 22 | SAVE_FREQ: 5 23 | TAG: simmim_finetune__swinv2_base__img224_window14__800ep -------------------------------------------------------------------------------- /configs/simmim/simmim_pretrain__swin_base__img192_window6__800ep.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: simmim_pretrain 4 | DROP_PATH_RATE: 0.0 5 | SWIN: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 4, 8, 16, 32 ] 9 | WINDOW_SIZE: 6 10 | DATA: 11 | IMG_SIZE: 192 12 | MASK_PATCH_SIZE: 32 13 | MASK_RATIO: 0.6 14 | TRAIN: 15 | EPOCHS: 800 16 | WARMUP_EPOCHS: 10 17 | BASE_LR: 1e-4 18 | WARMUP_LR: 5e-7 19 | WEIGHT_DECAY: 0.05 20 | LR_SCHEDULER: 21 | NAME: 'multistep' 22 | GAMMA: 0.1 23 | MULTISTEPS: [700,] 24 | PRINT_FREQ: 100 25 | SAVE_FREQ: 5 26 | TAG: simmim_pretrain__swin_base__img192_window6__800ep -------------------------------------------------------------------------------- /configs/simmim/simmim_pretrain__swinv2_base__img192_window12__800ep.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swinv2 3 | NAME: simmim_pretrain 4 | DROP_PATH_RATE: 0.1 5 | SIMMIM: 6 | NORM_TARGET: 7 | ENABLE: True 8 | PATCH_SIZE: 47 9 | SWINV2: 10 | EMBED_DIM: 128 11 | DEPTHS: [ 2, 2, 18, 2 ] 12 | NUM_HEADS: [ 4, 8, 16, 32 ] 13 | WINDOW_SIZE: 12 14 | DATA: 15 | IMG_SIZE: 192 16 | MASK_PATCH_SIZE: 32 17 | MASK_RATIO: 0.6 18 | TRAIN: 19 | EPOCHS: 800 20 | WARMUP_EPOCHS: 10 21 | BASE_LR: 1e-4 22 | WARMUP_LR: 5e-7 23 | WEIGHT_DECAY: 0.05 24 | LR_SCHEDULER: 25 | NAME: 'multistep' 26 | GAMMA: 0.1 27 | MULTISTEPS: [700,] 28 | PRINT_FREQ: 100 29 | SAVE_FREQ: 5 30 | TAG: simmim_pretrain__swinv2_base__img192_window12__800ep -------------------------------------------------------------------------------- /configs/swin/swin_base_patch4_window12_384_22kto1k_finetune.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 384 3 | MODEL: 4 | TYPE: swin 5 | NAME: swin_base_patch4_window12_384_22kto1k_finetune 6 | DROP_PATH_RATE: 0.2 7 | SWIN: 8 | EMBED_DIM: 128 9 | DEPTHS: [ 2, 2, 18, 2 ] 10 | NUM_HEADS: [ 4, 8, 16, 32 ] 11 | WINDOW_SIZE: 12 12 | TRAIN: 13 | EPOCHS: 30 14 | WARMUP_EPOCHS: 5 15 | WEIGHT_DECAY: 1e-8 16 | BASE_LR: 2e-05 17 | WARMUP_LR: 2e-08 18 | MIN_LR: 2e-07 19 | TEST: 20 | CROP: False -------------------------------------------------------------------------------- /configs/swin/swin_base_patch4_window12_384_finetune.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 384 3 | MODEL: 4 | TYPE: swin 5 | NAME: swin_base_patch4_window12_384_finetune 6 | DROP_PATH_RATE: 0.5 7 | SWIN: 8 | EMBED_DIM: 128 9 | DEPTHS: [ 2, 2, 18, 2 ] 10 | NUM_HEADS: [ 4, 8, 16, 32 ] 11 | WINDOW_SIZE: 12 12 | TRAIN: 13 | EPOCHS: 30 14 | WARMUP_EPOCHS: 5 15 | WEIGHT_DECAY: 1e-8 16 | BASE_LR: 2e-05 17 | WARMUP_LR: 2e-08 18 | MIN_LR: 2e-07 19 | TEST: 20 | CROP: False -------------------------------------------------------------------------------- /configs/swin/swin_base_patch4_window7_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_base_patch4_window7_224 4 | DROP_PATH_RATE: 0.5 5 | SWIN: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 4, 8, 16, 32 ] 9 | WINDOW_SIZE: 7 -------------------------------------------------------------------------------- /configs/swin/swin_base_patch4_window7_224_22k.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | DATASET: imagenet22K 3 | MODEL: 4 | TYPE: swin 5 | NAME: swin_base_patch4_window7_224_22k 6 | DROP_PATH_RATE: 0.2 7 | SWIN: 8 | EMBED_DIM: 128 9 | DEPTHS: [ 2, 2, 18, 2 ] 10 | NUM_HEADS: [ 4, 8, 16, 32 ] 11 | WINDOW_SIZE: 7 12 | TRAIN: 13 | EPOCHS: 90 14 | WARMUP_EPOCHS: 5 15 | WEIGHT_DECAY: 0.05 16 | BASE_LR: 1.25e-4 # 4096 batch-size 17 | WARMUP_LR: 1.25e-7 18 | MIN_LR: 1.25e-6 -------------------------------------------------------------------------------- /configs/swin/swin_base_patch4_window7_224_22kto1k_finetune.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_base_patch4_window7_224_22kto1k_finetune 4 | DROP_PATH_RATE: 0.2 5 | SWIN: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 4, 8, 16, 32 ] 9 | WINDOW_SIZE: 7 10 | TRAIN: 11 | EPOCHS: 30 12 | WARMUP_EPOCHS: 5 13 | WEIGHT_DECAY: 1e-8 14 | BASE_LR: 2e-05 15 | WARMUP_LR: 2e-08 16 | MIN_LR: 2e-07 -------------------------------------------------------------------------------- /configs/swin/swin_large_patch4_window12_384_22kto1k_finetune.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 384 3 | MODEL: 4 | TYPE: swin 5 | NAME: swin_large_patch4_window12_384_22kto1k_finetune 6 | DROP_PATH_RATE: 0.2 7 | SWIN: 8 | EMBED_DIM: 192 9 | DEPTHS: [ 2, 2, 18, 2 ] 10 | NUM_HEADS: [ 6, 12, 24, 48 ] 11 | WINDOW_SIZE: 12 12 | TRAIN: 13 | EPOCHS: 30 14 | WARMUP_EPOCHS: 5 15 | WEIGHT_DECAY: 1e-8 16 | BASE_LR: 2e-05 17 | WARMUP_LR: 2e-08 18 | MIN_LR: 2e-07 19 | TEST: 20 | CROP: False -------------------------------------------------------------------------------- /configs/swin/swin_large_patch4_window7_224_22k.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | DATASET: imagenet22K 3 | MODEL: 4 | TYPE: swin 5 | NAME: swin_large_patch4_window7_224_22k 6 | DROP_PATH_RATE: 0.2 7 | SWIN: 8 | EMBED_DIM: 192 9 | DEPTHS: [ 2, 2, 18, 2 ] 10 | NUM_HEADS: [ 6, 12, 24, 48 ] 11 | WINDOW_SIZE: 7 12 | TRAIN: 13 | EPOCHS: 90 14 | WARMUP_EPOCHS: 5 15 | WEIGHT_DECAY: 0.05 16 | BASE_LR: 1.25e-4 # 4096 batch-size 17 | WARMUP_LR: 1.25e-7 18 | MIN_LR: 1.25e-6 -------------------------------------------------------------------------------- /configs/swin/swin_large_patch4_window7_224_22kto1k_finetune.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_large_patch4_window7_224_22kto1k_finetune 4 | DROP_PATH_RATE: 0.2 5 | SWIN: 6 | EMBED_DIM: 192 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 6, 12, 24, 48 ] 9 | WINDOW_SIZE: 7 10 | TRAIN: 11 | EPOCHS: 30 12 | WARMUP_EPOCHS: 5 13 | WEIGHT_DECAY: 1e-8 14 | BASE_LR: 2e-05 15 | WARMUP_LR: 2e-08 16 | MIN_LR: 2e-07 -------------------------------------------------------------------------------- /configs/swin/swin_small_patch4_window7_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_small_patch4_window7_224 4 | DROP_PATH_RATE: 0.3 5 | SWIN: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 3, 6, 12, 24 ] 9 | WINDOW_SIZE: 7 -------------------------------------------------------------------------------- /configs/swin/swin_small_patch4_window7_224_22k.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | DATASET: imagenet22K 3 | MODEL: 4 | TYPE: swin 5 | NAME: swin_small_patch4_window7_224_22k 6 | DROP_PATH_RATE: 0.2 7 | SWIN: 8 | EMBED_DIM: 96 9 | DEPTHS: [ 2, 2, 18, 2 ] 10 | NUM_HEADS: [ 3, 6, 12, 24 ] 11 | WINDOW_SIZE: 7 12 | TRAIN: 13 | EPOCHS: 90 14 | WARMUP_EPOCHS: 5 15 | WEIGHT_DECAY: 0.05 16 | BASE_LR: 1.25e-4 # 4096 batch-size 17 | WARMUP_LR: 1.25e-7 18 | MIN_LR: 1.25e-6 -------------------------------------------------------------------------------- /configs/swin/swin_small_patch4_window7_224_22kto1k_finetune.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_small_patch4_window7_224_22kto1k_finetune 4 | DROP_PATH_RATE: 0.2 5 | SWIN: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 3, 6, 12, 24 ] 9 | WINDOW_SIZE: 7 10 | TRAIN: 11 | EPOCHS: 30 12 | WARMUP_EPOCHS: 5 13 | WEIGHT_DECAY: 1e-8 14 | BASE_LR: 2e-05 15 | WARMUP_LR: 2e-08 16 | MIN_LR: 2e-07 -------------------------------------------------------------------------------- /configs/swin/swin_tiny_c24_patch4_window8_256.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 256 3 | MODEL: 4 | TYPE: swin 5 | NAME: swin_tiny_c24_patch4_window8_256 6 | DROP_PATH_RATE: 0.2 7 | SWIN: 8 | EMBED_DIM: 96 9 | DEPTHS: [ 2, 2, 6, 2 ] 10 | NUM_HEADS: [ 4, 8, 16, 32 ] 11 | WINDOW_SIZE: 8 -------------------------------------------------------------------------------- /configs/swin/swin_tiny_patch4_window7_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_tiny_patch4_window7_224 4 | DROP_PATH_RATE: 0.2 5 | SWIN: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 6, 2 ] 8 | NUM_HEADS: [ 3, 6, 12, 24 ] 9 | WINDOW_SIZE: 7 -------------------------------------------------------------------------------- /configs/swin/swin_tiny_patch4_window7_224_22k.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | DATASET: imagenet22K 3 | MODEL: 4 | TYPE: swin 5 | NAME: swin_tiny_patch4_window7_224_22k 6 | DROP_PATH_RATE: 0.1 7 | SWIN: 8 | EMBED_DIM: 96 9 | DEPTHS: [ 2, 2, 6, 2 ] 10 | NUM_HEADS: [ 3, 6, 12, 24 ] 11 | WINDOW_SIZE: 7 12 | TRAIN: 13 | EPOCHS: 90 14 | WARMUP_EPOCHS: 5 15 | WEIGHT_DECAY: 0.05 16 | BASE_LR: 1.25e-4 # 4096 batch-size 17 | WARMUP_LR: 1.25e-7 18 | MIN_LR: 1.25e-6 -------------------------------------------------------------------------------- /configs/swin/swin_tiny_patch4_window7_224_22kto1k_finetune.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin 3 | NAME: swin_tiny_patch4_window7_224_22kto1k_finetune 4 | DROP_PATH_RATE: 0.1 5 | SWIN: 6 | EMBED_DIM: 96 7 | DEPTHS: [ 2, 2, 6, 2 ] 8 | NUM_HEADS: [ 3, 6, 12, 24 ] 9 | WINDOW_SIZE: 7 10 | TRAIN: 11 | EPOCHS: 30 12 | WARMUP_EPOCHS: 5 13 | WEIGHT_DECAY: 1e-8 14 | BASE_LR: 2e-05 15 | WARMUP_LR: 2e-08 16 | MIN_LR: 2e-07 -------------------------------------------------------------------------------- /configs/swinmlp/swin_mlp_base_patch4_window7_224.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin_mlp 3 | NAME: swin_mlp_base_patch4_window7_224 4 | DROP_PATH_RATE: 0.5 5 | SWIN_MLP: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 4, 8, 16, 32 ] 9 | WINDOW_SIZE: 7 10 | -------------------------------------------------------------------------------- /configs/swinmlp/swin_mlp_tiny_c12_patch4_window8_256.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 256 3 | MODEL: 4 | TYPE: swin_mlp 5 | NAME: swin_mlp_tiny_c12_patch4_window8_256 6 | DROP_PATH_RATE: 0.2 7 | SWIN_MLP: 8 | EMBED_DIM: 96 9 | DEPTHS: [ 2, 2, 6, 2 ] 10 | NUM_HEADS: [ 8, 16, 32, 64 ] 11 | WINDOW_SIZE: 8 -------------------------------------------------------------------------------- /configs/swinmlp/swin_mlp_tiny_c24_patch4_window8_256.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 256 3 | MODEL: 4 | TYPE: swin_mlp 5 | NAME: swin_mlp_tiny_c24_patch4_window8_256 6 | DROP_PATH_RATE: 0.2 7 | SWIN_MLP: 8 | EMBED_DIM: 96 9 | DEPTHS: [ 2, 2, 6, 2 ] 10 | NUM_HEADS: [ 4, 8, 16, 32 ] 11 | WINDOW_SIZE: 8 -------------------------------------------------------------------------------- /configs/swinmlp/swin_mlp_tiny_c6_patch4_window8_256.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 256 3 | MODEL: 4 | TYPE: swin_mlp 5 | NAME: swin_mlp_tiny_c6_patch4_window8_256 6 | DROP_PATH_RATE: 0.2 7 | SWIN_MLP: 8 | EMBED_DIM: 96 9 | DEPTHS: [ 2, 2, 6, 2 ] 10 | NUM_HEADS: [ 16, 32, 64, 128 ] 11 | WINDOW_SIZE: 8 -------------------------------------------------------------------------------- /configs/swinmoe/swin_moe_base_patch4_window12_192_16expert_32gpu_22k.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | DATASET: imagenet22K 3 | IMG_SIZE: 192 4 | MODEL: 5 | TYPE: swin_moe 6 | NAME: swin_moe_base_patch4_window12_192_16expert_32gpu_22k 7 | DROP_PATH_RATE: 0.3 8 | SWIN_MOE: 9 | EMBED_DIM: 128 10 | DEPTHS: [ 2, 2, 18, 2 ] 11 | NUM_HEADS: [ 4, 8, 16, 32 ] 12 | WINDOW_SIZE: 12 13 | MLP_FC2_BIAS: False 14 | INIT_STD: 0.005 15 | MOE_BLOCKS: [ [ -1 ], [ -1 ], [ 1, 3, 5, 7, 9, 11, 13, 15, 17 ], [ 1 ] ] 16 | NUM_LOCAL_EXPERTS: -2 17 | TOP_VALUE: 1 18 | CAPACITY_FACTOR: 1.25 19 | IS_GSHARD_LOSS: False 20 | MOE_DROP: 0.1 21 | AUX_LOSS_WEIGHT: 0.01 22 | TRAIN: 23 | EPOCHS: 90 24 | WARMUP_EPOCHS: 10 25 | WEIGHT_DECAY: 0.1 26 | BASE_LR: 1.25e-4 # 4096 batch-size 27 | WARMUP_LR: 1.25e-7 28 | MIN_LR: 1.25e-6 29 | CLIP_GRAD: 3.0 30 | TEST: 31 | SHUFFLE: True -------------------------------------------------------------------------------- /configs/swinmoe/swin_moe_base_patch4_window12_192_32expert_32gpu_22k.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | DATASET: imagenet22K 3 | IMG_SIZE: 192 4 | MODEL: 5 | TYPE: swin_moe 6 | NAME: swin_moe_base_patch4_window12_192_32expert_32gpu_22k 7 | DROP_PATH_RATE: 0.3 8 | SWIN_MOE: 9 | EMBED_DIM: 128 10 | DEPTHS: [ 2, 2, 18, 2 ] 11 | NUM_HEADS: [ 4, 8, 16, 32 ] 12 | WINDOW_SIZE: 12 13 | MLP_FC2_BIAS: False 14 | INIT_STD: 0.005 15 | MOE_BLOCKS: [ [ -1 ], [ -1 ], [ 1, 3, 5, 7, 9, 11, 13, 15, 17 ], [ 1 ] ] 16 | NUM_LOCAL_EXPERTS: 1 17 | TOP_VALUE: 1 18 | CAPACITY_FACTOR: 1.25 19 | IS_GSHARD_LOSS: False 20 | MOE_DROP: 0.1 21 | AUX_LOSS_WEIGHT: 0.01 22 | TRAIN: 23 | EPOCHS: 90 24 | WARMUP_EPOCHS: 10 25 | WEIGHT_DECAY: 0.1 26 | BASE_LR: 1.25e-4 # 4096 batch-size 27 | WARMUP_LR: 1.25e-7 28 | MIN_LR: 1.25e-6 29 | CLIP_GRAD: 3.0 30 | TEST: 31 | SHUFFLE: True -------------------------------------------------------------------------------- /configs/swinmoe/swin_moe_base_patch4_window12_192_8expert_32gpu_22k.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | DATASET: imagenet22K 3 | IMG_SIZE: 192 4 | MODEL: 5 | TYPE: swin_moe 6 | NAME: swin_moe_base_patch4_window12_192_8expert_32gpu_22k 7 | DROP_PATH_RATE: 0.3 8 | SWIN_MOE: 9 | EMBED_DIM: 128 10 | DEPTHS: [ 2, 2, 18, 2 ] 11 | NUM_HEADS: [ 4, 8, 16, 32 ] 12 | WINDOW_SIZE: 12 13 | MLP_FC2_BIAS: False 14 | INIT_STD: 0.005 15 | MOE_BLOCKS: [ [ -1 ], [ -1 ], [ 1, 3, 5, 7, 9, 11, 13, 15, 17 ], [ 1 ] ] 16 | NUM_LOCAL_EXPERTS: -4 17 | TOP_VALUE: 1 18 | CAPACITY_FACTOR: 1.25 19 | IS_GSHARD_LOSS: False 20 | MOE_DROP: 0.1 21 | AUX_LOSS_WEIGHT: 0.01 22 | TRAIN: 23 | EPOCHS: 90 24 | WARMUP_EPOCHS: 10 25 | WEIGHT_DECAY: 0.1 26 | BASE_LR: 1.25e-4 # 4096 batch-size 27 | WARMUP_LR: 1.25e-7 28 | MIN_LR: 1.25e-6 29 | CLIP_GRAD: 3.0 30 | TEST: 31 | SHUFFLE: True -------------------------------------------------------------------------------- /configs/swinmoe/swin_moe_base_patch4_window12_192_cosine_router_32expert_32gpu_22k.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | DATASET: imagenet22K 3 | IMG_SIZE: 192 4 | MODEL: 5 | TYPE: swin_moe 6 | NAME: swin_moe_base_patch4_window12_192_cosine_router_32expert_32gpu_22k 7 | DROP_PATH_RATE: 0.3 8 | SWIN_MOE: 9 | EMBED_DIM: 128 10 | DEPTHS: [ 2, 2, 18, 2 ] 11 | NUM_HEADS: [ 4, 8, 16, 32 ] 12 | WINDOW_SIZE: 12 13 | MLP_FC2_BIAS: False 14 | INIT_STD: 0.005 15 | MOE_BLOCKS: [ [ -1 ], [ -1 ], [ 1, 3, 5, 7, 9, 11, 13, 15, 17 ], [ 1 ] ] 16 | NUM_LOCAL_EXPERTS: 1 17 | TOP_VALUE: 1 18 | CAPACITY_FACTOR: 1.25 19 | COSINE_ROUTER: True 20 | IS_GSHARD_LOSS: False 21 | MOE_DROP: 0.1 22 | AUX_LOSS_WEIGHT: 0.01 23 | TRAIN: 24 | EPOCHS: 90 25 | WARMUP_EPOCHS: 10 26 | WEIGHT_DECAY: 0.1 27 | BASE_LR: 1.25e-4 # 4096 batch-size 28 | WARMUP_LR: 1.25e-7 29 | MIN_LR: 1.25e-6 30 | CLIP_GRAD: 3.0 31 | TEST: 32 | SHUFFLE: True -------------------------------------------------------------------------------- /configs/swinmoe/swin_moe_base_patch4_window12_192_densebaseline_22k.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | DATASET: imagenet22K 3 | IMG_SIZE: 192 4 | MODEL: 5 | TYPE: swin_moe 6 | NAME: swin_moe_base_patch4_window12_192_densebaseline_22k 7 | DROP_PATH_RATE: 0.2 8 | SWIN_MOE: 9 | EMBED_DIM: 128 10 | DEPTHS: [ 2, 2, 18, 2 ] 11 | NUM_HEADS: [ 4, 8, 16, 32 ] 12 | WINDOW_SIZE: 12 13 | MLP_FC2_BIAS: False 14 | MOE_BLOCKS: [ [ -1 ], [ -1 ], [ -1 ], [ -1 ] ] 15 | TRAIN: 16 | EPOCHS: 90 17 | WARMUP_EPOCHS: 10 18 | WEIGHT_DECAY: 0.1 19 | BASE_LR: 1.25e-4 # 4096 batch-size 20 | WARMUP_LR: 1.25e-7 21 | MIN_LR: 1.25e-6 22 | CLIP_GRAD: 3.0 23 | MOE: 24 | SAVE_MASTER: True 25 | TEST: 26 | SHUFFLE: True -------------------------------------------------------------------------------- /configs/swinmoe/swin_moe_small_patch4_window12_192_16expert_32gpu_22k.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | DATASET: imagenet22K 3 | IMG_SIZE: 192 4 | MODEL: 5 | TYPE: swin_moe 6 | NAME: swin_moe_small_patch4_window12_192_16expert_32gpu_22k 7 | DROP_PATH_RATE: 0.2 8 | SWIN_MOE: 9 | EMBED_DIM: 96 10 | DEPTHS: [ 2, 2, 18, 2 ] 11 | NUM_HEADS: [ 3, 6, 12, 24 ] 12 | WINDOW_SIZE: 12 13 | MLP_FC2_BIAS: False 14 | INIT_STD: 0.005 15 | MOE_BLOCKS: [ [ -1 ], [ -1 ], [ 1, 3, 5, 7, 9, 11, 13, 15, 17 ], [ 1 ] ] 16 | NUM_LOCAL_EXPERTS: -2 17 | TOP_VALUE: 1 18 | CAPACITY_FACTOR: 1.25 19 | IS_GSHARD_LOSS: False 20 | MOE_DROP: 0.1 21 | AUX_LOSS_WEIGHT: 0.01 22 | TRAIN: 23 | EPOCHS: 90 24 | WARMUP_EPOCHS: 10 25 | WEIGHT_DECAY: 0.1 26 | BASE_LR: 1.25e-4 # 4096 batch-size 27 | WARMUP_LR: 1.25e-7 28 | MIN_LR: 1.25e-6 29 | CLIP_GRAD: 3.0 30 | TEST: 31 | SHUFFLE: True -------------------------------------------------------------------------------- /configs/swinmoe/swin_moe_small_patch4_window12_192_32expert_32gpu_22k.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | DATASET: imagenet22K 3 | IMG_SIZE: 192 4 | MODEL: 5 | TYPE: swin_moe 6 | NAME: swin_moe_small_patch4_window12_192_32expert_32gpu_22k 7 | DROP_PATH_RATE: 0.2 8 | SWIN_MOE: 9 | EMBED_DIM: 96 10 | DEPTHS: [ 2, 2, 18, 2 ] 11 | NUM_HEADS: [ 3, 6, 12, 24 ] 12 | WINDOW_SIZE: 12 13 | MLP_FC2_BIAS: False 14 | INIT_STD: 0.005 15 | MOE_BLOCKS: [ [ -1 ], [ -1 ], [ 1, 3, 5, 7, 9, 11, 13, 15, 17 ], [ 1 ] ] 16 | NUM_LOCAL_EXPERTS: 1 17 | TOP_VALUE: 1 18 | CAPACITY_FACTOR: 1.25 19 | IS_GSHARD_LOSS: False 20 | MOE_DROP: 0.1 21 | AUX_LOSS_WEIGHT: 0.01 22 | TRAIN: 23 | EPOCHS: 90 24 | WARMUP_EPOCHS: 10 25 | WEIGHT_DECAY: 0.1 26 | BASE_LR: 1.25e-4 # 4096 batch-size 27 | WARMUP_LR: 1.25e-7 28 | MIN_LR: 1.25e-6 29 | CLIP_GRAD: 3.0 30 | TEST: 31 | SHUFFLE: True -------------------------------------------------------------------------------- /configs/swinmoe/swin_moe_small_patch4_window12_192_64expert_64gpu_22k.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | DATASET: imagenet22K 3 | IMG_SIZE: 192 4 | MODEL: 5 | TYPE: swin_moe 6 | NAME: swin_moe_small_patch4_window12_192_64expert_64gpu_22k 7 | DROP_PATH_RATE: 0.2 8 | SWIN_MOE: 9 | EMBED_DIM: 96 10 | DEPTHS: [ 2, 2, 18, 2 ] 11 | NUM_HEADS: [ 3, 6, 12, 24 ] 12 | WINDOW_SIZE: 12 13 | MLP_FC2_BIAS: False 14 | INIT_STD: 0.005 15 | MOE_BLOCKS: [ [ -1 ], [ -1 ], [ 1, 3, 5, 7, 9, 11, 13, 15, 17 ], [ 1 ] ] 16 | NUM_LOCAL_EXPERTS: 1 17 | TOP_VALUE: 1 18 | CAPACITY_FACTOR: 1.25 19 | IS_GSHARD_LOSS: False 20 | MOE_DROP: 0.1 21 | AUX_LOSS_WEIGHT: 0.01 22 | TRAIN: 23 | EPOCHS: 90 24 | WARMUP_EPOCHS: 10 25 | WEIGHT_DECAY: 0.1 26 | BASE_LR: 1.25e-4 # 4096 batch-size 27 | WARMUP_LR: 1.25e-7 28 | MIN_LR: 1.25e-6 29 | CLIP_GRAD: 3.0 30 | TEST: 31 | SHUFFLE: True -------------------------------------------------------------------------------- /configs/swinmoe/swin_moe_small_patch4_window12_192_8expert_32gpu_22k.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | DATASET: imagenet22K 3 | IMG_SIZE: 192 4 | MODEL: 5 | TYPE: swin_moe 6 | NAME: swin_moe_small_patch4_window12_192_8expert_32gpu_22k 7 | DROP_PATH_RATE: 0.2 8 | SWIN_MOE: 9 | EMBED_DIM: 96 10 | DEPTHS: [ 2, 2, 18, 2 ] 11 | NUM_HEADS: [ 3, 6, 12, 24 ] 12 | WINDOW_SIZE: 12 13 | MLP_FC2_BIAS: False 14 | INIT_STD: 0.005 15 | MOE_BLOCKS: [ [ -1 ], [ -1 ], [ 1, 3, 5, 7, 9, 11, 13, 15, 17 ], [ 1 ] ] 16 | NUM_LOCAL_EXPERTS: -4 17 | TOP_VALUE: 1 18 | CAPACITY_FACTOR: 1.25 19 | IS_GSHARD_LOSS: False 20 | MOE_DROP: 0.1 21 | AUX_LOSS_WEIGHT: 0.01 22 | TRAIN: 23 | EPOCHS: 90 24 | WARMUP_EPOCHS: 10 25 | WEIGHT_DECAY: 0.1 26 | BASE_LR: 1.25e-4 # 4096 batch-size 27 | WARMUP_LR: 1.25e-7 28 | MIN_LR: 1.25e-6 29 | CLIP_GRAD: 3.0 30 | TEST: 31 | SHUFFLE: True -------------------------------------------------------------------------------- /configs/swinmoe/swin_moe_small_patch4_window12_192_cosine_router_32expert_32gpu_22k.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | DATASET: imagenet22K 3 | IMG_SIZE: 192 4 | MODEL: 5 | TYPE: swin_moe 6 | NAME: swin_moe_small_patch4_window12_192_cosine_router_32expert_32gpu_22k 7 | DROP_PATH_RATE: 0.2 8 | SWIN_MOE: 9 | EMBED_DIM: 96 10 | DEPTHS: [ 2, 2, 18, 2 ] 11 | NUM_HEADS: [ 3, 6, 12, 24 ] 12 | WINDOW_SIZE: 12 13 | MLP_FC2_BIAS: False 14 | INIT_STD: 0.005 15 | MOE_BLOCKS: [ [ -1 ], [ -1 ], [ 1, 3, 5, 7, 9, 11, 13, 15, 17 ], [ 1 ] ] 16 | NUM_LOCAL_EXPERTS: 1 17 | TOP_VALUE: 1 18 | CAPACITY_FACTOR: 1.25 19 | COSINE_ROUTER: True 20 | IS_GSHARD_LOSS: False 21 | MOE_DROP: 0.1 22 | AUX_LOSS_WEIGHT: 0.01 23 | TRAIN: 24 | EPOCHS: 90 25 | WARMUP_EPOCHS: 10 26 | WEIGHT_DECAY: 0.1 27 | BASE_LR: 1.25e-4 # 4096 batch-size 28 | WARMUP_LR: 1.25e-7 29 | MIN_LR: 1.25e-6 30 | CLIP_GRAD: 3.0 31 | TEST: 32 | SHUFFLE: True -------------------------------------------------------------------------------- /configs/swinmoe/swin_moe_small_patch4_window12_192_densebaseline_22k.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | DATASET: imagenet22K 3 | IMG_SIZE: 192 4 | MODEL: 5 | TYPE: swin_moe 6 | NAME: swin_moe_small_patch4_window12_192_densebaseline_22k 7 | DROP_PATH_RATE: 0.2 8 | SWIN_MOE: 9 | EMBED_DIM: 96 10 | DEPTHS: [ 2, 2, 18, 2 ] 11 | NUM_HEADS: [ 3, 6, 12, 24 ] 12 | WINDOW_SIZE: 12 13 | MLP_FC2_BIAS: False 14 | MOE_BLOCKS: [ [ -1 ], [ -1 ], [ -1 ], [ -1 ] ] 15 | TRAIN: 16 | EPOCHS: 90 17 | WARMUP_EPOCHS: 10 18 | WEIGHT_DECAY: 0.1 19 | BASE_LR: 1.25e-4 # 4096 batch-size 20 | WARMUP_LR: 1.25e-7 21 | MIN_LR: 1.25e-6 22 | CLIP_GRAD: 3.0 23 | MOE: 24 | SAVE_MASTER: True 25 | TEST: 26 | SHUFFLE: True -------------------------------------------------------------------------------- /configs/swinv2/swinv2_base_patch4_window12_192_22k.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | DATASET: imagenet22K 3 | IMG_SIZE: 192 4 | MODEL: 5 | TYPE: swinv2 6 | NAME: swinv2_base_patch4_window12_192_22k 7 | DROP_PATH_RATE: 0.2 8 | SWINV2: 9 | EMBED_DIM: 128 10 | DEPTHS: [ 2, 2, 18, 2 ] 11 | NUM_HEADS: [ 4, 8, 16, 32 ] 12 | WINDOW_SIZE: 12 13 | TRAIN: 14 | EPOCHS: 90 15 | WARMUP_EPOCHS: 5 16 | WEIGHT_DECAY: 0.1 17 | BASE_LR: 1.25e-4 # 4096 batch-size 18 | WARMUP_LR: 1.25e-7 19 | MIN_LR: 1.25e-6 -------------------------------------------------------------------------------- /configs/swinv2/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 256 3 | MODEL: 4 | TYPE: swinv2 5 | NAME: swinv2_base_patch4_window12to16_192to256_22kto1k_ft 6 | DROP_PATH_RATE: 0.2 7 | SWINV2: 8 | EMBED_DIM: 128 9 | DEPTHS: [ 2, 2, 18, 2 ] 10 | NUM_HEADS: [ 4, 8, 16, 32 ] 11 | WINDOW_SIZE: 16 12 | PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ] 13 | TRAIN: 14 | EPOCHS: 30 15 | WARMUP_EPOCHS: 5 16 | WEIGHT_DECAY: 1e-8 17 | BASE_LR: 2e-05 18 | WARMUP_LR: 2e-08 19 | MIN_LR: 2e-07 -------------------------------------------------------------------------------- /configs/swinv2/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 384 3 | MODEL: 4 | TYPE: swinv2 5 | NAME: swinv2_base_patch4_window12to24_192to384_22kto1k_ft 6 | DROP_PATH_RATE: 0.2 7 | SWINV2: 8 | EMBED_DIM: 128 9 | DEPTHS: [ 2, 2, 18, 2 ] 10 | NUM_HEADS: [ 4, 8, 16, 32 ] 11 | WINDOW_SIZE: 24 12 | PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ] 13 | TRAIN: 14 | EPOCHS: 30 15 | WARMUP_EPOCHS: 5 16 | WEIGHT_DECAY: 1e-8 17 | BASE_LR: 2e-05 18 | WARMUP_LR: 2e-08 19 | MIN_LR: 2e-07 20 | TEST: 21 | CROP: False -------------------------------------------------------------------------------- /configs/swinv2/swinv2_base_patch4_window16_256.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 256 3 | MODEL: 4 | TYPE: swinv2 5 | NAME: swinv2_base_patch4_window16_256 6 | DROP_PATH_RATE: 0.5 7 | SWINV2: 8 | EMBED_DIM: 128 9 | DEPTHS: [ 2, 2, 18, 2 ] 10 | NUM_HEADS: [ 4, 8, 16, 32 ] 11 | WINDOW_SIZE: 16 -------------------------------------------------------------------------------- /configs/swinv2/swinv2_base_patch4_window8_256.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 256 3 | MODEL: 4 | TYPE: swinv2 5 | NAME: swinv2_base_patch4_window8_256 6 | DROP_PATH_RATE: 0.5 7 | SWINV2: 8 | EMBED_DIM: 128 9 | DEPTHS: [ 2, 2, 18, 2 ] 10 | NUM_HEADS: [ 4, 8, 16, 32 ] 11 | WINDOW_SIZE: 8 -------------------------------------------------------------------------------- /configs/swinv2/swinv2_large_patch4_window12_192_22k.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | DATASET: imagenet22K 3 | IMG_SIZE: 192 4 | MODEL: 5 | TYPE: swinv2 6 | NAME: swinv2_large_patch4_window12_192_22k 7 | DROP_PATH_RATE: 0.2 8 | SWINV2: 9 | EMBED_DIM: 192 10 | DEPTHS: [ 2, 2, 18, 2 ] 11 | NUM_HEADS: [ 6, 12, 24, 48 ] 12 | WINDOW_SIZE: 12 13 | TRAIN: 14 | EPOCHS: 90 15 | WARMUP_EPOCHS: 5 16 | WEIGHT_DECAY: 0.1 17 | BASE_LR: 1.25e-4 # 4096 batch-size 18 | WARMUP_LR: 1.25e-7 19 | MIN_LR: 1.25e-6 -------------------------------------------------------------------------------- /configs/swinv2/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 256 3 | MODEL: 4 | TYPE: swinv2 5 | NAME: swinv2_base_patch4_window12to16_192to256_22kto1k_ft 6 | DROP_PATH_RATE: 0.2 7 | SWINV2: 8 | EMBED_DIM: 192 9 | DEPTHS: [ 2, 2, 18, 2 ] 10 | NUM_HEADS: [ 6, 12, 24, 48 ] 11 | WINDOW_SIZE: 16 12 | PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ] 13 | TRAIN: 14 | EPOCHS: 30 15 | WARMUP_EPOCHS: 5 16 | WEIGHT_DECAY: 1e-8 17 | BASE_LR: 2e-05 18 | WARMUP_LR: 2e-08 19 | MIN_LR: 2e-07 -------------------------------------------------------------------------------- /configs/swinv2/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 384 3 | MODEL: 4 | TYPE: swinv2 5 | NAME: swinv2_large_patch4_window12to24_192to384_22kto1k_ft 6 | DROP_PATH_RATE: 0.2 7 | SWINV2: 8 | EMBED_DIM: 192 9 | DEPTHS: [ 2, 2, 18, 2 ] 10 | NUM_HEADS: [ 6, 12, 24, 48 ] 11 | WINDOW_SIZE: 24 12 | PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ] 13 | TRAIN: 14 | EPOCHS: 30 15 | WARMUP_EPOCHS: 5 16 | WEIGHT_DECAY: 1e-8 17 | BASE_LR: 2e-05 18 | WARMUP_LR: 2e-08 19 | MIN_LR: 2e-07 20 | TEST: 21 | CROP: False -------------------------------------------------------------------------------- /configs/swinv2/swinv2_small_patch4_window16_256.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 256 3 | MODEL: 4 | TYPE: swinv2 5 | NAME: swinv2_small_patch4_window16_256 6 | DROP_PATH_RATE: 0.3 7 | SWINV2: 8 | EMBED_DIM: 96 9 | DEPTHS: [ 2, 2, 18, 2 ] 10 | NUM_HEADS: [ 3, 6, 12, 24 ] 11 | WINDOW_SIZE: 16 -------------------------------------------------------------------------------- /configs/swinv2/swinv2_small_patch4_window8_256.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 256 3 | MODEL: 4 | TYPE: swinv2 5 | NAME: swinv2_small_patch4_window8_256 6 | DROP_PATH_RATE: 0.3 7 | SWINV2: 8 | EMBED_DIM: 96 9 | DEPTHS: [ 2, 2, 18, 2 ] 10 | NUM_HEADS: [ 3, 6, 12, 24 ] 11 | WINDOW_SIZE: 8 -------------------------------------------------------------------------------- /configs/swinv2/swinv2_tiny_patch4_window16_256.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 256 3 | MODEL: 4 | TYPE: swinv2 5 | NAME: swinv2_tiny_patch4_window16_256 6 | DROP_PATH_RATE: 0.2 7 | SWINV2: 8 | EMBED_DIM: 96 9 | DEPTHS: [ 2, 2, 6, 2 ] 10 | NUM_HEADS: [ 3, 6, 12, 24 ] 11 | WINDOW_SIZE: 16 -------------------------------------------------------------------------------- /configs/swinv2/swinv2_tiny_patch4_window8_256.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | IMG_SIZE: 256 3 | MODEL: 4 | TYPE: swinv2 5 | NAME: swinv2_tiny_patch4_window8_256 6 | DROP_PATH_RATE: 0.2 7 | SWINV2: 8 | EMBED_DIM: 96 9 | DEPTHS: [ 2, 2, 6, 2 ] 10 | NUM_HEADS: [ 3, 6, 12, 24 ] 11 | WINDOW_SIZE: 8 -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_loader as _build_loader 2 | from .data_simmim_pt import build_loader_simmim 3 | from .data_simmim_ft import build_loader_finetune 4 | 5 | 6 | def build_loader(config, simmim=False, is_pretrain=False): 7 | if not simmim: 8 | return _build_loader(config) 9 | if is_pretrain: 10 | return build_loader_simmim(config) 11 | else: 12 | return build_loader_finetune(config) 13 | -------------------------------------------------------------------------------- /data/build.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import torch 10 | import numpy as np 11 | import torch.distributed as dist 12 | from torchvision import datasets, transforms 13 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 14 | from timm.data import Mixup 15 | from timm.data import create_transform 16 | 17 | from .cached_image_folder import CachedImageFolder 18 | from .imagenet22k_dataset import IN22KDATASET 19 | from .samplers import SubsetRandomSampler 20 | 21 | try: 22 | from torchvision.transforms import InterpolationMode 23 | 24 | 25 | def _pil_interp(method): 26 | if method == 'bicubic': 27 | return InterpolationMode.BICUBIC 28 | elif method == 'lanczos': 29 | return InterpolationMode.LANCZOS 30 | elif method == 'hamming': 31 | return InterpolationMode.HAMMING 32 | else: 33 | # default bilinear, do we want to allow nearest? 34 | return InterpolationMode.BILINEAR 35 | 36 | 37 | import timm.data.transforms as timm_transforms 38 | 39 | timm_transforms._pil_interp = _pil_interp 40 | except: 41 | from timm.data.transforms import _pil_interp 42 | 43 | 44 | def build_loader(config): 45 | config.defrost() 46 | dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config) 47 | config.freeze() 48 | print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset") 49 | dataset_val, _ = build_dataset(is_train=False, config=config) 50 | print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset") 51 | 52 | num_tasks = dist.get_world_size() 53 | global_rank = dist.get_rank() 54 | if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part': 55 | indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size()) 56 | sampler_train = SubsetRandomSampler(indices) 57 | else: 58 | sampler_train = torch.utils.data.DistributedSampler( 59 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 60 | ) 61 | 62 | if config.TEST.SEQUENTIAL: 63 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 64 | else: 65 | sampler_val = torch.utils.data.distributed.DistributedSampler( 66 | dataset_val, shuffle=config.TEST.SHUFFLE 67 | ) 68 | 69 | data_loader_train = torch.utils.data.DataLoader( 70 | dataset_train, sampler=sampler_train, 71 | batch_size=config.DATA.BATCH_SIZE, 72 | num_workers=config.DATA.NUM_WORKERS, 73 | pin_memory=config.DATA.PIN_MEMORY, 74 | drop_last=True, 75 | ) 76 | 77 | data_loader_val = torch.utils.data.DataLoader( 78 | dataset_val, sampler=sampler_val, 79 | batch_size=config.DATA.BATCH_SIZE, 80 | shuffle=False, 81 | num_workers=config.DATA.NUM_WORKERS, 82 | pin_memory=config.DATA.PIN_MEMORY, 83 | drop_last=False 84 | ) 85 | 86 | # setup mixup / cutmix 87 | mixup_fn = None 88 | mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None 89 | if mixup_active: 90 | mixup_fn = Mixup( 91 | mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX, 92 | prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE, 93 | label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES) 94 | 95 | return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn 96 | 97 | 98 | def build_dataset(is_train, config): 99 | transform = build_transform(is_train, config) 100 | if config.DATA.DATASET == 'imagenet': 101 | prefix = 'train' if is_train else 'val' 102 | if config.DATA.ZIP_MODE: 103 | ann_file = prefix + "_map.txt" 104 | prefix = prefix + ".zip@/" 105 | dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform, 106 | cache_mode=config.DATA.CACHE_MODE if is_train else 'part') 107 | else: 108 | root = os.path.join(config.DATA.DATA_PATH, prefix) 109 | dataset = datasets.ImageFolder(root, transform=transform) 110 | nb_classes = 1000 111 | elif config.DATA.DATASET == 'imagenet22K': 112 | prefix = 'ILSVRC2011fall_whole' 113 | if is_train: 114 | ann_file = prefix + "_map_train.txt" 115 | else: 116 | ann_file = prefix + "_map_val.txt" 117 | dataset = IN22KDATASET(config.DATA.DATA_PATH, ann_file, transform) 118 | nb_classes = 21841 119 | else: 120 | raise NotImplementedError("We only support ImageNet Now.") 121 | 122 | return dataset, nb_classes 123 | 124 | 125 | def build_transform(is_train, config): 126 | resize_im = config.DATA.IMG_SIZE > 32 127 | if is_train: 128 | # this should always dispatch to transforms_imagenet_train 129 | transform = create_transform( 130 | input_size=config.DATA.IMG_SIZE, 131 | is_training=True, 132 | color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None, 133 | auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None, 134 | re_prob=config.AUG.REPROB, 135 | re_mode=config.AUG.REMODE, 136 | re_count=config.AUG.RECOUNT, 137 | interpolation=config.DATA.INTERPOLATION, 138 | ) 139 | if not resize_im: 140 | # replace RandomResizedCropAndInterpolation with 141 | # RandomCrop 142 | transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4) 143 | return transform 144 | 145 | t = [] 146 | if resize_im: 147 | if config.TEST.CROP: 148 | size = int((256 / 224) * config.DATA.IMG_SIZE) 149 | t.append( 150 | transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)), 151 | # to maintain same ratio w.r.t. 224 images 152 | ) 153 | t.append(transforms.CenterCrop(config.DATA.IMG_SIZE)) 154 | else: 155 | t.append( 156 | transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE), 157 | interpolation=_pil_interp(config.DATA.INTERPOLATION)) 158 | ) 159 | 160 | t.append(transforms.ToTensor()) 161 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 162 | return transforms.Compose(t) 163 | -------------------------------------------------------------------------------- /data/cached_image_folder.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import io 9 | import os 10 | import time 11 | import torch.distributed as dist 12 | import torch.utils.data as data 13 | from PIL import Image 14 | 15 | from .zipreader import is_zip_path, ZipReader 16 | 17 | 18 | def has_file_allowed_extension(filename, extensions): 19 | """Checks if a file is an allowed extension. 20 | Args: 21 | filename (string): path to a file 22 | Returns: 23 | bool: True if the filename ends with a known image extension 24 | """ 25 | filename_lower = filename.lower() 26 | return any(filename_lower.endswith(ext) for ext in extensions) 27 | 28 | 29 | def find_classes(dir): 30 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 31 | classes.sort() 32 | class_to_idx = {classes[i]: i for i in range(len(classes))} 33 | return classes, class_to_idx 34 | 35 | 36 | def make_dataset(dir, class_to_idx, extensions): 37 | images = [] 38 | dir = os.path.expanduser(dir) 39 | for target in sorted(os.listdir(dir)): 40 | d = os.path.join(dir, target) 41 | if not os.path.isdir(d): 42 | continue 43 | 44 | for root, _, fnames in sorted(os.walk(d)): 45 | for fname in sorted(fnames): 46 | if has_file_allowed_extension(fname, extensions): 47 | path = os.path.join(root, fname) 48 | item = (path, class_to_idx[target]) 49 | images.append(item) 50 | 51 | return images 52 | 53 | 54 | def make_dataset_with_ann(ann_file, img_prefix, extensions): 55 | images = [] 56 | with open(ann_file, "r") as f: 57 | contents = f.readlines() 58 | for line_str in contents: 59 | path_contents = [c for c in line_str.split('\t')] 60 | im_file_name = path_contents[0] 61 | class_index = int(path_contents[1]) 62 | 63 | assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions 64 | item = (os.path.join(img_prefix, im_file_name), class_index) 65 | 66 | images.append(item) 67 | 68 | return images 69 | 70 | 71 | class DatasetFolder(data.Dataset): 72 | """A generic data loader where the samples are arranged in this way: :: 73 | root/class_x/xxx.ext 74 | root/class_x/xxy.ext 75 | root/class_x/xxz.ext 76 | root/class_y/123.ext 77 | root/class_y/nsdf3.ext 78 | root/class_y/asd932_.ext 79 | Args: 80 | root (string): Root directory path. 81 | loader (callable): A function to load a sample given its path. 82 | extensions (list[string]): A list of allowed extensions. 83 | transform (callable, optional): A function/transform that takes in 84 | a sample and returns a transformed version. 85 | E.g, ``transforms.RandomCrop`` for images. 86 | target_transform (callable, optional): A function/transform that takes 87 | in the target and transforms it. 88 | Attributes: 89 | samples (list): List of (sample path, class_index) tuples 90 | """ 91 | 92 | def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None, 93 | cache_mode="no"): 94 | # image folder mode 95 | if ann_file == '': 96 | _, class_to_idx = find_classes(root) 97 | samples = make_dataset(root, class_to_idx, extensions) 98 | # zip mode 99 | else: 100 | samples = make_dataset_with_ann(os.path.join(root, ann_file), 101 | os.path.join(root, img_prefix), 102 | extensions) 103 | 104 | if len(samples) == 0: 105 | raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" + 106 | "Supported extensions are: " + ",".join(extensions))) 107 | 108 | self.root = root 109 | self.loader = loader 110 | self.extensions = extensions 111 | 112 | self.samples = samples 113 | self.labels = [y_1k for _, y_1k in samples] 114 | self.classes = list(set(self.labels)) 115 | 116 | self.transform = transform 117 | self.target_transform = target_transform 118 | 119 | self.cache_mode = cache_mode 120 | if self.cache_mode != "no": 121 | self.init_cache() 122 | 123 | def init_cache(self): 124 | assert self.cache_mode in ["part", "full"] 125 | n_sample = len(self.samples) 126 | global_rank = dist.get_rank() 127 | world_size = dist.get_world_size() 128 | 129 | samples_bytes = [None for _ in range(n_sample)] 130 | start_time = time.time() 131 | for index in range(n_sample): 132 | if index % (n_sample // 10) == 0: 133 | t = time.time() - start_time 134 | print(f'global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block') 135 | start_time = time.time() 136 | path, target = self.samples[index] 137 | if self.cache_mode == "full": 138 | samples_bytes[index] = (ZipReader.read(path), target) 139 | elif self.cache_mode == "part" and index % world_size == global_rank: 140 | samples_bytes[index] = (ZipReader.read(path), target) 141 | else: 142 | samples_bytes[index] = (path, target) 143 | self.samples = samples_bytes 144 | 145 | def __getitem__(self, index): 146 | """ 147 | Args: 148 | index (int): Index 149 | Returns: 150 | tuple: (sample, target) where target is class_index of the target class. 151 | """ 152 | path, target = self.samples[index] 153 | sample = self.loader(path) 154 | if self.transform is not None: 155 | sample = self.transform(sample) 156 | if self.target_transform is not None: 157 | target = self.target_transform(target) 158 | 159 | return sample, target 160 | 161 | def __len__(self): 162 | return len(self.samples) 163 | 164 | def __repr__(self): 165 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 166 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 167 | fmt_str += ' Root Location: {}\n'.format(self.root) 168 | tmp = ' Transforms (if any): ' 169 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 170 | tmp = ' Target Transforms (if any): ' 171 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 172 | return fmt_str 173 | 174 | 175 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] 176 | 177 | 178 | def pil_loader(path): 179 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 180 | if isinstance(path, bytes): 181 | img = Image.open(io.BytesIO(path)) 182 | elif is_zip_path(path): 183 | data = ZipReader.read(path) 184 | img = Image.open(io.BytesIO(data)) 185 | else: 186 | with open(path, 'rb') as f: 187 | img = Image.open(f) 188 | return img.convert('RGB') 189 | return img.convert('RGB') 190 | 191 | 192 | def accimage_loader(path): 193 | import accimage 194 | try: 195 | return accimage.Image(path) 196 | except IOError: 197 | # Potentially a decoding problem, fall back to PIL.Image 198 | return pil_loader(path) 199 | 200 | 201 | def default_img_loader(path): 202 | from torchvision import get_image_backend 203 | if get_image_backend() == 'accimage': 204 | return accimage_loader(path) 205 | else: 206 | return pil_loader(path) 207 | 208 | 209 | class CachedImageFolder(DatasetFolder): 210 | """A generic data loader where the images are arranged in this way: :: 211 | root/dog/xxx.png 212 | root/dog/xxy.png 213 | root/dog/xxz.png 214 | root/cat/123.png 215 | root/cat/nsdf3.png 216 | root/cat/asd932_.png 217 | Args: 218 | root (string): Root directory path. 219 | transform (callable, optional): A function/transform that takes in an PIL image 220 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 221 | target_transform (callable, optional): A function/transform that takes in the 222 | target and transforms it. 223 | loader (callable, optional): A function to load an image given its path. 224 | Attributes: 225 | imgs (list): List of (image path, class_index) tuples 226 | """ 227 | 228 | def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None, 229 | loader=default_img_loader, cache_mode="no"): 230 | super(CachedImageFolder, self).__init__(root, loader, IMG_EXTENSIONS, 231 | ann_file=ann_file, img_prefix=img_prefix, 232 | transform=transform, target_transform=target_transform, 233 | cache_mode=cache_mode) 234 | self.imgs = self.samples 235 | 236 | def __getitem__(self, index): 237 | """ 238 | Args: 239 | index (int): Index 240 | Returns: 241 | tuple: (image, target) where target is class_index of the target class. 242 | """ 243 | path, target = self.samples[index] 244 | image = self.loader(path) 245 | if self.transform is not None: 246 | img = self.transform(image) 247 | else: 248 | img = image 249 | if self.target_transform is not None: 250 | target = self.target_transform(target) 251 | 252 | return img, target 253 | -------------------------------------------------------------------------------- /data/data_simmim_ft.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # SimMIM 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Zhenda Xie 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import torch.distributed as dist 10 | from torch.utils.data import DataLoader, DistributedSampler 11 | from torchvision import datasets, transforms 12 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 13 | from timm.data import Mixup 14 | from timm.data import create_transform 15 | from timm.data.transforms import _pil_interp 16 | 17 | 18 | def build_loader_finetune(config): 19 | config.defrost() 20 | dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config) 21 | config.freeze() 22 | dataset_val, _ = build_dataset(is_train=False, config=config) 23 | 24 | num_tasks = dist.get_world_size() 25 | global_rank = dist.get_rank() 26 | sampler_train = DistributedSampler( 27 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 28 | ) 29 | sampler_val = DistributedSampler( 30 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False 31 | ) 32 | 33 | data_loader_train = DataLoader( 34 | dataset_train, sampler=sampler_train, 35 | batch_size=config.DATA.BATCH_SIZE, 36 | num_workers=config.DATA.NUM_WORKERS, 37 | pin_memory=config.DATA.PIN_MEMORY, 38 | drop_last=True, 39 | ) 40 | 41 | data_loader_val = DataLoader( 42 | dataset_val, sampler=sampler_val, 43 | batch_size=config.DATA.BATCH_SIZE, 44 | num_workers=config.DATA.NUM_WORKERS, 45 | pin_memory=config.DATA.PIN_MEMORY, 46 | drop_last=False, 47 | ) 48 | 49 | # setup mixup / cutmix 50 | mixup_fn = None 51 | mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None 52 | if mixup_active: 53 | mixup_fn = Mixup( 54 | mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX, 55 | prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE, 56 | label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES) 57 | 58 | return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn 59 | 60 | 61 | def build_dataset(is_train, config): 62 | transform = build_transform(is_train, config) 63 | 64 | if config.DATA.DATASET == 'imagenet': 65 | prefix = 'train' if is_train else 'val' 66 | root = os.path.join(config.DATA.DATA_PATH, prefix) 67 | dataset = datasets.ImageFolder(root, transform=transform) 68 | nb_classes = 1000 69 | else: 70 | raise NotImplementedError("We only support ImageNet Now.") 71 | 72 | return dataset, nb_classes 73 | 74 | 75 | def build_transform(is_train, config): 76 | resize_im = config.DATA.IMG_SIZE > 32 77 | if is_train: 78 | # this should always dispatch to transforms_imagenet_train 79 | transform = create_transform( 80 | input_size=config.DATA.IMG_SIZE, 81 | is_training=True, 82 | color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None, 83 | auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None, 84 | re_prob=config.AUG.REPROB, 85 | re_mode=config.AUG.REMODE, 86 | re_count=config.AUG.RECOUNT, 87 | interpolation=config.DATA.INTERPOLATION, 88 | ) 89 | if not resize_im: 90 | # replace RandomResizedCropAndInterpolation with 91 | # RandomCrop 92 | transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4) 93 | return transform 94 | 95 | t = [] 96 | if resize_im: 97 | if config.TEST.CROP: 98 | size = int((256 / 224) * config.DATA.IMG_SIZE) 99 | t.append( 100 | transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)), 101 | # to maintain same ratio w.r.t. 224 images 102 | ) 103 | t.append(transforms.CenterCrop(config.DATA.IMG_SIZE)) 104 | else: 105 | t.append( 106 | transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE), 107 | interpolation=_pil_interp(config.DATA.INTERPOLATION)) 108 | ) 109 | 110 | t.append(transforms.ToTensor()) 111 | t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) 112 | return transforms.Compose(t) 113 | -------------------------------------------------------------------------------- /data/data_simmim_pt.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # SimMIM 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Zhenda Xie 6 | # -------------------------------------------------------- 7 | 8 | import math 9 | import random 10 | import numpy as np 11 | 12 | import torch 13 | import torch.distributed as dist 14 | import torchvision.transforms as T 15 | from torch.utils.data import DataLoader, DistributedSampler 16 | from torch.utils.data._utils.collate import default_collate 17 | from torchvision.datasets import ImageFolder 18 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 19 | 20 | 21 | class MaskGenerator: 22 | def __init__(self, input_size=192, mask_patch_size=32, model_patch_size=4, mask_ratio=0.6): 23 | self.input_size = input_size 24 | self.mask_patch_size = mask_patch_size 25 | self.model_patch_size = model_patch_size 26 | self.mask_ratio = mask_ratio 27 | 28 | assert self.input_size % self.mask_patch_size == 0 29 | assert self.mask_patch_size % self.model_patch_size == 0 30 | 31 | self.rand_size = self.input_size // self.mask_patch_size 32 | self.scale = self.mask_patch_size // self.model_patch_size 33 | 34 | self.token_count = self.rand_size ** 2 35 | self.mask_count = int(np.ceil(self.token_count * self.mask_ratio)) 36 | 37 | def __call__(self): 38 | mask_idx = np.random.permutation(self.token_count)[:self.mask_count] 39 | mask = np.zeros(self.token_count, dtype=int) 40 | mask[mask_idx] = 1 41 | 42 | mask = mask.reshape((self.rand_size, self.rand_size)) 43 | mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1) 44 | 45 | return mask 46 | 47 | 48 | class SimMIMTransform: 49 | def __init__(self, config): 50 | self.transform_img = T.Compose([ 51 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), 52 | T.RandomResizedCrop(config.DATA.IMG_SIZE, scale=(0.67, 1.), ratio=(3. / 4., 4. / 3.)), 53 | T.RandomHorizontalFlip(), 54 | T.ToTensor(), 55 | T.Normalize(mean=torch.tensor(IMAGENET_DEFAULT_MEAN),std=torch.tensor(IMAGENET_DEFAULT_STD)), 56 | ]) 57 | 58 | if config.MODEL.TYPE in ['swin', 'swinv2']: 59 | model_patch_size=config.MODEL.SWIN.PATCH_SIZE 60 | else: 61 | raise NotImplementedError 62 | 63 | self.mask_generator = MaskGenerator( 64 | input_size=config.DATA.IMG_SIZE, 65 | mask_patch_size=config.DATA.MASK_PATCH_SIZE, 66 | model_patch_size=model_patch_size, 67 | mask_ratio=config.DATA.MASK_RATIO, 68 | ) 69 | 70 | def __call__(self, img): 71 | img = self.transform_img(img) 72 | mask = self.mask_generator() 73 | 74 | return img, mask 75 | 76 | 77 | def collate_fn(batch): 78 | if not isinstance(batch[0][0], tuple): 79 | return default_collate(batch) 80 | else: 81 | batch_num = len(batch) 82 | ret = [] 83 | for item_idx in range(len(batch[0][0])): 84 | if batch[0][0][item_idx] is None: 85 | ret.append(None) 86 | else: 87 | ret.append(default_collate([batch[i][0][item_idx] for i in range(batch_num)])) 88 | ret.append(default_collate([batch[i][1] for i in range(batch_num)])) 89 | return ret 90 | 91 | 92 | def build_loader_simmim(config): 93 | transform = SimMIMTransform(config) 94 | dataset = ImageFolder(config.DATA.DATA_PATH, transform) 95 | 96 | sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True) 97 | dataloader = DataLoader(dataset, config.DATA.BATCH_SIZE, sampler=sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, drop_last=True, collate_fn=collate_fn) 98 | 99 | return dataloader -------------------------------------------------------------------------------- /data/imagenet22k_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch.utils.data as data 4 | import numpy as np 5 | from PIL import Image 6 | 7 | import warnings 8 | 9 | warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning) 10 | 11 | 12 | class IN22KDATASET(data.Dataset): 13 | def __init__(self, root, ann_file='', transform=None, target_transform=None): 14 | super(IN22KDATASET, self).__init__() 15 | 16 | self.data_path = root 17 | self.ann_path = os.path.join(self.data_path, ann_file) 18 | self.transform = transform 19 | self.target_transform = target_transform 20 | # id & label: https://github.com/google-research/big_transfer/issues/7 21 | # total: 21843; only 21841 class have images: map 21841->9205; 21842->15027 22 | self.database = json.load(open(self.ann_path)) 23 | 24 | def _load_image(self, path): 25 | try: 26 | im = Image.open(path) 27 | except: 28 | print("ERROR IMG LOADED: ", path) 29 | random_img = np.random.rand(224, 224, 3) * 255 30 | im = Image.fromarray(np.uint8(random_img)) 31 | return im 32 | 33 | def __getitem__(self, index): 34 | """ 35 | Args: 36 | index (int): Index 37 | Returns: 38 | tuple: (image, target) where target is class_index of the target class. 39 | """ 40 | idb = self.database[index] 41 | 42 | # images 43 | images = self._load_image(self.data_path + '/' + idb[0]).convert('RGB') 44 | if self.transform is not None: 45 | images = self.transform(images) 46 | 47 | # target 48 | target = int(idb[1]) 49 | if self.target_transform is not None: 50 | target = self.target_transform(target) 51 | 52 | return images, target 53 | 54 | def __len__(self): 55 | return len(self.database) 56 | -------------------------------------------------------------------------------- /data/map22kto1k.txt: -------------------------------------------------------------------------------- 1 | 359 2 | 368 3 | 460 4 | 475 5 | 486 6 | 492 7 | 496 8 | 514 9 | 516 10 | 525 11 | 547 12 | 548 13 | 556 14 | 563 15 | 575 16 | 641 17 | 648 18 | 723 19 | 733 20 | 765 21 | 801 22 | 826 23 | 852 24 | 858 25 | 878 26 | 896 27 | 900 28 | 905 29 | 908 30 | 910 31 | 935 32 | 946 33 | 947 34 | 994 35 | 999 36 | 1003 37 | 1005 38 | 1010 39 | 1027 40 | 1029 41 | 1048 42 | 1055 43 | 1064 44 | 1065 45 | 1069 46 | 1075 47 | 1079 48 | 1081 49 | 1085 50 | 1088 51 | 1093 52 | 1106 53 | 1143 54 | 1144 55 | 1145 56 | 1147 57 | 1168 58 | 1171 59 | 1178 60 | 1187 61 | 1190 62 | 1197 63 | 1205 64 | 1216 65 | 1223 66 | 1230 67 | 1236 68 | 1241 69 | 1245 70 | 1257 71 | 1259 72 | 1260 73 | 1267 74 | 1268 75 | 1269 76 | 1271 77 | 1272 78 | 1273 79 | 1277 80 | 1303 81 | 1344 82 | 1349 83 | 1355 84 | 1357 85 | 1384 86 | 1388 87 | 1391 88 | 1427 89 | 1429 90 | 1432 91 | 1437 92 | 1450 93 | 1461 94 | 1462 95 | 1474 96 | 1502 97 | 1503 98 | 1512 99 | 1552 100 | 1555 101 | 1577 102 | 1584 103 | 1587 104 | 1589 105 | 1599 106 | 1615 107 | 1616 108 | 1681 109 | 1692 110 | 1701 111 | 1716 112 | 1729 113 | 1757 114 | 1759 115 | 1764 116 | 1777 117 | 1786 118 | 1822 119 | 1841 120 | 1842 121 | 1848 122 | 1850 123 | 1856 124 | 1860 125 | 1861 126 | 1864 127 | 1876 128 | 1897 129 | 1898 130 | 1910 131 | 1913 132 | 1918 133 | 1922 134 | 1928 135 | 1932 136 | 1935 137 | 1947 138 | 1951 139 | 1953 140 | 1970 141 | 1977 142 | 1979 143 | 2001 144 | 2017 145 | 2067 146 | 2081 147 | 2087 148 | 2112 149 | 2128 150 | 2135 151 | 2147 152 | 2174 153 | 2175 154 | 2176 155 | 2177 156 | 2178 157 | 2181 158 | 2183 159 | 2184 160 | 2187 161 | 2189 162 | 2190 163 | 2191 164 | 2192 165 | 2193 166 | 2197 167 | 2202 168 | 2203 169 | 2206 170 | 2208 171 | 2209 172 | 2211 173 | 2212 174 | 2213 175 | 2214 176 | 2215 177 | 2216 178 | 2217 179 | 2219 180 | 2222 181 | 2223 182 | 2224 183 | 2225 184 | 2226 185 | 2227 186 | 2228 187 | 2229 188 | 2230 189 | 2236 190 | 2238 191 | 2240 192 | 2241 193 | 2242 194 | 2243 195 | 2244 196 | 2245 197 | 2247 198 | 2248 199 | 2249 200 | 2250 201 | 2251 202 | 2252 203 | 2255 204 | 2256 205 | 2257 206 | 2262 207 | 2263 208 | 2264 209 | 2265 210 | 2266 211 | 2268 212 | 2270 213 | 2271 214 | 2272 215 | 2273 216 | 2275 217 | 2276 218 | 2279 219 | 2280 220 | 2281 221 | 2282 222 | 2285 223 | 2289 224 | 2292 225 | 2295 226 | 2296 227 | 2297 228 | 2298 229 | 2299 230 | 2300 231 | 2301 232 | 2302 233 | 2303 234 | 2304 235 | 2305 236 | 2306 237 | 2309 238 | 2310 239 | 2312 240 | 2313 241 | 2314 242 | 2315 243 | 2316 244 | 2318 245 | 2319 246 | 2321 247 | 2322 248 | 2326 249 | 2329 250 | 2330 251 | 2331 252 | 2332 253 | 2334 254 | 2335 255 | 2336 256 | 2337 257 | 2338 258 | 2339 259 | 2341 260 | 2342 261 | 2343 262 | 2344 263 | 2346 264 | 2348 265 | 2349 266 | 2351 267 | 2352 268 | 2353 269 | 2355 270 | 2357 271 | 2358 272 | 2359 273 | 2360 274 | 2364 275 | 2365 276 | 2368 277 | 2369 278 | 2377 279 | 2382 280 | 2383 281 | 2385 282 | 2397 283 | 2398 284 | 2400 285 | 2402 286 | 2405 287 | 2412 288 | 2421 289 | 2428 290 | 2431 291 | 2432 292 | 2433 293 | 2436 294 | 2441 295 | 2445 296 | 2450 297 | 2453 298 | 2454 299 | 2465 300 | 2469 301 | 2532 302 | 2533 303 | 2538 304 | 2544 305 | 2547 306 | 2557 307 | 2565 308 | 2578 309 | 2612 310 | 2658 311 | 2702 312 | 2722 313 | 2731 314 | 2738 315 | 2741 316 | 2747 317 | 2810 318 | 2818 319 | 2833 320 | 2844 321 | 2845 322 | 2867 323 | 2874 324 | 2882 325 | 2884 326 | 2888 327 | 2889 328 | 3008 329 | 3012 330 | 3019 331 | 3029 332 | 3033 333 | 3042 334 | 3091 335 | 3106 336 | 3138 337 | 3159 338 | 3164 339 | 3169 340 | 3280 341 | 3296 342 | 3311 343 | 3318 344 | 3320 345 | 3324 346 | 3330 347 | 3366 348 | 3375 349 | 3381 350 | 3406 351 | 3419 352 | 3432 353 | 3434 354 | 3435 355 | 3493 356 | 3495 357 | 3503 358 | 3509 359 | 3511 360 | 3513 361 | 3517 362 | 3521 363 | 3526 364 | 3546 365 | 3554 366 | 3600 367 | 3601 368 | 3606 369 | 3612 370 | 3613 371 | 3616 372 | 3622 373 | 3623 374 | 3627 375 | 3632 376 | 3634 377 | 3636 378 | 3638 379 | 3644 380 | 3646 381 | 3649 382 | 3650 383 | 3651 384 | 3656 385 | 3663 386 | 3673 387 | 3674 388 | 3689 389 | 3690 390 | 3702 391 | 3733 392 | 3769 393 | 3971 394 | 3974 395 | 4065 396 | 4068 397 | 4073 398 | 4102 399 | 4136 400 | 4140 401 | 4151 402 | 4159 403 | 4165 404 | 4207 405 | 4219 406 | 4226 407 | 4249 408 | 4256 409 | 4263 410 | 4270 411 | 4313 412 | 4321 413 | 4378 414 | 4386 415 | 4478 416 | 4508 417 | 4512 418 | 4536 419 | 4542 420 | 4550 421 | 4560 422 | 4562 423 | 4570 424 | 4571 425 | 4572 426 | 4583 427 | 4588 428 | 4594 429 | 4604 430 | 4608 431 | 4623 432 | 4634 433 | 4636 434 | 4646 435 | 4651 436 | 4652 437 | 4686 438 | 4688 439 | 4691 440 | 4699 441 | 4724 442 | 4727 443 | 4737 444 | 4770 445 | 4774 446 | 4789 447 | 4802 448 | 4807 449 | 4819 450 | 4880 451 | 4886 452 | 4908 453 | 4927 454 | 4931 455 | 4936 456 | 4964 457 | 4976 458 | 4993 459 | 5028 460 | 5033 461 | 5043 462 | 5046 463 | 5096 464 | 5111 465 | 5114 466 | 5131 467 | 5132 468 | 5183 469 | 5199 470 | 5235 471 | 5275 472 | 5291 473 | 5293 474 | 5294 475 | 5343 476 | 5360 477 | 5362 478 | 5364 479 | 5390 480 | 5402 481 | 5418 482 | 5428 483 | 5430 484 | 5437 485 | 5443 486 | 5473 487 | 5484 488 | 5486 489 | 5505 490 | 5507 491 | 5508 492 | 5510 493 | 5567 494 | 5578 495 | 5580 496 | 5584 497 | 5606 498 | 5613 499 | 5629 500 | 5672 501 | 5676 502 | 5692 503 | 5701 504 | 5760 505 | 5769 506 | 5770 507 | 5779 508 | 5814 509 | 5850 510 | 5871 511 | 5893 512 | 5911 513 | 5949 514 | 5954 515 | 6005 516 | 6006 517 | 6012 518 | 6017 519 | 6023 520 | 6024 521 | 6040 522 | 6050 523 | 6054 524 | 6087 525 | 6105 526 | 6157 527 | 6235 528 | 6237 529 | 6256 530 | 6259 531 | 6286 532 | 6291 533 | 6306 534 | 6339 535 | 6341 536 | 6343 537 | 6379 538 | 6383 539 | 6393 540 | 6405 541 | 6479 542 | 6511 543 | 6517 544 | 6541 545 | 6561 546 | 6608 547 | 6611 548 | 6615 549 | 6678 550 | 6682 551 | 6707 552 | 6752 553 | 6798 554 | 6850 555 | 6880 556 | 6885 557 | 6890 558 | 6920 559 | 6981 560 | 7000 561 | 7009 562 | 7038 563 | 7049 564 | 7050 565 | 7052 566 | 7073 567 | 7078 568 | 7098 569 | 7111 570 | 7165 571 | 7198 572 | 7204 573 | 7280 574 | 7283 575 | 7286 576 | 7287 577 | 7293 578 | 7294 579 | 7305 580 | 7318 581 | 7341 582 | 7346 583 | 7354 584 | 7382 585 | 7427 586 | 7428 587 | 7435 588 | 7445 589 | 7450 590 | 7455 591 | 7467 592 | 7469 593 | 7497 594 | 7502 595 | 7506 596 | 7514 597 | 7523 598 | 7651 599 | 7661 600 | 7664 601 | 7672 602 | 7679 603 | 7685 604 | 7696 605 | 7730 606 | 7871 607 | 7873 608 | 7895 609 | 7914 610 | 7915 611 | 7920 612 | 7934 613 | 7935 614 | 7949 615 | 8009 616 | 8036 617 | 8051 618 | 8065 619 | 8074 620 | 8090 621 | 8112 622 | 8140 623 | 8164 624 | 8168 625 | 8178 626 | 8182 627 | 8198 628 | 8212 629 | 8216 630 | 8230 631 | 8242 632 | 8288 633 | 8289 634 | 8295 635 | 8318 636 | 8352 637 | 8368 638 | 8371 639 | 8375 640 | 8376 641 | 8401 642 | 8416 643 | 8419 644 | 8436 645 | 8460 646 | 8477 647 | 8478 648 | 8482 649 | 8498 650 | 8500 651 | 8539 652 | 8543 653 | 8552 654 | 8555 655 | 8580 656 | 8584 657 | 8586 658 | 8594 659 | 8598 660 | 8601 661 | 8606 662 | 8610 663 | 8611 664 | 8622 665 | 8627 666 | 8639 667 | 8649 668 | 8650 669 | 8653 670 | 8654 671 | 8667 672 | 8672 673 | 8673 674 | 8674 675 | 8676 676 | 8684 677 | 8720 678 | 8723 679 | 8750 680 | 8753 681 | 8801 682 | 8815 683 | 8831 684 | 8835 685 | 8842 686 | 8845 687 | 8858 688 | 8897 689 | 8916 690 | 8951 691 | 8954 692 | 8959 693 | 8970 694 | 8976 695 | 8981 696 | 8983 697 | 8989 698 | 8991 699 | 8993 700 | 9019 701 | 9039 702 | 9042 703 | 9043 704 | 9056 705 | 9057 706 | 9070 707 | 9087 708 | 9098 709 | 9106 710 | 9130 711 | 9131 712 | 9155 713 | 9171 714 | 9183 715 | 9198 716 | 9199 717 | 9201 718 | 9204 719 | 9212 720 | 9221 721 | 9225 722 | 9229 723 | 9250 724 | 9260 725 | 9271 726 | 9279 727 | 9295 728 | 9300 729 | 9310 730 | 9322 731 | 9345 732 | 9352 733 | 9376 734 | 9377 735 | 9382 736 | 9392 737 | 9401 738 | 9405 739 | 9441 740 | 9449 741 | 9464 742 | 9475 743 | 9502 744 | 9505 745 | 9514 746 | 9515 747 | 9545 748 | 9567 749 | 9576 750 | 9608 751 | 9609 752 | 9624 753 | 9633 754 | 9639 755 | 9643 756 | 9656 757 | 9674 758 | 9740 759 | 9752 760 | 9760 761 | 9767 762 | 9778 763 | 9802 764 | 9820 765 | 9839 766 | 9879 767 | 9924 768 | 9956 769 | 9961 770 | 9963 771 | 9970 772 | 9997 773 | 10010 774 | 10031 775 | 10040 776 | 10052 777 | 10073 778 | 10075 779 | 10078 780 | 10094 781 | 10097 782 | 10109 783 | 10118 784 | 10121 785 | 10124 786 | 10158 787 | 10226 788 | 10276 789 | 10304 790 | 10307 791 | 10314 792 | 10315 793 | 10332 794 | 10337 795 | 10338 796 | 10413 797 | 10423 798 | 10451 799 | 10463 800 | 10465 801 | 10487 802 | 10519 803 | 10522 804 | 10523 805 | 10532 806 | 10534 807 | 10535 808 | 10551 809 | 10559 810 | 10574 811 | 10583 812 | 10586 813 | 10589 814 | 10612 815 | 10626 816 | 10635 817 | 10638 818 | 10677 819 | 10683 820 | 10726 821 | 10776 822 | 10782 823 | 10783 824 | 10807 825 | 10837 826 | 10840 827 | 10848 828 | 10859 829 | 10871 830 | 10881 831 | 10884 832 | 10908 833 | 10914 834 | 10921 835 | 10936 836 | 10947 837 | 10951 838 | 10952 839 | 10957 840 | 10999 841 | 11003 842 | 11018 843 | 11023 844 | 11025 845 | 11027 846 | 11045 847 | 11055 848 | 11095 849 | 11110 850 | 11137 851 | 5564 852 | 11168 853 | 11186 854 | 11221 855 | 11223 856 | 11242 857 | 11255 858 | 11259 859 | 11279 860 | 11306 861 | 11311 862 | 11331 863 | 11367 864 | 11377 865 | 11389 866 | 11392 867 | 11401 868 | 11407 869 | 11437 870 | 11449 871 | 11466 872 | 11469 873 | 11473 874 | 11478 875 | 11483 876 | 11484 877 | 11507 878 | 11536 879 | 11558 880 | 11566 881 | 11575 882 | 11584 883 | 11594 884 | 11611 885 | 11612 886 | 11619 887 | 11621 888 | 11640 889 | 11643 890 | 11664 891 | 11674 892 | 11689 893 | 11709 894 | 11710 895 | 11716 896 | 11721 897 | 11726 898 | 11729 899 | 11743 900 | 11760 901 | 11771 902 | 11837 903 | 11839 904 | 11856 905 | 11876 906 | 11878 907 | 11884 908 | 11889 909 | 11896 910 | 11917 911 | 11923 912 | 11930 913 | 11944 914 | 11952 915 | 11980 916 | 11984 917 | 12214 918 | 12229 919 | 12239 920 | 12241 921 | 12242 922 | 12247 923 | 12283 924 | 12349 925 | 12369 926 | 12373 927 | 12422 928 | 12560 929 | 12566 930 | 12575 931 | 12688 932 | 12755 933 | 12768 934 | 12778 935 | 12780 936 | 12812 937 | 12832 938 | 12835 939 | 12836 940 | 12843 941 | 12847 942 | 12849 943 | 12850 944 | 12856 945 | 12858 946 | 12873 947 | 12938 948 | 12971 949 | 13017 950 | 13038 951 | 13046 952 | 13059 953 | 13085 954 | 13086 955 | 13088 956 | 13094 957 | 13134 958 | 13182 959 | 13230 960 | 13406 961 | 13444 962 | 13614 963 | 13690 964 | 13698 965 | 13709 966 | 13749 967 | 13804 968 | 13982 969 | 14051 970 | 14059 971 | 14219 972 | 14246 973 | 14256 974 | 14264 975 | 14294 976 | 14324 977 | 14367 978 | 14389 979 | 14394 980 | 14438 981 | 14442 982 | 14965 983 | 15732 984 | 16744 985 | 18037 986 | 18205 987 | 18535 988 | 18792 989 | 19102 990 | 20019 991 | 20462 992 | 21026 993 | 21045 994 | 21163 995 | 21171 996 | 21181 997 | 21196 998 | 21200 999 | 21369 1000 | 21817 -------------------------------------------------------------------------------- /data/samplers.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | 10 | 11 | class SubsetRandomSampler(torch.utils.data.Sampler): 12 | r"""Samples elements randomly from a given list of indices, without replacement. 13 | 14 | Arguments: 15 | indices (sequence): a sequence of indices 16 | """ 17 | 18 | def __init__(self, indices): 19 | self.epoch = 0 20 | self.indices = indices 21 | 22 | def __iter__(self): 23 | return (self.indices[i] for i in torch.randperm(len(self.indices))) 24 | 25 | def __len__(self): 26 | return len(self.indices) 27 | 28 | def set_epoch(self, epoch): 29 | self.epoch = epoch 30 | -------------------------------------------------------------------------------- /data/zipreader.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import zipfile 10 | import io 11 | import numpy as np 12 | from PIL import Image 13 | from PIL import ImageFile 14 | 15 | ImageFile.LOAD_TRUNCATED_IMAGES = True 16 | 17 | 18 | def is_zip_path(img_or_path): 19 | """judge if this is a zip path""" 20 | return '.zip@' in img_or_path 21 | 22 | 23 | class ZipReader(object): 24 | """A class to read zipped files""" 25 | zip_bank = dict() 26 | 27 | def __init__(self): 28 | super(ZipReader, self).__init__() 29 | 30 | @staticmethod 31 | def get_zipfile(path): 32 | zip_bank = ZipReader.zip_bank 33 | if path not in zip_bank: 34 | zfile = zipfile.ZipFile(path, 'r') 35 | zip_bank[path] = zfile 36 | return zip_bank[path] 37 | 38 | @staticmethod 39 | def split_zip_style_path(path): 40 | pos_at = path.index('@') 41 | assert pos_at != -1, "character '@' is not found from the given path '%s'" % path 42 | 43 | zip_path = path[0: pos_at] 44 | folder_path = path[pos_at + 1:] 45 | folder_path = str.strip(folder_path, '/') 46 | return zip_path, folder_path 47 | 48 | @staticmethod 49 | def list_folder(path): 50 | zip_path, folder_path = ZipReader.split_zip_style_path(path) 51 | 52 | zfile = ZipReader.get_zipfile(zip_path) 53 | folder_list = [] 54 | for file_foler_name in zfile.namelist(): 55 | file_foler_name = str.strip(file_foler_name, '/') 56 | if file_foler_name.startswith(folder_path) and \ 57 | len(os.path.splitext(file_foler_name)[-1]) == 0 and \ 58 | file_foler_name != folder_path: 59 | if len(folder_path) == 0: 60 | folder_list.append(file_foler_name) 61 | else: 62 | folder_list.append(file_foler_name[len(folder_path) + 1:]) 63 | 64 | return folder_list 65 | 66 | @staticmethod 67 | def list_files(path, extension=None): 68 | if extension is None: 69 | extension = ['.*'] 70 | zip_path, folder_path = ZipReader.split_zip_style_path(path) 71 | 72 | zfile = ZipReader.get_zipfile(zip_path) 73 | file_lists = [] 74 | for file_foler_name in zfile.namelist(): 75 | file_foler_name = str.strip(file_foler_name, '/') 76 | if file_foler_name.startswith(folder_path) and \ 77 | str.lower(os.path.splitext(file_foler_name)[-1]) in extension: 78 | if len(folder_path) == 0: 79 | file_lists.append(file_foler_name) 80 | else: 81 | file_lists.append(file_foler_name[len(folder_path) + 1:]) 82 | 83 | return file_lists 84 | 85 | @staticmethod 86 | def read(path): 87 | zip_path, path_img = ZipReader.split_zip_style_path(path) 88 | zfile = ZipReader.get_zipfile(zip_path) 89 | data = zfile.read(path_img) 90 | return data 91 | 92 | @staticmethod 93 | def imread(path): 94 | zip_path, path_img = ZipReader.split_zip_style_path(path) 95 | zfile = ZipReader.get_zipfile(zip_path) 96 | data = zfile.read(path_img) 97 | try: 98 | im = Image.open(io.BytesIO(data)) 99 | except: 100 | print("ERROR IMG LOADED: ", path_img) 101 | random_img = np.random.rand(224, 224, 3) * 255 102 | im = Image.fromarray(np.uint8(random_img)) 103 | return im 104 | -------------------------------------------------------------------------------- /figures/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/Swin-Transformer/f82860bfb5225915aca09c3227159ee9e1df874d/figures/teaser.png -------------------------------------------------------------------------------- /get_started.md: -------------------------------------------------------------------------------- 1 | # Swin Transformer for Image Classification 2 | 3 | This folder contains the implementation of the Swin Transformer for image classification. 4 | 5 | ## Model Zoo 6 | 7 | Please refer to [MODEL HUB](MODELHUB.md#imagenet-22k-pretrained-swin-moe-models) for more pre-trained models. 8 | 9 | ## Usage 10 | 11 | ### Install 12 | 13 | We recommend using the pytorch docker `nvcr>=21.05` by 14 | nvidia: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch. 15 | 16 | - Clone this repo: 17 | 18 | ```bash 19 | git clone https://github.com/microsoft/Swin-Transformer.git 20 | cd Swin-Transformer 21 | ``` 22 | 23 | - Create a conda virtual environment and activate it: 24 | 25 | ```bash 26 | conda create -n swin python=3.7 -y 27 | conda activate swin 28 | ``` 29 | 30 | - Install `CUDA>=10.2` with `cudnn>=7` following 31 | the [official installation instructions](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html) 32 | - Install `PyTorch>=1.8.0` and `torchvision>=0.9.0` with `CUDA>=10.2`: 33 | 34 | ```bash 35 | conda install pytorch==1.8.0 torchvision==0.9.0 cudatoolkit=10.2 -c pytorch 36 | ``` 37 | 38 | - Install `timm==0.4.12`: 39 | 40 | ```bash 41 | pip install timm==0.4.12 42 | ``` 43 | 44 | - Install other requirements: 45 | 46 | ```bash 47 | pip install opencv-python==4.4.0.46 termcolor==1.1.0 yacs==0.1.8 pyyaml scipy 48 | ``` 49 | 50 | - Install fused window process for acceleration, activated by passing `--fused_window_process` in the running script 51 | ```bash 52 | cd kernels/window_process 53 | python setup.py install #--user 54 | ``` 55 | 56 | ### Data preparation 57 | 58 | We use standard ImageNet dataset, you can download it from http://image-net.org/. We provide the following two ways to 59 | load data: 60 | 61 | - For standard folder dataset, move validation images to labeled sub-folders. The file structure should look like: 62 | ```bash 63 | $ tree data 64 | imagenet 65 | ├── train 66 | │ ├── class1 67 | │ │ ├── img1.jpeg 68 | │ │ ├── img2.jpeg 69 | │ │ └── ... 70 | │ ├── class2 71 | │ │ ├── img3.jpeg 72 | │ │ └── ... 73 | │ └── ... 74 | └── val 75 | ├── class1 76 | │ ├── img4.jpeg 77 | │ ├── img5.jpeg 78 | │ └── ... 79 | ├── class2 80 | │ ├── img6.jpeg 81 | │ └── ... 82 | └── ... 83 | 84 | ``` 85 | - To boost the slow speed when reading images from massive small files, we also support zipped ImageNet, which includes 86 | four files: 87 | - `train.zip`, `val.zip`: which store the zipped folder for train and validate splits. 88 | - `train_map.txt`, `val_map.txt`: which store the relative path in the corresponding zip file and ground truth 89 | label. Make sure the data folder looks like this: 90 | 91 | ```bash 92 | $ tree data 93 | data 94 | └── ImageNet-Zip 95 | ├── train_map.txt 96 | ├── train.zip 97 | ├── val_map.txt 98 | └── val.zip 99 | 100 | $ head -n 5 data/ImageNet-Zip/val_map.txt 101 | ILSVRC2012_val_00000001.JPEG 65 102 | ILSVRC2012_val_00000002.JPEG 970 103 | ILSVRC2012_val_00000003.JPEG 230 104 | ILSVRC2012_val_00000004.JPEG 809 105 | ILSVRC2012_val_00000005.JPEG 516 106 | 107 | $ head -n 5 data/ImageNet-Zip/train_map.txt 108 | n01440764/n01440764_10026.JPEG 0 109 | n01440764/n01440764_10027.JPEG 0 110 | n01440764/n01440764_10029.JPEG 0 111 | n01440764/n01440764_10040.JPEG 0 112 | n01440764/n01440764_10042.JPEG 0 113 | ``` 114 | - For ImageNet-22K dataset, make a folder named `fall11_whole` and move all images to labeled sub-folders in this 115 | folder. Then download the train-val split 116 | file ([ILSVRC2011fall_whole_map_train.txt](https://github.com/SwinTransformer/storage/releases/download/v2.0.1/ILSVRC2011fall_whole_map_train.txt) 117 | & [ILSVRC2011fall_whole_map_val.txt](https://github.com/SwinTransformer/storage/releases/download/v2.0.1/ILSVRC2011fall_whole_map_val.txt)) 118 | , and put them in the parent directory of `fall11_whole`. The file structure should look like: 119 | 120 | ```bash 121 | $ tree imagenet22k/ 122 | imagenet22k/ 123 | ├── ILSVRC2011fall_whole_map_train.txt 124 | ├── ILSVRC2011fall_whole_map_val.txt 125 | └── fall11_whole 126 | ├── n00004475 127 | ├── n00005787 128 | ├── n00006024 129 | ├── n00006484 130 | └── ... 131 | ``` 132 | 133 | ### Evaluation 134 | 135 | To evaluate a pre-trained `Swin Transformer` on ImageNet val, run: 136 | 137 | ```bash 138 | python -m torch.distributed.launch --nproc_per_node --master_port 12345 main.py --eval \ 139 | --cfg --resume --data-path 140 | ``` 141 | 142 | For example, to evaluate the `Swin-B` with a single GPU: 143 | 144 | ```bash 145 | python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.py --eval \ 146 | --cfg configs/swin/swin_base_patch4_window7_224.yaml --resume swin_base_patch4_window7_224.pth --data-path 147 | ``` 148 | 149 | ### Training from scratch on ImageNet-1K 150 | 151 | To train a `Swin Transformer` on ImageNet from scratch, run: 152 | 153 | ```bash 154 | python -m torch.distributed.launch --nproc_per_node --master_port 12345 main.py \ 155 | --cfg --data-path [--batch-size --output --tag ] 156 | ``` 157 | 158 | **Notes**: 159 | 160 | - To use zipped ImageNet instead of folder dataset, add `--zip` to the parameters. 161 | - To cache the dataset in the memory instead of reading from files every time, add `--cache-mode part`, which will 162 | shard the dataset into non-overlapping pieces for different GPUs and only load the corresponding one for each GPU. 163 | - When GPU memory is not enough, you can try the following suggestions: 164 | - Use gradient accumulation by adding `--accumulation-steps `, set appropriate `` according to your need. 165 | - Use gradient checkpointing by adding `--use-checkpoint`, e.g., it saves about 60% memory when training `Swin-B`. 166 | Please refer to [this page](https://pytorch.org/docs/stable/checkpoint.html) for more details. 167 | - We recommend using multi-node with more GPUs for training very large models, a tutorial can be found 168 | in [this page](https://pytorch.org/tutorials/intermediate/dist_tuto.html). 169 | - To change config options in general, you can use `--opts KEY1 VALUE1 KEY2 VALUE2`, e.g., 170 | `--opts TRAIN.EPOCHS 100 TRAIN.WARMUP_EPOCHS 5` will change total epochs to 100 and warm-up epochs to 5. 171 | - For additional options, see [config](config.py) and run `python main.py --help` to get detailed message. 172 | 173 | For example, to train `Swin Transformer` with 8 GPU on a single node for 300 epochs, run: 174 | 175 | `Swin-T`: 176 | 177 | ```bash 178 | python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \ 179 | --cfg configs/swin/swin_tiny_patch4_window7_224.yaml --data-path --batch-size 128 180 | ``` 181 | 182 | `Swin-S`: 183 | 184 | ```bash 185 | python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \ 186 | --cfg configs/swin/swin_small_patch4_window7_224.yaml --data-path --batch-size 128 187 | ``` 188 | 189 | `Swin-B`: 190 | 191 | ```bash 192 | python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \ 193 | --cfg configs/swin/swin_base_patch4_window7_224.yaml --data-path --batch-size 64 \ 194 | --accumulation-steps 2 [--use-checkpoint] 195 | ``` 196 | 197 | ### Pre-training on ImageNet-22K 198 | 199 | For example, to pre-train a `Swin-B` model on ImageNet-22K: 200 | 201 | ```bash 202 | python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \ 203 | --cfg configs/swin/swin_base_patch4_window7_224_22k.yaml --data-path --batch-size 64 \ 204 | --accumulation-steps 8 [--use-checkpoint] 205 | ``` 206 | 207 | ### Fine-tuning on higher resolution 208 | 209 | For example, to fine-tune a `Swin-B` model pre-trained on 224x224 resolution to 384x384 resolution: 210 | 211 | ```bashs 212 | python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \ 213 | --cfg configs/swin/swin_base_patch4_window12_384_finetune.yaml --pretrained swin_base_patch4_window7_224.pth \ 214 | --data-path --batch-size 64 --accumulation-steps 2 [--use-checkpoint] 215 | ``` 216 | 217 | ### Fine-tuning from a ImageNet-22K(21K) pre-trained model 218 | 219 | For example, to fine-tune a `Swin-B` model pre-trained on ImageNet-22K(21K): 220 | 221 | ```bashs 222 | python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \ 223 | --cfg configs/swin/swin_base_patch4_window7_224_22kto1k_finetune.yaml --pretrained swin_base_patch4_window7_224_22k.pth \ 224 | --data-path --batch-size 64 --accumulation-steps 2 [--use-checkpoint] 225 | ``` 226 | 227 | ### Throughput 228 | 229 | To measure the throughput, run: 230 | 231 | ```bash 232 | python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.py \ 233 | --cfg --data-path --batch-size 64 --throughput --disable_amp 234 | ``` 235 | 236 | 237 | ## Mixture-of-Experts Support 238 | 239 | ### Install [Tutel](https://github.com/microsoft/tutel) 240 | ```bash 241 | python3 -m pip uninstall tutel -y 242 | python3 -m pip install --user --upgrade git+https://github.com/microsoft/tutel@main 243 | ``` 244 | 245 | ### Training Swin-MoE 246 | For example, to train a `Swin-MoE-S` model with 32 experts on ImageNet-22K with 32 GPUs (4 nodes): 247 | 248 | ```bash 249 | python -m torch.distributed.launch --nproc_per_node 8 --nnode=4 \ 250 | --node_rank= --master_addr= --master_port 12345 main_moe.py \ 251 | --cfg configs/swinmoe/swin_moe_small_patch4_window12_192_32expert_32gpu_22k.yaml --data-path --batch-size 128 252 | ``` 253 | 254 | ### Evaluating Swin-MoE 255 | 256 | To evaluate a `Swin-MoE-S` with 32 experts on ImageNet-22K with 32 GPUs (4 nodes): 257 | 258 | 1. Download the zip file [swin_moe_small_patch4_window12_192_32expert_32gpu_22k.zip](https://github.com/SwinTransformer/storage/releases/download/v2.0.2/swin_moe_small_patch4_window12_192_32expert_32gpu_22k.zip) which contains the pre-trained models for each rank, and unzip them to the folder "swin_moe_small_patch4_window12_192_32expert_32gpu_22k". 259 | 2. Run the following evaluation command, note the checkpoint path should not contain the ".rank\" suffix. 260 | 261 | ```bash 262 | python -m torch.distributed.launch --nproc_per_node 8 --nnode=4 \ 263 | --node_rank= --master_addr= --master_port 12345 main_moe.py \ 264 | --cfg configs/swinmoe/swin_moe_small_patch4_window12_192_32expert_32gpu_22k.yaml --data-path --batch-size 128 \ 265 | --resume swin_moe_small_patch4_window12_192_32expert_32gpu_22k/swin_moe_small_patch4_window12_192_32expert_32gpu_22k.pth 266 | ``` 267 | 268 | More Swin-MoE models can be found in [MODEL HUB](MODELHUB.md#imagenet-22k-pretrained-swin-moe-models) 269 | 270 | ## SimMIM Support 271 | 272 | ### Evaluating provided models 273 | 274 | To evaluate a provided model on ImageNet validation set, run: 275 | ```bash 276 | python -m torch.distributed.launch --nproc_per_node main_simmim_ft.py \ 277 | --eval --cfg --resume --data-path 278 | ``` 279 | 280 | For example, to evaluate the `Swin Base` model on a single GPU, run: 281 | ```bash 282 | python -m torch.distributed.launch --nproc_per_node 1 main_simmim_ft.py \ 283 | --eval --cfg configs/simmim/simmim_finetune__swin_base__img224_window7__800ep.yaml --resume simmim_finetune__swin_base__img224_window7__800ep.pth --data-path 284 | ``` 285 | 286 | ### Pre-training with SimMIM 287 | To pre-train models with `SimMIM`, run: 288 | ```bash 289 | python -m torch.distributed.launch --nproc_per_node main_simmim_pt.py \ 290 | --cfg --data-path /train [--batch-size --output --tag ] 291 | ``` 292 | 293 | For example, to pre-train `Swin Base` for 800 epochs on one DGX-2 server, run: 294 | ```bash 295 | python -m torch.distributed.launch --nproc_per_node 16 main_simmim_pt.py \ 296 | --cfg configs/simmim/simmim_pretrain__swin_base__img192_window6__800ep.yaml --batch-size 128 --data-path /train [--output --tag ] 297 | ``` 298 | 299 | ### Fine-tuning pre-trained models 300 | To fine-tune models pre-trained by `SimMIM`, run: 301 | ```bash 302 | python -m torch.distributed.launch --nproc_per_node main_simmim_ft.py \ 303 | --cfg --data-path --pretrained [--batch-size --output --tag ] 304 | ``` 305 | 306 | For example, to fine-tune `Swin Base` pre-trained by `SimMIM` on one DGX-2 server, run: 307 | ```bash 308 | python -m torch.distributed.launch --nproc_per_node 16 main_simmim_ft.py \ 309 | --cfg configs/simmim/simmim_finetune__swin_base__img224_window7__800ep.yaml --batch-size 128 --data-path --pretrained [--output --tag ] 310 | ``` -------------------------------------------------------------------------------- /kernels/window_process/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | 5 | setup(name='swin_window_process', 6 | ext_modules=[ 7 | CUDAExtension('swin_window_process', [ 8 | 'swin_window_process.cpp', 9 | 'swin_window_process_kernel.cu', 10 | ]) 11 | ], 12 | cmdclass={'build_ext': BuildExtension}) -------------------------------------------------------------------------------- /kernels/window_process/swin_window_process.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include 18 | #include 19 | 20 | 21 | at::Tensor roll_and_window_partition_forward_cuda( 22 | at::Tensor & input, 23 | //at::Tensor & output, 24 | const int B, 25 | const int H, 26 | const int W, 27 | const int C, 28 | const int shift_size, 29 | const int window_size); 30 | 31 | 32 | at::Tensor roll_and_window_partition_backward_cuda( 33 | at::Tensor & grad_in, 34 | //at::Tensor & grad_out, 35 | const int B, 36 | const int H, 37 | const int W, 38 | const int C, 39 | const int shift_size, 40 | const int window_size); 41 | 42 | 43 | at::Tensor window_merge_and_roll_forward_cuda( 44 | at::Tensor & input, 45 | //at::Tensor & output, 46 | const int B, 47 | const int H, 48 | const int W, 49 | const int C, 50 | const int shift_size, 51 | const int window_size); 52 | 53 | at::Tensor window_merge_and_roll_backward_cuda( 54 | at::Tensor & grad_in, 55 | //at::Tensor & grad_out, 56 | const int B, 57 | const int H, 58 | const int W, 59 | const int C, 60 | const int shift_size, 61 | const int window_size); 62 | 63 | 64 | #define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") 65 | #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") 66 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 67 | 68 | 69 | 70 | at::Tensor roll_and_window_partition_forward( 71 | at::Tensor & input, 72 | //at::Tensor & output, 73 | const int B, 74 | const int H, 75 | const int W, 76 | const int C, 77 | const int shift_size, 78 | const int window_size){ 79 | CHECK_INPUT(input); 80 | return roll_and_window_partition_forward_cuda(input, B, H, W, C, shift_size, window_size); 81 | } 82 | 83 | 84 | at::Tensor roll_and_window_partition_backward( 85 | at::Tensor & grad_in, 86 | //at::Tensor & grad_out, 87 | const int B, 88 | const int H, 89 | const int W, 90 | const int C, 91 | const int shift_size, 92 | const int window_size){ 93 | CHECK_INPUT(grad_in); 94 | return roll_and_window_partition_backward_cuda(grad_in, B, H, W, C, shift_size, window_size); 95 | } 96 | 97 | 98 | at::Tensor window_merge_and_roll_forward( 99 | at::Tensor & input, 100 | //at::Tensor & output, 101 | const int B, 102 | const int H, 103 | const int W, 104 | const int C, 105 | const int shift_size, 106 | const int window_size){ 107 | CHECK_INPUT(input); 108 | return window_merge_and_roll_forward_cuda(input, B, H, W, C, shift_size, window_size); 109 | } 110 | 111 | 112 | at::Tensor window_merge_and_roll_backward( 113 | at::Tensor & grad_in, 114 | //at::Tensor & grad_out, 115 | const int B, 116 | const int H, 117 | const int W, 118 | const int C, 119 | const int shift_size, 120 | const int window_size){ 121 | CHECK_INPUT(grad_in); 122 | return window_merge_and_roll_backward_cuda(grad_in, B, H, W, C, shift_size, window_size); 123 | } 124 | 125 | 126 | 127 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 128 | m.def("roll_and_window_partition_forward", &roll_and_window_partition_forward, "torch.roll and window_partition."); 129 | m.def("roll_and_window_partition_backward", &roll_and_window_partition_backward, "torch.roll and window_partition."); 130 | m.def("window_merge_and_roll_forward", &window_merge_and_roll_forward, "window merge and torch.roll."); 131 | m.def("window_merge_and_roll_backward", &window_merge_and_roll_backward, "window merge and torch.roll."); 132 | } -------------------------------------------------------------------------------- /kernels/window_process/swin_window_process_kernel.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. 3 | * 4 | * Licensed under the Apache License, Version 2.0 (the "License"); 5 | * you may not use this file except in compliance with the License. 6 | * You may obtain a copy of the License at 7 | * 8 | * http://www.apache.org/licenses/LICENSE-2.0 9 | * 10 | * Unless required by applicable law or agreed to in writing, software 11 | * distributed under the License is distributed on an "AS IS" BASIS, 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | * See the License for the specific language governing permissions and 14 | * limitations under the License. 15 | */ 16 | 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | int best_block_dim(int feat_dim){ 25 | int best_dim; 26 | if (feat_dim < 384){ 27 | best_dim = 64; 28 | } 29 | else{ 30 | if (feat_dim < 1024){ 31 | best_dim = 128; 32 | } 33 | else{ 34 | best_dim = 256; 35 | } 36 | } 37 | return best_dim; 38 | } 39 | 40 | 41 | template 42 | __global__ void roll_and_window_partition_forward_cuda_kernel( 43 | T* input, 44 | T* output, 45 | const int B, 46 | const int H, 47 | const int W, 48 | const int C, 49 | const int shift_size, 50 | const int window_size, 51 | const int nH, 52 | const int nW){ 53 | // start 54 | //bool qual = threadIdx.x < C; 55 | int index = threadIdx.x; 56 | int offset; 57 | for (int i = index; i < C; i += blockDim.x) { 58 | offset = ((blockIdx.z * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x) * C + i; // C = blocksize 59 | int input_offset = blockIdx.z / (nH * nW) * H * W * C + 60 | (blockIdx.z % (nH * nW) / nW * window_size + blockIdx.y - shift_size + H) % H * W * C + 61 | (blockIdx.z % nW * window_size + blockIdx.x - shift_size + W) % W * C + 62 | i; 63 | output[offset] = (T)(__ldg(input + input_offset)); 64 | } 65 | } 66 | 67 | 68 | template 69 | __global__ void roll_and_window_partition_backward_cuda_kernel( 70 | T* grad_in, 71 | T* grad_out, 72 | const int B, 73 | const int H, 74 | const int W, 75 | const int C, 76 | const int shift_size, 77 | const int window_size, 78 | const int nH, 79 | const int nW){ 80 | // start 81 | int index = threadIdx.x; 82 | int offset; 83 | for (int i = index; i < C; i += blockDim.x) { 84 | offset = ((blockIdx.z * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x) * C + i; // C = blocksize 85 | int input_offset = 86 | (blockIdx.z * nH * nW + (blockIdx.y + shift_size + H) % H / window_size * nW + (blockIdx.x + shift_size + W) % W / window_size) * window_size * window_size * C + 87 | (blockIdx.y + shift_size + H ) % H % window_size * window_size * C + 88 | (blockIdx.x + shift_size + W ) % W % window_size * C + 89 | i; 90 | grad_out[offset] = (T)(__ldg(grad_in + input_offset)); 91 | } 92 | } 93 | 94 | 95 | template 96 | __global__ void window_merge_and_roll_forward_cuda_kernel( 97 | T* input, 98 | T* output, 99 | const int B, 100 | const int H, 101 | const int W, 102 | const int C, 103 | const int shift_size, 104 | const int window_size, 105 | const int nH, 106 | const int nW){ 107 | // start 108 | int index = threadIdx.x; 109 | int offset; 110 | for (int i = index; i < C; i += blockDim.x) { 111 | offset = ((blockIdx.z * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x) * C + i; // C = blocksize 112 | int input_offset = 113 | (blockIdx.z * nH * nW + (blockIdx.y - shift_size + H) % H / window_size * nH + (blockIdx.x - shift_size + W) % W / window_size) * window_size * window_size * C + 114 | (blockIdx.y - shift_size + H) % window_size * window_size * C + 115 | (blockIdx.x - shift_size + W) % window_size * C + 116 | i; 117 | output[offset] = (T)(__ldg(input + input_offset)); 118 | } 119 | } 120 | 121 | 122 | 123 | template 124 | __global__ void window_merge_and_roll_backward_cuda_kernel( 125 | T* grad_in, 126 | T* grad_out, 127 | const int B, 128 | const int H, 129 | const int W, 130 | const int C, 131 | const int shift_size, 132 | const int window_size, 133 | const int nH, 134 | const int nW){ 135 | // start 136 | int index = threadIdx.x; 137 | int offset; 138 | for (int i = index; i < C; i += blockDim.x) { 139 | offset = ((blockIdx.z * gridDim.y + blockIdx.y) * gridDim.x + blockIdx.x) * C + i; // C = blocksize 140 | int input_offset = 141 | (blockIdx.z / (nH * nW)) * H * W * C + 142 | (blockIdx.z % (nH * nW) / nW * window_size + blockIdx.y + shift_size + H) % H * W * C + 143 | (blockIdx.z % nW * window_size + blockIdx.x + shift_size + W) % W * C + 144 | i; 145 | grad_out[offset] = (T)(__ldg(grad_in + input_offset)); 146 | } 147 | } 148 | 149 | // input: [B, H, W, C] 150 | // output: [B*nH*nW, window_size, window_size, C] 151 | at::Tensor roll_and_window_partition_forward_cuda( 152 | at::Tensor & input, 153 | //at::Tensor & output, 154 | const int B, 155 | const int H, 156 | const int W, 157 | const int C, 158 | const int shift_size, 159 | const int window_size){ 160 | 161 | int nH = H / window_size; 162 | int nW = W / window_size; 163 | 164 | dim3 grid(window_size, window_size, B * nH * nW); 165 | //dim3 block((C + 31) / 32 * 32); 166 | int blocknum = best_block_dim(C); 167 | dim3 block(blocknum); 168 | 169 | at::Tensor output; 170 | if (input.scalar_type() == torch::kFloat16){ 171 | output = torch::empty({B*nH*nW, window_size, window_size, C}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(true)); 172 | } 173 | else{ 174 | output = torch::empty({B*nH*nW, window_size, window_size, C}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(true)); 175 | } 176 | 177 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "roll_and_window_partition_forward_cuda_kernel", ([&] { 178 | roll_and_window_partition_forward_cuda_kernel<<>>( 179 | input.data(), 180 | output.data(), 181 | B, 182 | H, 183 | W, 184 | C, 185 | shift_size, 186 | window_size, 187 | nH, 188 | nW); 189 | })); 190 | return output; 191 | } 192 | 193 | 194 | // grad_in: [B*nH*nW, window_size, window_size, C] 195 | // grad_out: [B, H, W, C] 196 | at::Tensor roll_and_window_partition_backward_cuda( 197 | at::Tensor & grad_in, 198 | const int B, 199 | const int H, 200 | const int W, 201 | const int C, 202 | const int shift_size, 203 | const int window_size){ 204 | 205 | int nH = H / window_size; 206 | int nW = W / window_size; 207 | 208 | dim3 grid(W, H, B); 209 | //dim3 block((C + 31) / 32 * 32); 210 | int blocknum = best_block_dim(C); 211 | dim3 block(blocknum); 212 | 213 | at::Tensor grad_out; 214 | if (grad_in.scalar_type() == torch::kFloat16){ 215 | grad_out = torch::empty({B, H, W, C}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(false)); 216 | } 217 | else{ 218 | grad_out = torch::empty({B, H, W, C}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false)); 219 | } 220 | 221 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_in.type(), "roll_and_window_partition_backward_cuda_kernel", ([&] { 222 | roll_and_window_partition_backward_cuda_kernel<<>>( 223 | grad_in.data(), 224 | grad_out.data(), 225 | B, 226 | H, 227 | W, 228 | C, 229 | shift_size, 230 | window_size, 231 | nH, 232 | nW); 233 | })); 234 | return grad_out; 235 | } 236 | 237 | 238 | // input: [B*nH*nW, window_size, window_size, C] 239 | // output: [B, H, W, C] 240 | at::Tensor window_merge_and_roll_forward_cuda( 241 | at::Tensor & input, 242 | //at::Tensor & output, 243 | const int B, 244 | const int H, 245 | const int W, 246 | const int C, 247 | const int shift_size, 248 | const int window_size){ 249 | 250 | int nH = H / window_size; 251 | int nW = W / window_size; 252 | 253 | dim3 grid(W, H, B); 254 | //dim3 block((C + 31) / 32 * 32); 255 | int blocknum = best_block_dim(C); 256 | dim3 block(blocknum); 257 | 258 | //generate output tensor inside 259 | at::Tensor output; 260 | if (input.scalar_type() == torch::kFloat16){ 261 | output = torch::empty({B, H, W, C}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(true)); 262 | } 263 | else{ 264 | output = torch::empty({B, H, W, C}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(true)); 265 | } 266 | 267 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "window_merge_and_roll_forward_cuda_kernel", ([&] { 268 | window_merge_and_roll_forward_cuda_kernel<<>>( 269 | input.data(), 270 | output.data(), 271 | B, 272 | H, 273 | W, 274 | C, 275 | shift_size, 276 | window_size, 277 | nH, 278 | nW); 279 | })); 280 | return output; 281 | } 282 | 283 | 284 | at::Tensor window_merge_and_roll_backward_cuda( 285 | at::Tensor & grad_in, 286 | const int B, 287 | const int H, 288 | const int W, 289 | const int C, 290 | const int shift_size, 291 | const int window_size){ 292 | 293 | int nH = H / window_size; 294 | int nW = W / window_size; 295 | 296 | dim3 grid(window_size, window_size, B * nH * nW); 297 | //dim3 block((C + 31) / 32 * 32); 298 | int blocknum = best_block_dim(C); 299 | dim3 block(blocknum); 300 | 301 | at::Tensor grad_out; 302 | if (grad_in.scalar_type() == torch::kFloat16){ 303 | grad_out = torch::empty({B*nH*nW, window_size, window_size, C}, torch::dtype(torch::kFloat16).device(torch::kCUDA).requires_grad(false)); 304 | } 305 | else{ 306 | grad_out = torch::empty({B*nH*nW, window_size, window_size, C}, torch::dtype(torch::kFloat32).device(torch::kCUDA).requires_grad(false)); 307 | } 308 | 309 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad_in.type(), "window_merge_and_roll_backward_cuda_kernel", ([&] { 310 | window_merge_and_roll_backward_cuda_kernel<<>>( 311 | grad_in.data(), 312 | grad_out.data(), 313 | B, 314 | H, 315 | W, 316 | C, 317 | shift_size, 318 | window_size, 319 | nH, 320 | nW); 321 | })); 322 | return grad_out; 323 | } -------------------------------------------------------------------------------- /kernels/window_process/unit_test.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fused kernel for window process for SwinTransformer 3 | # Copyright (c) 2022 Nvidia 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | import torch 8 | import swin_window_process 9 | import random 10 | import time 11 | import unittest 12 | 13 | 14 | class WindowProcess(torch.autograd.Function): 15 | @staticmethod 16 | def forward(ctx, input, B, H, W, C, shift_size, window_size): 17 | output = swin_window_process.roll_and_window_partition_forward(input, B, H, W, C, shift_size, window_size) 18 | 19 | ctx.B = B 20 | ctx.H = H 21 | ctx.W = W 22 | ctx.C = C 23 | ctx.shift_size = shift_size 24 | ctx.window_size = window_size 25 | return output 26 | 27 | @staticmethod 28 | def backward(ctx, grad_in): 29 | B = ctx.B 30 | H = ctx.H 31 | W = ctx.W 32 | C = ctx.C 33 | shift_size = ctx.shift_size 34 | window_size = ctx.window_size 35 | 36 | grad_out = swin_window_process.roll_and_window_partition_backward(grad_in, B, H, W, C, shift_size, window_size) 37 | return grad_out, None, None, None, None, None, None, None 38 | 39 | 40 | class WindowProcessReverse(torch.autograd.Function): 41 | @staticmethod 42 | def forward(ctx, input, B, H, W, C, shift_size, window_size): 43 | output = swin_window_process.window_merge_and_roll_forward(input, B, H, W, C, shift_size, window_size) 44 | 45 | ctx.B = B 46 | ctx.H = H 47 | ctx.W = W 48 | ctx.C = C 49 | ctx.shift_size = shift_size 50 | ctx.window_size = window_size 51 | 52 | return output 53 | 54 | @staticmethod 55 | def backward(ctx, grad_in): 56 | B = ctx.B 57 | H = ctx.H 58 | W = ctx.W 59 | C = ctx.C 60 | shift_size = ctx.shift_size 61 | window_size = ctx.window_size 62 | 63 | grad_out = swin_window_process.window_merge_and_roll_backward(grad_in, B, H, W, C, shift_size, window_size) 64 | return grad_out, None, None, None, None, None, None, None 65 | 66 | 67 | def window_partition(x, window_size): 68 | """ 69 | Args: 70 | x: (B, H, W, C) 71 | window_size (int): window size 72 | Returns: 73 | windows: (num_windows*B, window_size, window_size, C) 74 | """ 75 | B, H, W, C = x.shape 76 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 77 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 78 | return windows 79 | 80 | def window_reverse(windows, window_size, H, W): 81 | """ 82 | Args: 83 | windows: (num_windows*B, window_size, window_size, C) 84 | window_size (int): Window size 85 | H (int): Height of image 86 | W (int): Width of image 87 | Returns: 88 | x: (B, H, W, C) 89 | """ 90 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 91 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 92 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 93 | return x 94 | 95 | 96 | def pyt_forward(x, shift_size, window_size): 97 | # x in shape(B, H, W, C) 98 | # cyclic shift 99 | if shift_size > 0: 100 | shifted_x = torch.roll(x, shifts=(-shift_size, -shift_size), dims=(1, 2)) 101 | else: 102 | shifted_x = x 103 | # partition windows 104 | x_windows = window_partition(shifted_x, window_size) 105 | return x_windows 106 | 107 | 108 | def reverse_pyt_forward(attn_windows, shift_size, window_size, H, W): 109 | # x in shape(B*nH*nW, window_size, window_size, C) 110 | shifted_x = window_reverse(attn_windows, window_size, H, W) 111 | if shift_size > 0: 112 | x = torch.roll(shifted_x, shifts=(shift_size, shift_size), dims=(1, 2)) 113 | else: 114 | x = shifted_x 115 | return x 116 | 117 | 118 | def copy_one_tensor(input, requires_grad=True): 119 | input1 = input.clone().detach().requires_grad_(requires_grad).cuda() 120 | return input1 121 | 122 | class Test_WindowProcess(unittest.TestCase): 123 | def setUp(self): 124 | self.B = 192 125 | self.H = 56 126 | self.W = 56 127 | self.C = 96 128 | self.shift_size = 2 129 | self.window_size = 7 130 | self.nH = self.H // self.window_size 131 | self.nW = self.W // self.window_size 132 | 133 | def test_roll_and_window_partition_forward(self, dtype=torch.float32): 134 | input = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda() 135 | 136 | input1 = copy_one_tensor(input, True) 137 | input2 = copy_one_tensor(input, True) 138 | 139 | with torch.no_grad(): 140 | # ori 141 | expected = pyt_forward(input1, self.shift_size, self.window_size) 142 | # fused kernel 143 | fused_output = WindowProcess.apply(input2, self.B, self.H, self.W, self.C, -self.shift_size, self.window_size) 144 | 145 | self.assertTrue(torch.equal(expected, fused_output)) 146 | #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08)) 147 | 148 | def test_roll_and_window_partition_backward(self, dtype=torch.float32): 149 | input = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda() 150 | d_loss_tensor = torch.randn((self.B*self.nW*self.nH, self.window_size, self.window_size, self.C), dtype=dtype).cuda() 151 | 152 | input1 = copy_one_tensor(input, True) 153 | input2 = copy_one_tensor(input, True) 154 | 155 | # ori 156 | expected = pyt_forward(input1, self.shift_size, self.window_size) 157 | expected.backward(d_loss_tensor) 158 | # fused kernel 159 | fused_output = WindowProcess.apply(input2, self.B, self.H, self.W, self.C, -self.shift_size, self.window_size) 160 | fused_output.backward(d_loss_tensor) 161 | 162 | self.assertTrue(torch.equal(expected, fused_output)) 163 | #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08)) 164 | 165 | def test_window_merge_and_roll_forward(self, dtype=torch.float32): 166 | input = torch.randn((self.B*self.nH*self.nW, self.window_size, self.window_size, self.C), dtype=dtype, requires_grad=True).cuda() 167 | 168 | input1 = copy_one_tensor(input, True) 169 | input2 = copy_one_tensor(input, True) 170 | 171 | with torch.no_grad(): 172 | # ori 173 | expected = reverse_pyt_forward(input1, self.shift_size, self.window_size, self.H, self.W) 174 | # fused kernel 175 | fused_output = WindowProcessReverse.apply(input2, self.B, self.H, self.W, self.C, self.shift_size, self.window_size) 176 | 177 | self.assertTrue(torch.equal(expected, fused_output)) 178 | #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08)) 179 | 180 | 181 | def test_window_merge_and_roll_backward(self, dtype=torch.float32): 182 | input = torch.randn((self.B*self.nH*self.nW, self.window_size, self.window_size, self.C), dtype=dtype, requires_grad=True).cuda() 183 | d_loss_tensor = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda() 184 | 185 | input1 = copy_one_tensor(input, True) 186 | input2 = copy_one_tensor(input, True) 187 | 188 | # ori 189 | expected = reverse_pyt_forward(input1, self.shift_size, self.window_size, self.H, self.W) 190 | expected.backward(d_loss_tensor) 191 | # fused kernel 192 | fused_output = WindowProcessReverse.apply(input2, self.B, self.H, self.W, self.C, self.shift_size, self.window_size) 193 | fused_output.backward(d_loss_tensor) 194 | 195 | self.assertTrue(torch.equal(expected, fused_output)) 196 | #self.assertTrue(torch.allclose(expected, fused_output, rtol=1e-05, atol=1e-08)) 197 | 198 | def test_forward_backward_speed(self, dtype=torch.float32, times=1000): 199 | input = torch.randn((self.B*self.nH*self.nW, self.window_size, self.window_size, self.C), dtype=dtype, requires_grad=True).cuda() 200 | d_loss_tensor = torch.randn((self.B, self.H, self.W, self.C), dtype=dtype, requires_grad=True).cuda() 201 | 202 | input1 = copy_one_tensor(input, True) 203 | input2 = copy_one_tensor(input, True) 204 | 205 | # SwinTransformer official 206 | def run_pyt(t=1000): 207 | for _ in range(t): 208 | expected = reverse_pyt_forward(input1, self.shift_size, self.window_size, self.H, self.W) 209 | expected.backward(d_loss_tensor) 210 | 211 | # my op 212 | def run_fusedop(t=1000): 213 | for _ in range(t): 214 | fused_output = WindowProcessReverse.apply(input2, self.B, self.H, self.W, self.C, self.shift_size, self.window_size) 215 | fused_output.backward(d_loss_tensor) 216 | 217 | torch.cuda.synchronize() 218 | t1 = time.time() 219 | run_pyt(t=times) 220 | torch.cuda.synchronize() 221 | t2 = time.time() 222 | run_fusedop(t=times) 223 | torch.cuda.synchronize() 224 | t3 = time.time() 225 | self.assertTrue((t3 - t2) < (t2 - t1)) 226 | 227 | print('Run {} times'.format(times)) 228 | print('Original time cost: {}'.format(t2 - t1)) 229 | print('Fused op time cost: {}'.format(t3 - t2)) 230 | 231 | def test_roll_and_window_partition_forward_fp16(self, dtype=torch.float16): 232 | self.test_roll_and_window_partition_forward(dtype=dtype) 233 | 234 | def test_roll_and_window_partition_backward_fp16(self, dtype=torch.float16): 235 | self.test_roll_and_window_partition_backward(dtype=dtype) 236 | 237 | def test_window_merge_and_roll_forward_fp16(self, dtype=torch.float16): 238 | self.test_window_merge_and_roll_forward(dtype=dtype) 239 | 240 | def test_window_merge_and_roll_backward_fp16(self, dtype=torch.float16): 241 | self.test_window_merge_and_roll_backward(dtype=dtype) 242 | 243 | def test_forward_backward_speed_fp16(self, dtype=torch.float16, times=1000): 244 | self.test_forward_backward_speed(dtype=dtype, times=times) 245 | 246 | 247 | if __name__ == '__main__': 248 | print('Pass only two tensors are exactly the same (using torch.equal).\n') 249 | torch.manual_seed(0) 250 | unittest.main(verbosity=2) 251 | -------------------------------------------------------------------------------- /kernels/window_process/window_process.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fused kernel for window process for SwinTransformer 3 | # Copyright (c) 2022 Nvidia 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # -------------------------------------------------------- 6 | 7 | import torch 8 | import swin_window_process 9 | 10 | 11 | class WindowProcess(torch.autograd.Function): 12 | @staticmethod 13 | def forward(ctx, input, B, H, W, C, shift_size, window_size): 14 | output = swin_window_process.roll_and_window_partition_forward(input, B, H, W, C, shift_size, window_size) 15 | 16 | ctx.B = B 17 | ctx.H = H 18 | ctx.W = W 19 | ctx.C = C 20 | ctx.shift_size = shift_size 21 | ctx.window_size = window_size 22 | return output 23 | 24 | @staticmethod 25 | def backward(ctx, grad_in): 26 | B = ctx.B 27 | H = ctx.H 28 | W = ctx.W 29 | C = ctx.C 30 | shift_size = ctx.shift_size 31 | window_size = ctx.window_size 32 | 33 | grad_out = swin_window_process.roll_and_window_partition_backward(grad_in, B, H, W, C, shift_size, window_size) 34 | return grad_out, None, None, None, None, None, None, None 35 | 36 | 37 | class WindowProcessReverse(torch.autograd.Function): 38 | @staticmethod 39 | def forward(ctx, input, B, H, W, C, shift_size, window_size): 40 | output = swin_window_process.window_merge_and_roll_forward(input, B, H, W, C, shift_size, window_size) 41 | 42 | ctx.B = B 43 | ctx.H = H 44 | ctx.W = W 45 | ctx.C = C 46 | ctx.shift_size = shift_size 47 | ctx.window_size = window_size 48 | 49 | return output 50 | 51 | @staticmethod 52 | def backward(ctx, grad_in): 53 | B = ctx.B 54 | H = ctx.H 55 | W = ctx.W 56 | C = ctx.C 57 | shift_size = ctx.shift_size 58 | window_size = ctx.window_size 59 | 60 | #grad_out = ctx.saved_tensors[0] 61 | #grad_out = torch.zeros((B, H, W, C), dtype=dtype).cuda() 62 | grad_out = swin_window_process.window_merge_and_roll_backward(grad_in, B, H, W, C, shift_size, window_size) 63 | return grad_out, None, None, None, None, None, None, None 64 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import sys 10 | import logging 11 | import functools 12 | from termcolor import colored 13 | 14 | 15 | @functools.lru_cache() 16 | def create_logger(output_dir, dist_rank=0, name=''): 17 | # create logger 18 | logger = logging.getLogger(name) 19 | logger.setLevel(logging.DEBUG) 20 | logger.propagate = False 21 | 22 | # create formatter 23 | fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s' 24 | color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \ 25 | colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s' 26 | 27 | # create console handlers for master process 28 | if dist_rank == 0: 29 | console_handler = logging.StreamHandler(sys.stdout) 30 | console_handler.setLevel(logging.DEBUG) 31 | console_handler.setFormatter( 32 | logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S')) 33 | logger.addHandler(console_handler) 34 | 35 | # create file handlers 36 | file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a') 37 | file_handler.setLevel(logging.DEBUG) 38 | file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S')) 39 | logger.addHandler(file_handler) 40 | 41 | return logger 42 | -------------------------------------------------------------------------------- /lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import bisect 9 | 10 | import torch 11 | from timm.scheduler.cosine_lr import CosineLRScheduler 12 | from timm.scheduler.step_lr import StepLRScheduler 13 | from timm.scheduler.scheduler import Scheduler 14 | 15 | 16 | def build_scheduler(config, optimizer, n_iter_per_epoch): 17 | num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch) 18 | warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch) 19 | decay_steps = int(config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch) 20 | multi_steps = [i * n_iter_per_epoch for i in config.TRAIN.LR_SCHEDULER.MULTISTEPS] 21 | 22 | lr_scheduler = None 23 | if config.TRAIN.LR_SCHEDULER.NAME == 'cosine': 24 | lr_scheduler = CosineLRScheduler( 25 | optimizer, 26 | t_initial=(num_steps - warmup_steps) if config.TRAIN.LR_SCHEDULER.WARMUP_PREFIX else num_steps, 27 | t_mul=1., 28 | lr_min=config.TRAIN.MIN_LR, 29 | warmup_lr_init=config.TRAIN.WARMUP_LR, 30 | warmup_t=warmup_steps, 31 | cycle_limit=1, 32 | t_in_epochs=False, 33 | warmup_prefix=config.TRAIN.LR_SCHEDULER.WARMUP_PREFIX, 34 | ) 35 | elif config.TRAIN.LR_SCHEDULER.NAME == 'linear': 36 | lr_scheduler = LinearLRScheduler( 37 | optimizer, 38 | t_initial=num_steps, 39 | lr_min_rate=0.01, 40 | warmup_lr_init=config.TRAIN.WARMUP_LR, 41 | warmup_t=warmup_steps, 42 | t_in_epochs=False, 43 | ) 44 | elif config.TRAIN.LR_SCHEDULER.NAME == 'step': 45 | lr_scheduler = StepLRScheduler( 46 | optimizer, 47 | decay_t=decay_steps, 48 | decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE, 49 | warmup_lr_init=config.TRAIN.WARMUP_LR, 50 | warmup_t=warmup_steps, 51 | t_in_epochs=False, 52 | ) 53 | elif config.TRAIN.LR_SCHEDULER.NAME == 'multistep': 54 | lr_scheduler = MultiStepLRScheduler( 55 | optimizer, 56 | milestones=multi_steps, 57 | gamma=config.TRAIN.LR_SCHEDULER.GAMMA, 58 | warmup_lr_init=config.TRAIN.WARMUP_LR, 59 | warmup_t=warmup_steps, 60 | t_in_epochs=False, 61 | ) 62 | 63 | return lr_scheduler 64 | 65 | 66 | class LinearLRScheduler(Scheduler): 67 | def __init__(self, 68 | optimizer: torch.optim.Optimizer, 69 | t_initial: int, 70 | lr_min_rate: float, 71 | warmup_t=0, 72 | warmup_lr_init=0., 73 | t_in_epochs=True, 74 | noise_range_t=None, 75 | noise_pct=0.67, 76 | noise_std=1.0, 77 | noise_seed=42, 78 | initialize=True, 79 | ) -> None: 80 | super().__init__( 81 | optimizer, param_group_field="lr", 82 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 83 | initialize=initialize) 84 | 85 | self.t_initial = t_initial 86 | self.lr_min_rate = lr_min_rate 87 | self.warmup_t = warmup_t 88 | self.warmup_lr_init = warmup_lr_init 89 | self.t_in_epochs = t_in_epochs 90 | if self.warmup_t: 91 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 92 | super().update_groups(self.warmup_lr_init) 93 | else: 94 | self.warmup_steps = [1 for _ in self.base_values] 95 | 96 | def _get_lr(self, t): 97 | if t < self.warmup_t: 98 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 99 | else: 100 | t = t - self.warmup_t 101 | total_t = self.t_initial - self.warmup_t 102 | lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values] 103 | return lrs 104 | 105 | def get_epoch_values(self, epoch: int): 106 | if self.t_in_epochs: 107 | return self._get_lr(epoch) 108 | else: 109 | return None 110 | 111 | def get_update_values(self, num_updates: int): 112 | if not self.t_in_epochs: 113 | return self._get_lr(num_updates) 114 | else: 115 | return None 116 | 117 | 118 | class MultiStepLRScheduler(Scheduler): 119 | def __init__(self, optimizer: torch.optim.Optimizer, milestones, gamma=0.1, warmup_t=0, warmup_lr_init=0, t_in_epochs=True) -> None: 120 | super().__init__(optimizer, param_group_field="lr") 121 | 122 | self.milestones = milestones 123 | self.gamma = gamma 124 | self.warmup_t = warmup_t 125 | self.warmup_lr_init = warmup_lr_init 126 | self.t_in_epochs = t_in_epochs 127 | if self.warmup_t: 128 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 129 | super().update_groups(self.warmup_lr_init) 130 | else: 131 | self.warmup_steps = [1 for _ in self.base_values] 132 | 133 | assert self.warmup_t <= min(self.milestones) 134 | 135 | def _get_lr(self, t): 136 | if t < self.warmup_t: 137 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 138 | else: 139 | lrs = [v * (self.gamma ** bisect.bisect_right(self.milestones, t)) for v in self.base_values] 140 | return lrs 141 | 142 | def get_epoch_values(self, epoch: int): 143 | if self.t_in_epochs: 144 | return self._get_lr(epoch) 145 | else: 146 | return None 147 | 148 | def get_update_values(self, num_updates: int): 149 | if not self.t_in_epochs: 150 | return self._get_lr(num_updates) 151 | else: 152 | return None 153 | -------------------------------------------------------------------------------- /main_simmim_ft.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # SimMIM 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # Modified by Zhenda Xie 7 | # -------------------------------------------------------- 8 | 9 | import os 10 | import time 11 | import argparse 12 | import datetime 13 | import numpy as np 14 | 15 | import torch 16 | import torch.backends.cudnn as cudnn 17 | import torch.distributed as dist 18 | import torch.cuda.amp as amp 19 | 20 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 21 | from timm.utils import accuracy, AverageMeter 22 | 23 | from config import get_config 24 | from models import build_model 25 | from data import build_loader 26 | from lr_scheduler import build_scheduler 27 | from optimizer import build_optimizer 28 | from logger import create_logger 29 | from utils_simmim import load_checkpoint, load_pretrained, save_checkpoint, get_grad_norm, auto_resume_helper, \ 30 | reduce_tensor 31 | 32 | # pytorch major version (1.x or 2.x) 33 | PYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0]) 34 | 35 | 36 | def parse_option(): 37 | parser = argparse.ArgumentParser('SimMIM fine-tuning script', add_help=False) 38 | parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', ) 39 | parser.add_argument( 40 | "--opts", 41 | help="Modify config options by adding 'KEY VALUE' pairs. ", 42 | default=None, 43 | nargs='+', 44 | ) 45 | 46 | # easy config modification 47 | parser.add_argument('--batch-size', type=int, help="batch size for single GPU") 48 | parser.add_argument('--data-path', type=str, help='path to dataset') 49 | parser.add_argument('--pretrained', type=str, help='path to pre-trained model') 50 | parser.add_argument('--resume', help='resume from checkpoint') 51 | parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps") 52 | parser.add_argument('--use-checkpoint', action='store_true', 53 | help="whether to use gradient checkpointing to save memory") 54 | parser.add_argument('--enable-amp', action='store_true') 55 | parser.add_argument('--disable-amp', action='store_false', dest='enable_amp') 56 | parser.set_defaults(enable_amp=True) 57 | parser.add_argument('--output', default='output', type=str, metavar='PATH', 58 | help='root of output folder, the full path is // (default: output)') 59 | parser.add_argument('--tag', help='tag of experiment') 60 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 61 | parser.add_argument('--throughput', action='store_true', help='Test throughput only') 62 | 63 | # distributed training 64 | # for pytorch >= 2.0, use `os.environ['LOCAL_RANK']` instead 65 | # (see https://pytorch.org/docs/stable/distributed.html#launch-utility) 66 | if PYTORCH_MAJOR_VERSION == 1: 67 | parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel') 68 | 69 | args = parser.parse_args() 70 | 71 | config = get_config(args) 72 | 73 | return args, config 74 | 75 | 76 | def main(config): 77 | dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config, simmim=True, 78 | is_pretrain=False) 79 | 80 | logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") 81 | model = build_model(config, is_pretrain=False) 82 | model.cuda() 83 | logger.info(str(model)) 84 | 85 | optimizer = build_optimizer(config, model, simmim=True, is_pretrain=False) 86 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False) 87 | model_without_ddp = model.module 88 | 89 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 90 | logger.info(f"number of params: {n_parameters}") 91 | if hasattr(model_without_ddp, 'flops'): 92 | flops = model_without_ddp.flops() 93 | logger.info(f"number of GFLOPs: {flops / 1e9}") 94 | 95 | lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) 96 | scaler = amp.GradScaler() 97 | 98 | if config.AUG.MIXUP > 0.: 99 | # smoothing is handled with mixup label transform 100 | criterion = SoftTargetCrossEntropy() 101 | elif config.MODEL.LABEL_SMOOTHING > 0.: 102 | criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING) 103 | else: 104 | criterion = torch.nn.CrossEntropyLoss() 105 | 106 | max_accuracy = 0.0 107 | 108 | if config.TRAIN.AUTO_RESUME: 109 | resume_file = auto_resume_helper(config.OUTPUT, logger) 110 | if resume_file: 111 | if config.MODEL.RESUME: 112 | logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}") 113 | config.defrost() 114 | config.MODEL.RESUME = resume_file 115 | config.freeze() 116 | logger.info(f'auto resuming from {resume_file}') 117 | else: 118 | logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume') 119 | 120 | if config.MODEL.RESUME: 121 | max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, scaler, logger) 122 | acc1, acc5, loss = validate(config, data_loader_val, model) 123 | logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") 124 | if config.EVAL_MODE: 125 | return 126 | 127 | if config.MODEL.PRETRAINED and (not config.MODEL.RESUME): 128 | load_pretrained(config, model_without_ddp, logger) 129 | acc1, acc5, loss = validate(config, data_loader_val, model) 130 | logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") 131 | 132 | if config.THROUGHPUT_MODE: 133 | throughput(data_loader_val, model, logger) 134 | return 135 | 136 | logger.info("Start training") 137 | start_time = time.time() 138 | for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): 139 | data_loader_train.sampler.set_epoch(epoch) 140 | 141 | train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler, scaler) 142 | if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)): 143 | save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, scaler, logger) 144 | 145 | acc1, acc5, loss = validate(config, data_loader_val, model) 146 | logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") 147 | max_accuracy = max(max_accuracy, acc1) 148 | logger.info(f'Max accuracy: {max_accuracy:.2f}%') 149 | 150 | total_time = time.time() - start_time 151 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 152 | logger.info('Training time {}'.format(total_time_str)) 153 | 154 | 155 | def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, scaler): 156 | model.train() 157 | optimizer.zero_grad() 158 | 159 | logger.info(f'Current learning rate for different parameter groups: {[it["lr"] for it in optimizer.param_groups]}') 160 | 161 | num_steps = len(data_loader) 162 | batch_time = AverageMeter() 163 | loss_meter = AverageMeter() 164 | norm_meter = AverageMeter() 165 | loss_scale_meter = AverageMeter() 166 | 167 | start = time.time() 168 | end = time.time() 169 | for idx, (samples, targets) in enumerate(data_loader): 170 | samples = samples.cuda(non_blocking=True) 171 | targets = targets.cuda(non_blocking=True) 172 | 173 | if mixup_fn is not None: 174 | samples, targets = mixup_fn(samples, targets) 175 | 176 | outputs = model(samples) 177 | 178 | if config.TRAIN.ACCUMULATION_STEPS > 1: 179 | loss = criterion(outputs, targets) 180 | loss = loss / config.TRAIN.ACCUMULATION_STEPS 181 | scaler.scale(loss).backward() 182 | if config.TRAIN.CLIP_GRAD: 183 | scaler.unscale_(optimizer) 184 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) 185 | else: 186 | grad_norm = get_grad_norm(model.parameters()) 187 | if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: 188 | scaler.step(optimizer) 189 | optimizer.zero_grad() 190 | scaler.update() 191 | lr_scheduler.step_update(epoch * num_steps + idx) 192 | else: 193 | loss = criterion(outputs, targets) 194 | optimizer.zero_grad() 195 | scaler.scale(loss).backward() 196 | if config.TRAIN.CLIP_GRAD: 197 | scaler.unscale_(optimizer) 198 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) 199 | else: 200 | grad_norm = get_grad_norm(model.parameters()) 201 | scaler.step(optimizer) 202 | scaler.update() 203 | lr_scheduler.step_update(epoch * num_steps + idx) 204 | 205 | torch.cuda.synchronize() 206 | 207 | loss_meter.update(loss.item(), targets.size(0)) 208 | norm_meter.update(grad_norm) 209 | loss_scale_meter.update(scaler.get_scale()) 210 | batch_time.update(time.time() - end) 211 | end = time.time() 212 | 213 | if idx % config.PRINT_FREQ == 0: 214 | lr = optimizer.param_groups[-1]['lr'] 215 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) 216 | etas = batch_time.avg * (num_steps - idx) 217 | logger.info( 218 | f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' 219 | f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t' 220 | f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' 221 | f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' 222 | f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' 223 | f'loss_scale {loss_scale_meter.val:.4f} ({loss_scale_meter.avg:.4f})\t' 224 | f'mem {memory_used:.0f}MB') 225 | epoch_time = time.time() - start 226 | logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}") 227 | 228 | 229 | @torch.no_grad() 230 | def validate(config, data_loader, model): 231 | criterion = torch.nn.CrossEntropyLoss() 232 | model.eval() 233 | 234 | batch_time = AverageMeter() 235 | loss_meter = AverageMeter() 236 | acc1_meter = AverageMeter() 237 | acc5_meter = AverageMeter() 238 | 239 | end = time.time() 240 | for idx, (images, target) in enumerate(data_loader): 241 | images = images.cuda(non_blocking=True) 242 | target = target.cuda(non_blocking=True) 243 | 244 | # compute output 245 | output = model(images) 246 | 247 | # measure accuracy and record loss 248 | loss = criterion(output, target) 249 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 250 | 251 | acc1 = reduce_tensor(acc1) 252 | acc5 = reduce_tensor(acc5) 253 | loss = reduce_tensor(loss) 254 | 255 | loss_meter.update(loss.item(), target.size(0)) 256 | acc1_meter.update(acc1.item(), target.size(0)) 257 | acc5_meter.update(acc5.item(), target.size(0)) 258 | 259 | # measure elapsed time 260 | batch_time.update(time.time() - end) 261 | end = time.time() 262 | 263 | if idx % config.PRINT_FREQ == 0: 264 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) 265 | logger.info( 266 | f'Test: [{idx}/{len(data_loader)}]\t' 267 | f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 268 | f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' 269 | f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t' 270 | f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t' 271 | f'Mem {memory_used:.0f}MB') 272 | logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}') 273 | return acc1_meter.avg, acc5_meter.avg, loss_meter.avg 274 | 275 | 276 | @torch.no_grad() 277 | def throughput(data_loader, model, logger): 278 | model.eval() 279 | 280 | for idx, (images, _) in enumerate(data_loader): 281 | images = images.cuda(non_blocking=True) 282 | batch_size = images.shape[0] 283 | for i in range(50): 284 | model(images) 285 | torch.cuda.synchronize() 286 | logger.info(f"throughput averaged with 30 times") 287 | tic1 = time.time() 288 | for i in range(30): 289 | model(images) 290 | torch.cuda.synchronize() 291 | tic2 = time.time() 292 | logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}") 293 | return 294 | 295 | 296 | if __name__ == '__main__': 297 | _, config = parse_option() 298 | 299 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 300 | rank = int(os.environ["RANK"]) 301 | world_size = int(os.environ['WORLD_SIZE']) 302 | print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}") 303 | else: 304 | rank = -1 305 | world_size = -1 306 | torch.cuda.set_device(config.LOCAL_RANK) 307 | torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank) 308 | torch.distributed.barrier() 309 | 310 | seed = config.SEED + dist.get_rank() 311 | torch.manual_seed(seed) 312 | np.random.seed(seed) 313 | cudnn.benchmark = True 314 | 315 | # linear scale the learning rate according to total batch size, may not be optimal 316 | linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 317 | linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 318 | linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 319 | # gradient accumulation also need to scale the learning rate 320 | if config.TRAIN.ACCUMULATION_STEPS > 1: 321 | linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS 322 | linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS 323 | linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS 324 | config.defrost() 325 | config.TRAIN.BASE_LR = linear_scaled_lr 326 | config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr 327 | config.TRAIN.MIN_LR = linear_scaled_min_lr 328 | config.freeze() 329 | 330 | os.makedirs(config.OUTPUT, exist_ok=True) 331 | logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}") 332 | 333 | if dist.get_rank() == 0: 334 | path = os.path.join(config.OUTPUT, "config.json") 335 | with open(path, "w") as f: 336 | f.write(config.dump()) 337 | logger.info(f"Full config saved to {path}") 338 | 339 | # print config 340 | logger.info(config.dump()) 341 | 342 | main(config) 343 | -------------------------------------------------------------------------------- /main_simmim_pt.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # SimMIM 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # Modified by Zhenda Xie 7 | # -------------------------------------------------------- 8 | 9 | import os 10 | import time 11 | import argparse 12 | import datetime 13 | import numpy as np 14 | 15 | import torch 16 | import torch.backends.cudnn as cudnn 17 | import torch.distributed as dist 18 | import torch.cuda.amp as amp 19 | from timm.utils import AverageMeter 20 | 21 | from config import get_config 22 | from models import build_model 23 | from data import build_loader 24 | from lr_scheduler import build_scheduler 25 | from optimizer import build_optimizer 26 | from logger import create_logger 27 | from utils_simmim import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper 28 | 29 | # pytorch major version (1.x or 2.x) 30 | PYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0]) 31 | 32 | 33 | def parse_option(): 34 | parser = argparse.ArgumentParser('SimMIM pre-training script', add_help=False) 35 | parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', ) 36 | parser.add_argument( 37 | "--opts", 38 | help="Modify config options by adding 'KEY VALUE' pairs. ", 39 | default=None, 40 | nargs='+', 41 | ) 42 | 43 | # easy config modification 44 | parser.add_argument('--batch-size', type=int, help="batch size for single GPU") 45 | parser.add_argument('--data-path', type=str, help='path to dataset') 46 | parser.add_argument('--resume', help='resume from checkpoint') 47 | parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps") 48 | parser.add_argument('--use-checkpoint', action='store_true', 49 | help="whether to use gradient checkpointing to save memory") 50 | parser.add_argument('--enable-amp', action='store_true') 51 | parser.add_argument('--disable-amp', action='store_false', dest='enable_amp') 52 | parser.set_defaults(enable_amp=True) 53 | parser.add_argument('--output', default='output', type=str, metavar='PATH', 54 | help='root of output folder, the full path is // (default: output)') 55 | parser.add_argument('--tag', help='tag of experiment') 56 | 57 | # distributed training 58 | # for pytorch >= 2.0, use `os.environ['LOCAL_RANK']` instead 59 | # (see https://pytorch.org/docs/stable/distributed.html#launch-utility) 60 | if PYTORCH_MAJOR_VERSION == 1: 61 | parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel') 62 | 63 | args = parser.parse_args() 64 | 65 | config = get_config(args) 66 | 67 | return args, config 68 | 69 | 70 | def main(config): 71 | data_loader_train = build_loader(config, simmim=True, is_pretrain=True) 72 | 73 | logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") 74 | model = build_model(config, is_pretrain=True) 75 | model.cuda() 76 | logger.info(str(model)) 77 | 78 | optimizer = build_optimizer(config, model, simmim=True, is_pretrain=True) 79 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False) 80 | model_without_ddp = model.module 81 | 82 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 83 | logger.info(f"number of params: {n_parameters}") 84 | if hasattr(model_without_ddp, 'flops'): 85 | flops = model_without_ddp.flops() 86 | logger.info(f"number of GFLOPs: {flops / 1e9}") 87 | 88 | lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) 89 | scaler = amp.GradScaler() 90 | 91 | if config.TRAIN.AUTO_RESUME: 92 | resume_file = auto_resume_helper(config.OUTPUT, logger) 93 | if resume_file: 94 | if config.MODEL.RESUME: 95 | logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}") 96 | config.defrost() 97 | config.MODEL.RESUME = resume_file 98 | config.freeze() 99 | logger.info(f'auto resuming from {resume_file}') 100 | else: 101 | logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume') 102 | 103 | if config.MODEL.RESUME: 104 | load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, scaler, logger) 105 | 106 | logger.info("Start training") 107 | start_time = time.time() 108 | for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): 109 | data_loader_train.sampler.set_epoch(epoch) 110 | 111 | train_one_epoch(config, model, data_loader_train, optimizer, epoch, lr_scheduler, scaler) 112 | if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)): 113 | save_checkpoint(config, epoch, model_without_ddp, 0., optimizer, lr_scheduler, scaler, logger) 114 | 115 | total_time = time.time() - start_time 116 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 117 | logger.info('Training time {}'.format(total_time_str)) 118 | 119 | 120 | def train_one_epoch(config, model, data_loader, optimizer, epoch, lr_scheduler, scaler): 121 | model.train() 122 | optimizer.zero_grad() 123 | 124 | num_steps = len(data_loader) 125 | batch_time = AverageMeter() 126 | loss_meter = AverageMeter() 127 | norm_meter = AverageMeter() 128 | loss_scale_meter = AverageMeter() 129 | 130 | start = time.time() 131 | end = time.time() 132 | for idx, (img, mask, _) in enumerate(data_loader): 133 | img = img.cuda(non_blocking=True) 134 | mask = mask.cuda(non_blocking=True) 135 | 136 | with amp.autocast(enabled=config.ENABLE_AMP): 137 | loss = model(img, mask) 138 | 139 | if config.TRAIN.ACCUMULATION_STEPS > 1: 140 | loss = loss / config.TRAIN.ACCUMULATION_STEPS 141 | scaler.scale(loss).backward() 142 | if config.TRAIN.CLIP_GRAD: 143 | scaler.unscale_(optimizer) 144 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) 145 | else: 146 | grad_norm = get_grad_norm(model.parameters()) 147 | if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: 148 | scaler.step(optimizer) 149 | optimizer.zero_grad() 150 | scaler.update() 151 | lr_scheduler.step_update(epoch * num_steps + idx) 152 | else: 153 | optimizer.zero_grad() 154 | scaler.scale(loss).backward() 155 | if config.TRAIN.CLIP_GRAD: 156 | scaler.unscale_(optimizer) 157 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) 158 | else: 159 | grad_norm = get_grad_norm(model.parameters()) 160 | scaler.step(optimizer) 161 | scaler.update() 162 | lr_scheduler.step_update(epoch * num_steps + idx) 163 | 164 | torch.cuda.synchronize() 165 | 166 | loss_meter.update(loss.item(), img.size(0)) 167 | norm_meter.update(grad_norm) 168 | loss_scale_meter.update(scaler.get_scale()) 169 | batch_time.update(time.time() - end) 170 | end = time.time() 171 | 172 | if idx % config.PRINT_FREQ == 0: 173 | lr = optimizer.param_groups[0]['lr'] 174 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) 175 | etas = batch_time.avg * (num_steps - idx) 176 | logger.info( 177 | f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' 178 | f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t' 179 | f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' 180 | f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' 181 | f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' 182 | f'loss_scale {loss_scale_meter.val:.4f} ({loss_scale_meter.avg:.4f})\t' 183 | f'mem {memory_used:.0f}MB') 184 | epoch_time = time.time() - start 185 | logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}") 186 | 187 | 188 | if __name__ == '__main__': 189 | _, config = parse_option() 190 | 191 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 192 | rank = int(os.environ["RANK"]) 193 | world_size = int(os.environ['WORLD_SIZE']) 194 | print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}") 195 | else: 196 | rank = -1 197 | world_size = -1 198 | torch.cuda.set_device(config.LOCAL_RANK) 199 | torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank) 200 | torch.distributed.barrier() 201 | 202 | seed = config.SEED + dist.get_rank() 203 | torch.manual_seed(seed) 204 | np.random.seed(seed) 205 | cudnn.benchmark = True 206 | 207 | # linear scale the learning rate according to total batch size, may not be optimal 208 | linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 209 | linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 210 | linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 211 | # gradient accumulation also need to scale the learning rate 212 | if config.TRAIN.ACCUMULATION_STEPS > 1: 213 | linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS 214 | linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS 215 | linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS 216 | config.defrost() 217 | config.TRAIN.BASE_LR = linear_scaled_lr 218 | config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr 219 | config.TRAIN.MIN_LR = linear_scaled_min_lr 220 | config.freeze() 221 | 222 | os.makedirs(config.OUTPUT, exist_ok=True) 223 | logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}") 224 | 225 | if dist.get_rank() == 0: 226 | path = os.path.join(config.OUTPUT, "config.json") 227 | with open(path, "w") as f: 228 | f.write(config.dump()) 229 | logger.info(f"Full config saved to {path}") 230 | 231 | # print config 232 | logger.info(config.dump()) 233 | 234 | main(config) 235 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_model -------------------------------------------------------------------------------- /models/build.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | from .swin_transformer import SwinTransformer 9 | from .swin_transformer_v2 import SwinTransformerV2 10 | from .swin_transformer_moe import SwinTransformerMoE 11 | from .swin_mlp import SwinMLP 12 | from .simmim import build_simmim 13 | 14 | 15 | def build_model(config, is_pretrain=False): 16 | model_type = config.MODEL.TYPE 17 | 18 | # accelerate layernorm 19 | if config.FUSED_LAYERNORM: 20 | try: 21 | import apex as amp 22 | layernorm = amp.normalization.FusedLayerNorm 23 | except: 24 | layernorm = None 25 | print("To use FusedLayerNorm, please install apex.") 26 | else: 27 | import torch.nn as nn 28 | layernorm = nn.LayerNorm 29 | 30 | if is_pretrain: 31 | model = build_simmim(config) 32 | return model 33 | 34 | if model_type == 'swin': 35 | model = SwinTransformer(img_size=config.DATA.IMG_SIZE, 36 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 37 | in_chans=config.MODEL.SWIN.IN_CHANS, 38 | num_classes=config.MODEL.NUM_CLASSES, 39 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 40 | depths=config.MODEL.SWIN.DEPTHS, 41 | num_heads=config.MODEL.SWIN.NUM_HEADS, 42 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 43 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 44 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 45 | qk_scale=config.MODEL.SWIN.QK_SCALE, 46 | drop_rate=config.MODEL.DROP_RATE, 47 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 48 | ape=config.MODEL.SWIN.APE, 49 | norm_layer=layernorm, 50 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 51 | use_checkpoint=config.TRAIN.USE_CHECKPOINT, 52 | fused_window_process=config.FUSED_WINDOW_PROCESS) 53 | elif model_type == 'swinv2': 54 | model = SwinTransformerV2(img_size=config.DATA.IMG_SIZE, 55 | patch_size=config.MODEL.SWINV2.PATCH_SIZE, 56 | in_chans=config.MODEL.SWINV2.IN_CHANS, 57 | num_classes=config.MODEL.NUM_CLASSES, 58 | embed_dim=config.MODEL.SWINV2.EMBED_DIM, 59 | depths=config.MODEL.SWINV2.DEPTHS, 60 | num_heads=config.MODEL.SWINV2.NUM_HEADS, 61 | window_size=config.MODEL.SWINV2.WINDOW_SIZE, 62 | mlp_ratio=config.MODEL.SWINV2.MLP_RATIO, 63 | qkv_bias=config.MODEL.SWINV2.QKV_BIAS, 64 | drop_rate=config.MODEL.DROP_RATE, 65 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 66 | ape=config.MODEL.SWINV2.APE, 67 | patch_norm=config.MODEL.SWINV2.PATCH_NORM, 68 | use_checkpoint=config.TRAIN.USE_CHECKPOINT, 69 | pretrained_window_sizes=config.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES) 70 | elif model_type == 'swin_moe': 71 | model = SwinTransformerMoE(img_size=config.DATA.IMG_SIZE, 72 | patch_size=config.MODEL.SWIN_MOE.PATCH_SIZE, 73 | in_chans=config.MODEL.SWIN_MOE.IN_CHANS, 74 | num_classes=config.MODEL.NUM_CLASSES, 75 | embed_dim=config.MODEL.SWIN_MOE.EMBED_DIM, 76 | depths=config.MODEL.SWIN_MOE.DEPTHS, 77 | num_heads=config.MODEL.SWIN_MOE.NUM_HEADS, 78 | window_size=config.MODEL.SWIN_MOE.WINDOW_SIZE, 79 | mlp_ratio=config.MODEL.SWIN_MOE.MLP_RATIO, 80 | qkv_bias=config.MODEL.SWIN_MOE.QKV_BIAS, 81 | qk_scale=config.MODEL.SWIN_MOE.QK_SCALE, 82 | drop_rate=config.MODEL.DROP_RATE, 83 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 84 | ape=config.MODEL.SWIN_MOE.APE, 85 | patch_norm=config.MODEL.SWIN_MOE.PATCH_NORM, 86 | mlp_fc2_bias=config.MODEL.SWIN_MOE.MLP_FC2_BIAS, 87 | init_std=config.MODEL.SWIN_MOE.INIT_STD, 88 | use_checkpoint=config.TRAIN.USE_CHECKPOINT, 89 | pretrained_window_sizes=config.MODEL.SWIN_MOE.PRETRAINED_WINDOW_SIZES, 90 | moe_blocks=config.MODEL.SWIN_MOE.MOE_BLOCKS, 91 | num_local_experts=config.MODEL.SWIN_MOE.NUM_LOCAL_EXPERTS, 92 | top_value=config.MODEL.SWIN_MOE.TOP_VALUE, 93 | capacity_factor=config.MODEL.SWIN_MOE.CAPACITY_FACTOR, 94 | cosine_router=config.MODEL.SWIN_MOE.COSINE_ROUTER, 95 | normalize_gate=config.MODEL.SWIN_MOE.NORMALIZE_GATE, 96 | use_bpr=config.MODEL.SWIN_MOE.USE_BPR, 97 | is_gshard_loss=config.MODEL.SWIN_MOE.IS_GSHARD_LOSS, 98 | gate_noise=config.MODEL.SWIN_MOE.GATE_NOISE, 99 | cosine_router_dim=config.MODEL.SWIN_MOE.COSINE_ROUTER_DIM, 100 | cosine_router_init_t=config.MODEL.SWIN_MOE.COSINE_ROUTER_INIT_T, 101 | moe_drop=config.MODEL.SWIN_MOE.MOE_DROP, 102 | aux_loss_weight=config.MODEL.SWIN_MOE.AUX_LOSS_WEIGHT) 103 | elif model_type == 'swin_mlp': 104 | model = SwinMLP(img_size=config.DATA.IMG_SIZE, 105 | patch_size=config.MODEL.SWIN_MLP.PATCH_SIZE, 106 | in_chans=config.MODEL.SWIN_MLP.IN_CHANS, 107 | num_classes=config.MODEL.NUM_CLASSES, 108 | embed_dim=config.MODEL.SWIN_MLP.EMBED_DIM, 109 | depths=config.MODEL.SWIN_MLP.DEPTHS, 110 | num_heads=config.MODEL.SWIN_MLP.NUM_HEADS, 111 | window_size=config.MODEL.SWIN_MLP.WINDOW_SIZE, 112 | mlp_ratio=config.MODEL.SWIN_MLP.MLP_RATIO, 113 | drop_rate=config.MODEL.DROP_RATE, 114 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 115 | ape=config.MODEL.SWIN_MLP.APE, 116 | patch_norm=config.MODEL.SWIN_MLP.PATCH_NORM, 117 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 118 | else: 119 | raise NotImplementedError(f"Unkown model: {model_type}") 120 | 121 | return model 122 | -------------------------------------------------------------------------------- /models/simmim.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | # -------------------------------------------------------- 4 | # SimMIM 5 | # Copyright (c) 2021 Microsoft 6 | # Licensed under The MIT License [see LICENSE for details] 7 | # Written by Zhenda Xie 8 | # -------------------------------------------------------- 9 | 10 | from functools import partial 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from timm.models.layers import trunc_normal_ 16 | 17 | from .swin_transformer import SwinTransformer 18 | from .swin_transformer_v2 import SwinTransformerV2 19 | 20 | 21 | def norm_targets(targets, patch_size): 22 | assert patch_size % 2 == 1 23 | 24 | targets_ = targets 25 | targets_count = torch.ones_like(targets) 26 | 27 | targets_square = targets ** 2. 28 | 29 | targets_mean = F.avg_pool2d(targets, kernel_size=patch_size, stride=1, padding=patch_size // 2, count_include_pad=False) 30 | targets_square_mean = F.avg_pool2d(targets_square, kernel_size=patch_size, stride=1, padding=patch_size // 2, count_include_pad=False) 31 | targets_count = F.avg_pool2d(targets_count, kernel_size=patch_size, stride=1, padding=patch_size // 2, count_include_pad=True) * (patch_size ** 2) 32 | 33 | targets_var = (targets_square_mean - targets_mean ** 2.) * (targets_count / (targets_count - 1)) 34 | targets_var = torch.clamp(targets_var, min=0.) 35 | 36 | targets_ = (targets_ - targets_mean) / (targets_var + 1.e-6) ** 0.5 37 | 38 | return targets_ 39 | 40 | 41 | class SwinTransformerForSimMIM(SwinTransformer): 42 | def __init__(self, **kwargs): 43 | super().__init__(**kwargs) 44 | 45 | assert self.num_classes == 0 46 | 47 | self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) 48 | trunc_normal_(self.mask_token, mean=0., std=.02) 49 | 50 | def forward(self, x, mask): 51 | x = self.patch_embed(x) 52 | 53 | assert mask is not None 54 | B, L, _ = x.shape 55 | 56 | mask_tokens = self.mask_token.expand(B, L, -1) 57 | w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens) 58 | x = x * (1. - w) + mask_tokens * w 59 | 60 | if self.ape: 61 | x = x + self.absolute_pos_embed 62 | x = self.pos_drop(x) 63 | 64 | for layer in self.layers: 65 | x = layer(x) 66 | x = self.norm(x) 67 | 68 | x = x.transpose(1, 2) 69 | B, C, L = x.shape 70 | H = W = int(L ** 0.5) 71 | x = x.reshape(B, C, H, W) 72 | return x 73 | 74 | @torch.jit.ignore 75 | def no_weight_decay(self): 76 | return super().no_weight_decay() | {'mask_token'} 77 | 78 | 79 | class SwinTransformerV2ForSimMIM(SwinTransformerV2): 80 | def __init__(self, **kwargs): 81 | super().__init__(**kwargs) 82 | 83 | assert self.num_classes == 0 84 | 85 | self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) 86 | trunc_normal_(self.mask_token, mean=0., std=.02) 87 | 88 | def forward(self, x, mask): 89 | x = self.patch_embed(x) 90 | 91 | assert mask is not None 92 | B, L, _ = x.shape 93 | 94 | mask_tokens = self.mask_token.expand(B, L, -1) 95 | w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens) 96 | x = x * (1. - w) + mask_tokens * w 97 | 98 | if self.ape: 99 | x = x + self.absolute_pos_embed 100 | x = self.pos_drop(x) 101 | 102 | for layer in self.layers: 103 | x = layer(x) 104 | x = self.norm(x) 105 | 106 | x = x.transpose(1, 2) 107 | B, C, L = x.shape 108 | H = W = int(L ** 0.5) 109 | x = x.reshape(B, C, H, W) 110 | return x 111 | 112 | @torch.jit.ignore 113 | def no_weight_decay(self): 114 | return super().no_weight_decay() | {'mask_token'} 115 | 116 | 117 | class SimMIM(nn.Module): 118 | def __init__(self, config, encoder, encoder_stride, in_chans, patch_size): 119 | super().__init__() 120 | self.config = config 121 | self.encoder = encoder 122 | self.encoder_stride = encoder_stride 123 | 124 | self.decoder = nn.Sequential( 125 | nn.Conv2d( 126 | in_channels=self.encoder.num_features, 127 | out_channels=self.encoder_stride ** 2 * 3, kernel_size=1), 128 | nn.PixelShuffle(self.encoder_stride), 129 | ) 130 | 131 | self.in_chans = in_chans 132 | self.patch_size = patch_size 133 | 134 | def forward(self, x, mask): 135 | z = self.encoder(x, mask) 136 | x_rec = self.decoder(z) 137 | 138 | mask = mask.repeat_interleave(self.patch_size, 1).repeat_interleave(self.patch_size, 2).unsqueeze(1).contiguous() 139 | 140 | # norm target as prompted 141 | if self.config.NORM_TARGET.ENABLE: 142 | x = norm_targets(x, self.config.NORM_TARGET.PATCH_SIZE) 143 | 144 | loss_recon = F.l1_loss(x, x_rec, reduction='none') 145 | loss = (loss_recon * mask).sum() / (mask.sum() + 1e-5) / self.in_chans 146 | return loss 147 | 148 | @torch.jit.ignore 149 | def no_weight_decay(self): 150 | if hasattr(self.encoder, 'no_weight_decay'): 151 | return {'encoder.' + i for i in self.encoder.no_weight_decay()} 152 | return {} 153 | 154 | @torch.jit.ignore 155 | def no_weight_decay_keywords(self): 156 | if hasattr(self.encoder, 'no_weight_decay_keywords'): 157 | return {'encoder.' + i for i in self.encoder.no_weight_decay_keywords()} 158 | return {} 159 | 160 | 161 | def build_simmim(config): 162 | model_type = config.MODEL.TYPE 163 | if model_type == 'swin': 164 | encoder = SwinTransformerForSimMIM( 165 | img_size=config.DATA.IMG_SIZE, 166 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 167 | in_chans=config.MODEL.SWIN.IN_CHANS, 168 | num_classes=0, 169 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 170 | depths=config.MODEL.SWIN.DEPTHS, 171 | num_heads=config.MODEL.SWIN.NUM_HEADS, 172 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 173 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 174 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 175 | qk_scale=config.MODEL.SWIN.QK_SCALE, 176 | drop_rate=config.MODEL.DROP_RATE, 177 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 178 | ape=config.MODEL.SWIN.APE, 179 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 180 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 181 | encoder_stride = 32 182 | in_chans = config.MODEL.SWIN.IN_CHANS 183 | patch_size = config.MODEL.SWIN.PATCH_SIZE 184 | elif model_type == 'swinv2': 185 | encoder = SwinTransformerV2ForSimMIM( 186 | img_size=config.DATA.IMG_SIZE, 187 | patch_size=config.MODEL.SWINV2.PATCH_SIZE, 188 | in_chans=config.MODEL.SWINV2.IN_CHANS, 189 | num_classes=0, 190 | embed_dim=config.MODEL.SWINV2.EMBED_DIM, 191 | depths=config.MODEL.SWINV2.DEPTHS, 192 | num_heads=config.MODEL.SWINV2.NUM_HEADS, 193 | window_size=config.MODEL.SWINV2.WINDOW_SIZE, 194 | mlp_ratio=config.MODEL.SWINV2.MLP_RATIO, 195 | qkv_bias=config.MODEL.SWINV2.QKV_BIAS, 196 | drop_rate=config.MODEL.DROP_RATE, 197 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 198 | ape=config.MODEL.SWINV2.APE, 199 | patch_norm=config.MODEL.SWINV2.PATCH_NORM, 200 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 201 | encoder_stride = 32 202 | in_chans = config.MODEL.SWINV2.IN_CHANS 203 | patch_size = config.MODEL.SWINV2.PATCH_SIZE 204 | else: 205 | raise NotImplementedError(f"Unknown pre-train model: {model_type}") 206 | 207 | model = SimMIM(config=config.MODEL.SIMMIM, encoder=encoder, encoder_stride=encoder_stride, in_chans=in_chans, patch_size=patch_size) 208 | 209 | return model -------------------------------------------------------------------------------- /optimizer.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | from functools import partial 9 | from torch import optim as optim 10 | 11 | try: 12 | from apex.optimizers import FusedAdam, FusedLAMB 13 | except: 14 | FusedAdam = None 15 | FusedLAMB = None 16 | print("To use FusedLAMB or FusedAdam, please install apex.") 17 | 18 | 19 | def build_optimizer(config, model, simmim=False, is_pretrain=False): 20 | """ 21 | Build optimizer, set weight decay of normalization to 0 by default. 22 | """ 23 | skip = {} 24 | skip_keywords = {} 25 | if hasattr(model, 'no_weight_decay'): 26 | skip = model.no_weight_decay() 27 | if hasattr(model, 'no_weight_decay_keywords'): 28 | skip_keywords = model.no_weight_decay_keywords() 29 | if simmim: 30 | if is_pretrain: 31 | parameters = get_pretrain_param_groups(model, skip, skip_keywords) 32 | else: 33 | depths = config.MODEL.SWIN.DEPTHS if config.MODEL.TYPE == 'swin' else config.MODEL.SWINV2.DEPTHS 34 | num_layers = sum(depths) 35 | get_layer_func = partial(get_swin_layer, num_layers=num_layers + 2, depths=depths) 36 | scales = list(config.TRAIN.LAYER_DECAY ** i for i in reversed(range(num_layers + 2))) 37 | parameters = get_finetune_param_groups(model, config.TRAIN.BASE_LR, config.TRAIN.WEIGHT_DECAY, get_layer_func, scales, skip, skip_keywords) 38 | else: 39 | parameters = set_weight_decay(model, skip, skip_keywords) 40 | 41 | opt_lower = config.TRAIN.OPTIMIZER.NAME.lower() 42 | optimizer = None 43 | if opt_lower == 'sgd': 44 | optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True, 45 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 46 | elif opt_lower == 'adamw': 47 | optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, 48 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 49 | elif opt_lower == 'fused_adam': 50 | optimizer = FusedAdam(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, 51 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 52 | elif opt_lower == 'fused_lamb': 53 | optimizer = FusedLAMB(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, 54 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 55 | 56 | return optimizer 57 | 58 | 59 | def set_weight_decay(model, skip_list=(), skip_keywords=()): 60 | has_decay = [] 61 | no_decay = [] 62 | 63 | for name, param in model.named_parameters(): 64 | if not param.requires_grad: 65 | continue # frozen weights 66 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ 67 | check_keywords_in_name(name, skip_keywords): 68 | no_decay.append(param) 69 | # print(f"{name} has no weight decay") 70 | else: 71 | has_decay.append(param) 72 | return [{'params': has_decay}, 73 | {'params': no_decay, 'weight_decay': 0.}] 74 | 75 | 76 | def check_keywords_in_name(name, keywords=()): 77 | isin = False 78 | for keyword in keywords: 79 | if keyword in name: 80 | isin = True 81 | return isin 82 | 83 | 84 | def get_pretrain_param_groups(model, skip_list=(), skip_keywords=()): 85 | has_decay = [] 86 | no_decay = [] 87 | has_decay_name = [] 88 | no_decay_name = [] 89 | 90 | for name, param in model.named_parameters(): 91 | if not param.requires_grad: 92 | continue 93 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ 94 | check_keywords_in_name(name, skip_keywords): 95 | no_decay.append(param) 96 | no_decay_name.append(name) 97 | else: 98 | has_decay.append(param) 99 | has_decay_name.append(name) 100 | return [{'params': has_decay}, 101 | {'params': no_decay, 'weight_decay': 0.}] 102 | 103 | 104 | def get_swin_layer(name, num_layers, depths): 105 | if name in ("mask_token"): 106 | return 0 107 | elif name.startswith("patch_embed"): 108 | return 0 109 | elif name.startswith("layers"): 110 | layer_id = int(name.split('.')[1]) 111 | block_id = name.split('.')[3] 112 | if block_id == 'reduction' or block_id == 'norm': 113 | return sum(depths[:layer_id + 1]) 114 | layer_id = sum(depths[:layer_id]) + int(block_id) 115 | return layer_id + 1 116 | else: 117 | return num_layers - 1 118 | 119 | 120 | def get_finetune_param_groups(model, lr, weight_decay, get_layer_func, scales, skip_list=(), skip_keywords=()): 121 | parameter_group_names = {} 122 | parameter_group_vars = {} 123 | 124 | for name, param in model.named_parameters(): 125 | if not param.requires_grad: 126 | continue 127 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ 128 | check_keywords_in_name(name, skip_keywords): 129 | group_name = "no_decay" 130 | this_weight_decay = 0. 131 | else: 132 | group_name = "decay" 133 | this_weight_decay = weight_decay 134 | if get_layer_func is not None: 135 | layer_id = get_layer_func(name) 136 | group_name = "layer_%d_%s" % (layer_id, group_name) 137 | else: 138 | layer_id = None 139 | 140 | if group_name not in parameter_group_names: 141 | if scales is not None: 142 | scale = scales[layer_id] 143 | else: 144 | scale = 1. 145 | 146 | parameter_group_names[group_name] = { 147 | "group_name": group_name, 148 | "weight_decay": this_weight_decay, 149 | "params": [], 150 | "lr": lr * scale, 151 | "lr_scale": scale, 152 | } 153 | parameter_group_vars[group_name] = { 154 | "group_name": group_name, 155 | "weight_decay": this_weight_decay, 156 | "params": [], 157 | "lr": lr * scale, 158 | "lr_scale": scale 159 | } 160 | 161 | parameter_group_vars[group_name]["params"].append(param) 162 | parameter_group_names[group_name]["params"].append(name) 163 | return list(parameter_group_vars.values()) 164 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import torch 10 | import torch.distributed as dist 11 | 12 | try: 13 | from torch._six import inf 14 | except: 15 | from torch import inf 16 | 17 | 18 | def load_checkpoint(config, model, optimizer, lr_scheduler, loss_scaler, logger): 19 | logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................") 20 | if config.MODEL.RESUME.startswith('https'): 21 | checkpoint = torch.hub.load_state_dict_from_url( 22 | config.MODEL.RESUME, map_location='cpu', check_hash=True) 23 | else: 24 | checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu') 25 | msg = model.load_state_dict(checkpoint['model'], strict=False) 26 | logger.info(msg) 27 | max_accuracy = 0.0 28 | if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 29 | optimizer.load_state_dict(checkpoint['optimizer']) 30 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 31 | config.defrost() 32 | config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1 33 | config.freeze() 34 | if 'scaler' in checkpoint: 35 | loss_scaler.load_state_dict(checkpoint['scaler']) 36 | logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})") 37 | if 'max_accuracy' in checkpoint: 38 | max_accuracy = checkpoint['max_accuracy'] 39 | 40 | del checkpoint 41 | torch.cuda.empty_cache() 42 | return max_accuracy 43 | 44 | 45 | def load_pretrained(config, model, logger): 46 | logger.info(f"==============> Loading weight {config.MODEL.PRETRAINED} for fine-tuning......") 47 | checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu') 48 | state_dict = checkpoint['model'] 49 | 50 | # delete relative_position_index since we always re-init it 51 | relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k] 52 | for k in relative_position_index_keys: 53 | del state_dict[k] 54 | 55 | # delete relative_coords_table since we always re-init it 56 | relative_position_index_keys = [k for k in state_dict.keys() if "relative_coords_table" in k] 57 | for k in relative_position_index_keys: 58 | del state_dict[k] 59 | 60 | # delete attn_mask since we always re-init it 61 | attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k] 62 | for k in attn_mask_keys: 63 | del state_dict[k] 64 | 65 | # bicubic interpolate relative_position_bias_table if not match 66 | relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k] 67 | for k in relative_position_bias_table_keys: 68 | relative_position_bias_table_pretrained = state_dict[k] 69 | relative_position_bias_table_current = model.state_dict()[k] 70 | L1, nH1 = relative_position_bias_table_pretrained.size() 71 | L2, nH2 = relative_position_bias_table_current.size() 72 | if nH1 != nH2: 73 | logger.warning(f"Error in loading {k}, passing......") 74 | else: 75 | if L1 != L2: 76 | # bicubic interpolate relative_position_bias_table if not match 77 | S1 = int(L1 ** 0.5) 78 | S2 = int(L2 ** 0.5) 79 | relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( 80 | relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2), 81 | mode='bicubic') 82 | state_dict[k] = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0) 83 | 84 | # bicubic interpolate absolute_pos_embed if not match 85 | absolute_pos_embed_keys = [k for k in state_dict.keys() if "absolute_pos_embed" in k] 86 | for k in absolute_pos_embed_keys: 87 | # dpe 88 | absolute_pos_embed_pretrained = state_dict[k] 89 | absolute_pos_embed_current = model.state_dict()[k] 90 | _, L1, C1 = absolute_pos_embed_pretrained.size() 91 | _, L2, C2 = absolute_pos_embed_current.size() 92 | if C1 != C1: 93 | logger.warning(f"Error in loading {k}, passing......") 94 | else: 95 | if L1 != L2: 96 | S1 = int(L1 ** 0.5) 97 | S2 = int(L2 ** 0.5) 98 | absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1) 99 | absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2) 100 | absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate( 101 | absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic') 102 | absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1) 103 | absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.flatten(1, 2) 104 | state_dict[k] = absolute_pos_embed_pretrained_resized 105 | 106 | # check classifier, if not match, then re-init classifier to zero 107 | head_bias_pretrained = state_dict['head.bias'] 108 | Nc1 = head_bias_pretrained.shape[0] 109 | Nc2 = model.head.bias.shape[0] 110 | if (Nc1 != Nc2): 111 | if Nc1 == 21841 and Nc2 == 1000: 112 | logger.info("loading ImageNet-22K weight to ImageNet-1K ......") 113 | map22kto1k_path = f'data/map22kto1k.txt' 114 | with open(map22kto1k_path) as f: 115 | map22kto1k = f.readlines() 116 | map22kto1k = [int(id22k.strip()) for id22k in map22kto1k] 117 | state_dict['head.weight'] = state_dict['head.weight'][map22kto1k, :] 118 | state_dict['head.bias'] = state_dict['head.bias'][map22kto1k] 119 | else: 120 | torch.nn.init.constant_(model.head.bias, 0.) 121 | torch.nn.init.constant_(model.head.weight, 0.) 122 | del state_dict['head.weight'] 123 | del state_dict['head.bias'] 124 | logger.warning(f"Error in loading classifier head, re-init classifier head to 0") 125 | 126 | msg = model.load_state_dict(state_dict, strict=False) 127 | logger.warning(msg) 128 | 129 | logger.info(f"=> loaded successfully '{config.MODEL.PRETRAINED}'") 130 | 131 | del checkpoint 132 | torch.cuda.empty_cache() 133 | 134 | 135 | def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, loss_scaler, logger): 136 | save_state = {'model': model.state_dict(), 137 | 'optimizer': optimizer.state_dict(), 138 | 'lr_scheduler': lr_scheduler.state_dict(), 139 | 'max_accuracy': max_accuracy, 140 | 'scaler': loss_scaler.state_dict(), 141 | 'epoch': epoch, 142 | 'config': config} 143 | 144 | save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth') 145 | logger.info(f"{save_path} saving......") 146 | torch.save(save_state, save_path) 147 | logger.info(f"{save_path} saved !!!") 148 | 149 | 150 | def get_grad_norm(parameters, norm_type=2): 151 | if isinstance(parameters, torch.Tensor): 152 | parameters = [parameters] 153 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 154 | norm_type = float(norm_type) 155 | total_norm = 0 156 | for p in parameters: 157 | param_norm = p.grad.data.norm(norm_type) 158 | total_norm += param_norm.item() ** norm_type 159 | total_norm = total_norm ** (1. / norm_type) 160 | return total_norm 161 | 162 | 163 | def auto_resume_helper(output_dir): 164 | checkpoints = os.listdir(output_dir) 165 | checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')] 166 | print(f"All checkpoints founded in {output_dir}: {checkpoints}") 167 | if len(checkpoints) > 0: 168 | latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime) 169 | print(f"The latest checkpoint founded: {latest_checkpoint}") 170 | resume_file = latest_checkpoint 171 | else: 172 | resume_file = None 173 | return resume_file 174 | 175 | 176 | def reduce_tensor(tensor): 177 | rt = tensor.clone() 178 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 179 | rt /= dist.get_world_size() 180 | return rt 181 | 182 | 183 | def ampscaler_get_grad_norm(parameters, norm_type: float = 2.0) -> torch.Tensor: 184 | if isinstance(parameters, torch.Tensor): 185 | parameters = [parameters] 186 | parameters = [p for p in parameters if p.grad is not None] 187 | norm_type = float(norm_type) 188 | if len(parameters) == 0: 189 | return torch.tensor(0.) 190 | device = parameters[0].grad.device 191 | if norm_type == inf: 192 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 193 | else: 194 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), 195 | norm_type).to(device) for p in parameters]), norm_type) 196 | return total_norm 197 | 198 | 199 | class NativeScalerWithGradNormCount: 200 | state_dict_key = "amp_scaler" 201 | 202 | def __init__(self): 203 | self._scaler = torch.cuda.amp.GradScaler() 204 | 205 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 206 | self._scaler.scale(loss).backward(create_graph=create_graph) 207 | if update_grad: 208 | if clip_grad is not None: 209 | assert parameters is not None 210 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 211 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 212 | else: 213 | self._scaler.unscale_(optimizer) 214 | norm = ampscaler_get_grad_norm(parameters) 215 | self._scaler.step(optimizer) 216 | self._scaler.update() 217 | else: 218 | norm = None 219 | return norm 220 | 221 | def state_dict(self): 222 | return self._scaler.state_dict() 223 | 224 | def load_state_dict(self, state_dict): 225 | self._scaler.load_state_dict(state_dict) 226 | -------------------------------------------------------------------------------- /utils_moe.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import torch 10 | import torch.distributed as dist 11 | 12 | 13 | def split_moe_model_state_dict(moe_keys, model_state_dict): 14 | moe_model_state_dict = {} 15 | non_moe_model_state_dict = {} 16 | for (k, v) in model_state_dict.items(): 17 | if k in moe_keys: 18 | moe_model_state_dict[k] = v 19 | else: 20 | non_moe_model_state_dict[k] = v 21 | return moe_model_state_dict, non_moe_model_state_dict 22 | 23 | 24 | def merge_moe_model_state_dict(moe_model_state_dict, non_moe_model_state_dict): 25 | model_state_dict = {} 26 | model_state_dict.update(moe_model_state_dict) 27 | model_state_dict.update(non_moe_model_state_dict) 28 | return model_state_dict 29 | 30 | 31 | def load_checkpoint(config, model, optimizer, lr_scheduler, loss_scaler, logger): 32 | global_rank = dist.get_rank() 33 | logger.info(f"==============> Rank[{global_rank}] Resuming form {config.MODEL.RESUME}....................") 34 | if config.MODEL.RESUME.endswith(f'.pth'): 35 | if config.TRAIN.MOE.SAVE_MASTER: 36 | resume_path = config.MODEL.RESUME + f'.global' 37 | else: 38 | resume_path = config.MODEL.RESUME + f'.rank{global_rank}' 39 | logger.info(f"===> Rank[{global_rank}] Re-formatting checkpoint name to {resume_path}......") 40 | else: 41 | resume_path = config.MODEL.RESUME 42 | 43 | checkpoint = torch.load(resume_path, map_location='cpu') 44 | msg = model.load_state_dict(checkpoint['model'], strict=False) 45 | logger.info(msg) 46 | max_accuracy = 0.0 47 | if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 48 | optimizer.load_state_dict(checkpoint['optimizer']) 49 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 50 | config.defrost() 51 | config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1 52 | config.freeze() 53 | if 'scaler' in checkpoint: 54 | loss_scaler.load_state_dict(checkpoint['scaler']) 55 | logger.info(f"=>Rank[{global_rank}] loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})") 56 | if 'max_accuracy' in checkpoint: 57 | max_accuracy = checkpoint['max_accuracy'] 58 | 59 | del checkpoint 60 | torch.cuda.empty_cache() 61 | return max_accuracy 62 | 63 | 64 | def load_pretrained(config, model, logger): 65 | global_rank = dist.get_rank() 66 | logger.info(f"==============> Rank[{global_rank}] Loading weight {config.MODEL.PRETRAINED} for fine-tuning......") 67 | if config.MODEL.PRETRAINED.endswith(f'.pth'): 68 | if config.TRAIN.MOE.SAVE_MASTER: 69 | pretrained_path = config.MODEL.PRETRAINED + f'.global' 70 | else: 71 | pretrained_path = config.MODEL.PRETRAINED + f'.rank{global_rank}' 72 | logger.info(f"===> Rank[{global_rank}] Re-formatting checkpoint name to {pretrained_path}......") 73 | else: 74 | pretrained_path = config.MODEL.PRETRAINED 75 | 76 | if pretrained_path.endswith(f'.rank{global_rank}'): 77 | checkpoint = torch.load(pretrained_path, map_location='cpu') 78 | if os.path.exists(pretrained_path.replace(f'.rank{global_rank}', f'.master')): 79 | checkpoint_master = torch.load(pretrained_path.replace(f'.rank{global_rank}', f'.master'), 80 | map_location='cpu') 81 | state_dict = merge_moe_model_state_dict(checkpoint['model'], checkpoint_master['model']) 82 | else: 83 | state_dict = checkpoint['model'] 84 | elif pretrained_path.endswith(f'.pth.global'): 85 | checkpoint = torch.load(pretrained_path, map_location='cpu') 86 | state_dict = checkpoint['model'] 87 | else: 88 | raise NotImplementedError(f"{config.MODEL.PRETRAINED} file error...") 89 | 90 | # delete relative_position_index since we always re-init it 91 | relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k] 92 | for k in relative_position_index_keys: 93 | del state_dict[k] 94 | 95 | # delete relative_coords_table since we always re-init it 96 | relative_position_index_keys = [k for k in state_dict.keys() if "relative_coords_table" in k] 97 | for k in relative_position_index_keys: 98 | del state_dict[k] 99 | 100 | # delete attn_mask since we always re-init it 101 | attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k] 102 | for k in attn_mask_keys: 103 | del state_dict[k] 104 | 105 | # bicubic interpolate relative_position_bias_table if not match 106 | relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k] 107 | for k in relative_position_bias_table_keys: 108 | relative_position_bias_table_pretrained = state_dict[k] 109 | relative_position_bias_table_current = model.state_dict()[k] 110 | L1, nH1 = relative_position_bias_table_pretrained.size() 111 | L2, nH2 = relative_position_bias_table_current.size() 112 | if nH1 != nH2: 113 | logger.warning(f"Error in loading {k}, passing......") 114 | else: 115 | if L1 != L2: 116 | # bicubic interpolate relative_position_bias_table if not match 117 | S1 = int(L1 ** 0.5) 118 | S2 = int(L2 ** 0.5) 119 | relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( 120 | relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2), 121 | mode='bicubic') 122 | state_dict[k] = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0) 123 | 124 | # bicubic interpolate absolute_pos_embed if not match 125 | absolute_pos_embed_keys = [k for k in state_dict.keys() if "absolute_pos_embed" in k] 126 | for k in absolute_pos_embed_keys: 127 | # dpe 128 | absolute_pos_embed_pretrained = state_dict[k] 129 | absolute_pos_embed_current = model.state_dict()[k] 130 | _, L1, C1 = absolute_pos_embed_pretrained.size() 131 | _, L2, C2 = absolute_pos_embed_current.size() 132 | if C1 != C1: 133 | logger.warning(f"Error in loading {k}, passing......") 134 | else: 135 | if L1 != L2: 136 | S1 = int(L1 ** 0.5) 137 | S2 = int(L2 ** 0.5) 138 | absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1) 139 | absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2) 140 | absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate( 141 | absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic') 142 | absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1) 143 | absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.flatten(1, 2) 144 | state_dict[k] = absolute_pos_embed_pretrained_resized 145 | 146 | # check classifier, if not match, then re-init classifier to zero 147 | head_bias_pretrained = state_dict['head.bias'] 148 | Nc1 = head_bias_pretrained.shape[0] 149 | Nc2 = model.head.bias.shape[0] 150 | if (Nc1 != Nc2): 151 | if Nc1 == 21841 and Nc2 == 1000: 152 | logger.info("loading ImageNet-22K weight to ImageNet-1K ......") 153 | map22kto1k_path = f'data/map22kto1k.txt' 154 | with open(map22kto1k_path) as f: 155 | map22kto1k = f.readlines() 156 | map22kto1k = [int(id22k.strip()) for id22k in map22kto1k] 157 | state_dict['head.weight'] = state_dict['head.weight'][map22kto1k, :] 158 | state_dict['head.bias'] = state_dict['head.bias'][map22kto1k] 159 | else: 160 | torch.nn.init.constant_(model.head.bias, 0.) 161 | torch.nn.init.constant_(model.head.weight, 0.) 162 | del state_dict['head.weight'] 163 | del state_dict['head.bias'] 164 | logger.warning(f"Error in loading classifier head, re-init classifier head to 0") 165 | 166 | msg = model.load_state_dict(state_dict, strict=False) 167 | logger.warning(msg) 168 | 169 | logger.info(f"=> loaded successfully '{config.MODEL.PRETRAINED}'") 170 | 171 | del checkpoint 172 | torch.cuda.empty_cache() 173 | 174 | 175 | def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, loss_scaler, logger, 176 | zero_redundancy=False): 177 | global_rank = dist.get_rank() 178 | 179 | if zero_redundancy: 180 | if config.TRAIN.MOE.SAVE_MASTER: 181 | save_state = {'model': model.state_dict()} 182 | if global_rank == 0: 183 | save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth.global') 184 | logger.info(f"{save_path} saving......") 185 | torch.save(save_state, save_path) 186 | logger.info(f"{save_path} saved !!!") 187 | else: 188 | moe_model_state_dict, non_moe_model_state_dict = \ 189 | split_moe_model_state_dict(model._ddp_params_and_buffers_to_ignore, model.state_dict()) 190 | save_state = {'model': moe_model_state_dict} 191 | save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth.rank{global_rank}') 192 | logger.info(f"{save_path} saving......") 193 | torch.save(save_state, save_path) 194 | logger.info(f"{save_path} saved !!!") 195 | if global_rank == 0: 196 | save_state_master = {'model': non_moe_model_state_dict} 197 | save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth.master') 198 | logger.info(f"{save_path} saving......") 199 | torch.save(save_state_master, save_path) 200 | logger.info(f"{save_path} saved !!!") 201 | else: 202 | save_state = {'model': model.state_dict(), 203 | 'optimizer': optimizer.state_dict(), 204 | 'lr_scheduler': lr_scheduler.state_dict(), 205 | 'max_accuracy': max_accuracy, 206 | 'scaler': loss_scaler.state_dict(), 207 | 'epoch': epoch, 208 | 'config': config} 209 | if config.TRAIN.MOE.SAVE_MASTER: 210 | if global_rank == 0: 211 | save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth.global') 212 | logger.info(f"{save_path} saving......") 213 | torch.save(save_state, save_path) 214 | logger.info(f"{save_path} saved !!!") 215 | else: 216 | save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth.rank{global_rank}') 217 | logger.info(f"{save_path} saving......") 218 | torch.save(save_state, save_path) 219 | logger.info(f"{save_path} saved !!!") 220 | 221 | 222 | def auto_resume_helper(output_dir, save_master=False): 223 | global_rank = dist.get_rank() 224 | checkpoints = os.listdir(output_dir) 225 | if not save_master: 226 | master_checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith(f'pth.rank0')] 227 | else: 228 | master_checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith(f'pth.global')] 229 | print(f"All master checkpoints founded in {output_dir}: {master_checkpoints}") 230 | if len(master_checkpoints) > 0: 231 | latest_master_checkpoint = max([os.path.join(output_dir, d) for d in master_checkpoints], key=os.path.getmtime) 232 | latest_checkpoint = latest_master_checkpoint.replace('pth.rank0', f'pth.rank{global_rank}') 233 | print(f"The latest checkpoint founded: {latest_checkpoint}") 234 | resume_file = latest_checkpoint 235 | else: 236 | resume_file = None 237 | return resume_file 238 | 239 | 240 | def hook_scale_grad(scale, tensor): 241 | return tensor / scale 242 | -------------------------------------------------------------------------------- /utils_simmim.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # SimMIM 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # Modified by Zhenda Xie 7 | # -------------------------------------------------------- 8 | 9 | import os 10 | import torch 11 | import torch.distributed as dist 12 | import numpy as np 13 | from scipy import interpolate 14 | 15 | 16 | def load_checkpoint(config, model, optimizer, lr_scheduler, scaler, logger): 17 | logger.info(f">>>>>>>>>> Resuming from {config.MODEL.RESUME} ..........") 18 | if config.MODEL.RESUME.startswith('https'): 19 | checkpoint = torch.hub.load_state_dict_from_url( 20 | config.MODEL.RESUME, map_location='cpu', check_hash=True) 21 | else: 22 | checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu') 23 | 24 | # re-map keys due to name change (only for loading provided models) 25 | rpe_mlp_keys = [k for k in checkpoint['model'].keys() if "rpe_mlp" in k] 26 | for k in rpe_mlp_keys: 27 | checkpoint['model'][k.replace('rpe_mlp', 'cpb_mlp')] = checkpoint['model'].pop(k) 28 | 29 | msg = model.load_state_dict(checkpoint['model'], strict=False) 30 | logger.info(msg) 31 | 32 | max_accuracy = 0.0 33 | if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'scaler' in checkpoint and 'epoch' in checkpoint: 34 | optimizer.load_state_dict(checkpoint['optimizer']) 35 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 36 | scaler.load_state_dict(checkpoint['scaler']) 37 | 38 | config.defrost() 39 | config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1 40 | config.freeze() 41 | 42 | logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})") 43 | if 'max_accuracy' in checkpoint: 44 | max_accuracy = checkpoint['max_accuracy'] 45 | else: 46 | max_accuracy = 0.0 47 | 48 | del checkpoint 49 | torch.cuda.empty_cache() 50 | return max_accuracy 51 | 52 | 53 | def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, scaler, logger): 54 | save_state = {'model': model.state_dict(), 55 | 'optimizer': optimizer.state_dict(), 56 | 'lr_scheduler': lr_scheduler.state_dict(), 57 | 'scaler': scaler.state_dict(), 58 | 'max_accuracy': max_accuracy, 59 | 'epoch': epoch, 60 | 'config': config} 61 | 62 | save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth') 63 | logger.info(f"{save_path} saving......") 64 | torch.save(save_state, save_path) 65 | logger.info(f"{save_path} saved !!!") 66 | 67 | 68 | def get_grad_norm(parameters, norm_type=2): 69 | if isinstance(parameters, torch.Tensor): 70 | parameters = [parameters] 71 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 72 | norm_type = float(norm_type) 73 | total_norm = 0 74 | for p in parameters: 75 | param_norm = p.grad.data.norm(norm_type) 76 | total_norm += param_norm.item() ** norm_type 77 | total_norm = total_norm ** (1. / norm_type) 78 | return total_norm 79 | 80 | 81 | def auto_resume_helper(output_dir, logger): 82 | checkpoints = os.listdir(output_dir) 83 | checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')] 84 | logger.info(f"All checkpoints founded in {output_dir}: {checkpoints}") 85 | if len(checkpoints) > 0: 86 | latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime) 87 | logger.info(f"The latest checkpoint founded: {latest_checkpoint}") 88 | resume_file = latest_checkpoint 89 | else: 90 | resume_file = None 91 | return resume_file 92 | 93 | 94 | def reduce_tensor(tensor): 95 | rt = tensor.clone() 96 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 97 | rt /= dist.get_world_size() 98 | return rt 99 | 100 | 101 | def load_pretrained(config, model, logger): 102 | logger.info(f">>>>>>>>>> Fine-tuned from {config.MODEL.PRETRAINED} ..........") 103 | checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu') 104 | checkpoint_model = checkpoint['model'] 105 | 106 | if any([True if 'encoder.' in k else False for k in checkpoint_model.keys()]): 107 | checkpoint_model = {k.replace('encoder.', ''): v for k, v in checkpoint_model.items() if k.startswith('encoder.')} 108 | logger.info('Detect pre-trained model, remove [encoder.] prefix.') 109 | else: 110 | logger.info('Detect non-pre-trained model, pass without doing anything.') 111 | 112 | if config.MODEL.TYPE in ['swin', 'swinv2']: 113 | logger.info(f">>>>>>>>>> Remapping pre-trained keys for SWIN ..........") 114 | checkpoint = remap_pretrained_keys_swin(model, checkpoint_model, logger) 115 | else: 116 | raise NotImplementedError 117 | 118 | msg = model.load_state_dict(checkpoint_model, strict=False) 119 | logger.info(msg) 120 | 121 | del checkpoint 122 | torch.cuda.empty_cache() 123 | logger.info(f">>>>>>>>>> loaded successfully '{config.MODEL.PRETRAINED}'") 124 | 125 | 126 | def remap_pretrained_keys_swin(model, checkpoint_model, logger): 127 | state_dict = model.state_dict() 128 | 129 | # Geometric interpolation when pre-trained patch size mismatch with fine-tuned patch size 130 | all_keys = list(checkpoint_model.keys()) 131 | for key in all_keys: 132 | if "relative_position_bias_table" in key: 133 | relative_position_bias_table_pretrained = checkpoint_model[key] 134 | relative_position_bias_table_current = state_dict[key] 135 | L1, nH1 = relative_position_bias_table_pretrained.size() 136 | L2, nH2 = relative_position_bias_table_current.size() 137 | if nH1 != nH2: 138 | logger.info(f"Error in loading {key}, passing......") 139 | else: 140 | if L1 != L2: 141 | logger.info(f"{key}: Interpolate relative_position_bias_table using geo.") 142 | src_size = int(L1 ** 0.5) 143 | dst_size = int(L2 ** 0.5) 144 | 145 | def geometric_progression(a, r, n): 146 | return a * (1.0 - r ** n) / (1.0 - r) 147 | 148 | left, right = 1.01, 1.5 149 | while right - left > 1e-6: 150 | q = (left + right) / 2.0 151 | gp = geometric_progression(1, q, src_size // 2) 152 | if gp > dst_size // 2: 153 | right = q 154 | else: 155 | left = q 156 | 157 | # if q > 1.090307: 158 | # q = 1.090307 159 | 160 | dis = [] 161 | cur = 1 162 | for i in range(src_size // 2): 163 | dis.append(cur) 164 | cur += q ** (i + 1) 165 | 166 | r_ids = [-_ for _ in reversed(dis)] 167 | 168 | x = r_ids + [0] + dis 169 | y = r_ids + [0] + dis 170 | 171 | t = dst_size // 2.0 172 | dx = np.arange(-t, t + 0.1, 1.0) 173 | dy = np.arange(-t, t + 0.1, 1.0) 174 | 175 | logger.info("Original positions = %s" % str(x)) 176 | logger.info("Target positions = %s" % str(dx)) 177 | 178 | all_rel_pos_bias = [] 179 | 180 | for i in range(nH1): 181 | z = relative_position_bias_table_pretrained[:, i].view(src_size, src_size).float().numpy() 182 | f_cubic = interpolate.interp2d(x, y, z, kind='cubic') 183 | all_rel_pos_bias.append(torch.Tensor(f_cubic(dx, dy)).contiguous().view(-1, 1).to( 184 | relative_position_bias_table_pretrained.device)) 185 | 186 | new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) 187 | checkpoint_model[key] = new_rel_pos_bias 188 | 189 | # delete relative_position_index since we always re-init it 190 | relative_position_index_keys = [k for k in checkpoint_model.keys() if "relative_position_index" in k] 191 | for k in relative_position_index_keys: 192 | del checkpoint_model[k] 193 | 194 | # delete relative_coords_table since we always re-init it 195 | relative_coords_table_keys = [k for k in checkpoint_model.keys() if "relative_coords_table" in k] 196 | for k in relative_coords_table_keys: 197 | del checkpoint_model[k] 198 | 199 | # re-map keys due to name change 200 | rpe_mlp_keys = [k for k in checkpoint_model.keys() if "rpe_mlp" in k] 201 | for k in rpe_mlp_keys: 202 | checkpoint_model[k.replace('rpe_mlp', 'cpb_mlp')] = checkpoint_model.pop(k) 203 | 204 | # delete attn_mask since we always re-init it 205 | attn_mask_keys = [k for k in checkpoint_model.keys() if "attn_mask" in k] 206 | for k in attn_mask_keys: 207 | del checkpoint_model[k] 208 | 209 | return checkpoint_model 210 | --------------------------------------------------------------------------------