├── .gitignore ├── LICENSE ├── README.md ├── config.py ├── configs ├── finetune │ ├── fd_finetune__clip_vit_base__img224__100ep.yaml │ ├── fd_finetune__clip_vit_base__img224__300ep.yaml │ ├── fd_finetune__clip_vit_large__img224__300ep.yaml │ ├── fd_finetune__deit_vit_base__img224__300ep.yaml │ ├── fd_finetune__dino_vit_base__img224__300ep.yaml │ └── fd_finetune__esvit_swin_base__img224__300ep.yaml └── pretrain │ ├── fd_pretrain__clip_vit_base__img224__100ep.yaml │ ├── fd_pretrain__clip_vit_base__img224__300ep.yaml │ ├── fd_pretrain__clip_vit_large__img224__300ep.yaml │ ├── fd_pretrain__deit_vit_base__img224__300ep.yaml │ ├── fd_pretrain__dino_vit_base__img224__300ep.yaml │ └── fd_pretrain__esvit_swin_base__img224__300ep.yaml ├── data ├── __init__.py ├── cached_image_folder.py ├── data_fd.py ├── data_finetune.py ├── data_linear.py └── utils.py ├── figures └── teaser.jpg ├── logger.py ├── lr_scheduler.py ├── main_fd.py ├── main_finetune.py ├── main_linear.py ├── models ├── __init__.py ├── build.py ├── clip │ ├── __init__.py │ ├── clip.py │ ├── model.py │ ├── simple_tokenizer.py │ ├── utils.py │ └── vit.py ├── deit.py ├── dino.py ├── esvit.py ├── feature_distillation.py ├── swin_transformer.py ├── swin_transformer_v2.py ├── utils.py └── vision_transformer.py ├── optimizer.py ├── requirements.txt └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # dev files 132 | wandb/ 133 | output/ 134 | visualize/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Feature-Distillation 2 | 3 | By [Yixuan Wei](https://scholar.google.com/citations?user=xwudKb4AAAAJ&hl=en)\*, [Han Hu](https://ancientmooner.github.io/)\*, [Zhenda Xie](https://zdaxie.github.io), [Zheng Zhang](https://stupidzz.github.io/), [Yue Cao](http://yue-cao.me), [Jianmin Bao](https://jianminbao.github.io/), [Dong Chen](http://www.dongchen.pro) and [Baining Guo](https://scholar.google.com/citations?user=h4kYmRYAAAAJ&hl=en&oi=ao). 4 | 5 | This repo is the official implementation of ["Contrastive Learning Rivals Masked Image Modeling in Fine-tuning via Feature Distillation"](https://arxiv.org/abs/2205.14141). 6 | 7 | ## Updates 8 | ***11/30/2022*** 9 | 10 | 1. Distilled and fine-tuned models on ImageNet-1K (`ViT Large`) are provided. 11 | 12 | ***11/28/2022*** 13 | 14 | Initial commits: 15 | 16 | 1. Distilled and fine-tuned models on ImageNet-1K (`Swin Base`, and `ViT Base`) are provided. 17 | 2. The supported code for ImageNet-1K distillation and fine-tuning is provided. 18 | 19 | ## Introduction 20 | 21 | **FD** is initially described in [arxiv](https://arxiv.org/abs/2205.14141), which is a simple framework to convert the traditional pre-training models, such as image classification (DeiT), instance contrastive learning (DINO) and image-text alignment (CLIP) into new models with better fine-tuning performances. Through a set of diagosing tools, we find that the models distilled with feature map are endowed with following good properties which are also revealed in masked image modeling models: 1) more diverse attention heads; 2) more diagonal attention patterns; 3) flatten loss landscapes. 22 | 23 |
24 | 25 |
26 | 27 | ## Main Results on ImageNet 28 | 29 | ### Swin Transformer 30 | 31 | **ImageNet-1K Distilled and Fine-tuned Models** 32 | 33 | | name | distillation epochs | teacher model | image resolution | acc@1 | distilled model | fine-tuned model | 34 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | 35 | | Swin-Base | 300 | [EsViT-Base](https://github.com/microsoft/esvit) | 224x224 | 85.1 | [google](https://drive.google.com/file/d/11_GQUHgcrUO8PMzl73eJmLSa7f3c5dZY/view?usp=sharing)/[config](configs/pretrain/fd_pretrain__esvit_swin_base__img224__300ep.yaml) | [google](https://drive.google.com/file/d/1criliGcjpEJxqlsYRGBERBAMYrFYFW--/view?usp=sharing)/[config](configs/finetune/fd_finetune__esvit_swin_base__img224__300ep.yaml) | 36 | 37 | ### Vision Transformer 38 | 39 | **ImageNet-1K Distilled and Fine-tuned Models** 40 | 41 | | name | distillation epochs | teacher model | image resolution | acc@1 | distilled model | fine-tuned model | 42 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | 43 | | ViT-Base | 300 | [CLIP-Base](https://github.com/openai/CLIP) | 224x224 | 84.9 | [google](https://drive.google.com/file/d/1XFOZ6rJkv5X08Bu5d04_Xy3iJOj6SLc7/view?usp=sharing)/[config](configs/pretrain/fd_pretrain__clip_vit_base__img224__300ep.yaml) | [google](https://drive.google.com/file/d/1mP_JESmcdFeIkpB4aYyFzALtkydy_9qN/view?usp=sharing)/[config](configs/finetune/fd_finetune__clip_vit_base__img224__300ep.yaml) | 44 | | ViT-Base | 300 | [DINO-Base](https://github.com/facebookresearch/dino) | 224x224 | 83.8 | [google](https://drive.google.com/file/d/1fwBINMxpv5zFOI7Ye6l9msI8GzocpA3z/view?usp=sharing)/[config](configs/pretrain/fd_pretrain__dino_vit_base__img224__300ep.yaml) | [google](https://drive.google.com/file/d/1Mn_GgepfZXOe7W0UqEQMFo5MjJpMwM_i/view?usp=sharing)/[config](configs/finetune/fd_finetune__dino_vit_base__img224__300ep.yaml) | 45 | | ViT-Base | 300 | [DeiT-Base](https://github.com/facebookresearch/deit) | 224x224 | 83.0 | [google](https://drive.google.com/file/d/1yPezioDc4O6hdfD6VSAIU9DvJiXG4ZSJ/view?usp=sharing)/[config](configs/pretrain/fd_pretrain__deit_vit_base__img224__300ep.yaml) | [google](https://drive.google.com/file/d/1pb0KUlVcCaEGT-xnx6ookrqcC-88Ori5/view?usp=sharing)/[config](configs/finetune/fd_finetune__deit_vit_base__img224__300ep.yaml) | 46 | | ViT-Large | 300 | [CLIP-Large](https://github.com/openai/CLIP) | 224x224 | 87.7 | [google](https://drive.google.com/file/d/1H5USyzqwoS31JHDX874q8a70LdVD9zNY/view?usp=sharing)/[config](configs/pretrain/fd_pretrain__clip_vit_large__img224__300ep.yaml) | [google](https://drive.google.com/file/d/1XDDbDl9jzt8H2Fy6iZNfNA7Yjepf_MGx/view?usp=sharing)/[config](configs/finetune/fd_finetune__clip_vit_large__img224__300ep.yaml) | 47 | 48 | ## Citation 49 | 50 | If you find our work useful in your research, please cite: 51 | 52 | ``` 53 | @article{wei2022FD, 54 | title={Contrastive Learning Rivals Masked Image Modeling in Fine-tuning via Feature Distillation}, 55 | author={Yixuan Wei and Han Hu and Zhenda Xie and Zheng Zhang and Yue Cao and Jianmin Bao and Dong Chen and Baining Guo}, 56 | journal={Tech Report}, 57 | year={2022} 58 | } 59 | ``` 60 | 61 | ## Getting Started 62 | 63 | ### Installation 64 | 65 | - Install `CUDA 11.3` with `cuDNN 8` following the official installation guide of [CUDA](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html) and [cuDNN](https://developer.nvidia.com/rdp/cudnn-archive). 66 | 67 | - Setup conda environment: 68 | ```bash 69 | # Create environment 70 | conda create -n FD python=3.8 -y 71 | conda activate FD 72 | 73 | # Install requirements 74 | pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113 75 | 76 | # Clone codes 77 | git clone https://github.com/SwinTransformer/Feature-Distillation 78 | cd Feature-Distillation 79 | 80 | # Install other requirements 81 | pip install -r requirements.txt 82 | ``` 83 | 84 | ### Feature-Distillation 85 | To distill models, run: 86 | ```bash 87 | python -m torch.distributed.launch --nproc_per_node main_fd.py \ 88 | --cfg --data-path /train [--batch-size --output --tag ] 89 | ``` 90 | 91 | For example, to distill `CLIP-Base` for 300 epochs on one DGX-2 server, run: 92 | ```bash 93 | python -m torch.distributed.launch --nproc_per_node=16 main_fd.py --cfg configs/pretrain/fd_pretrain__clip_vit_base__img224__300ep.yaml --batch-size 128 --data-path /train [--output --tag ] 94 | ``` 95 | 96 | If you want to save gpu memory consumption, add `--use-checkpoint`. 97 | 98 | ### Fine-tuning distilled models 99 | To fine-tune distilled models, run: 100 | ```bash 101 | python -m torch.distributed.launch --nproc_per_node main_finetune.py \ 102 | --cfg --data-path --pretrained [--batch-size --output --tag ] 103 | ``` 104 | 105 | For example, to fine-tune `Distilled-CLIP-Base` on one DGX-2 server, run: 106 | ```bash 107 | python -m torch.distributed.launch --nproc_per_node 16 main_finetune.py \ 108 | --cfg configs/finetune/fd_finetune__clip_vit_base__img224__300ep.yaml --batch-size 128 --data-path --pretrained [--output --tag ] 109 | ``` -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Copyright (c) 2021 Microsoft 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Ze Liu 5 | # Modified by Zhenda Xie 6 | # Modified by Yixuan Wei 7 | # -------------------------------------------------------- 8 | 9 | import os 10 | import yaml 11 | from yacs.config import CfgNode as CN 12 | 13 | _C = CN() 14 | 15 | # Base config files 16 | _C.BASE = [''] 17 | 18 | # ----------------------------------------------------------------------------- 19 | # Dev settings 20 | # ----------------------------------------------------------------------------- 21 | _C.DEV = CN() 22 | # Relative Coords Table Type 23 | _C.DEV.RCT_TYPE = 'norm8_log' 24 | _C.DEV.CHECKPOINT_BLOCKS = [0,0,0,0] 25 | 26 | # [Feature Distillation] 27 | _C.DEV.PRED_FEAT = '' 28 | _C.DEV.PRED_FEAT_AFTERNORM = False # whether to use feature after norm 29 | _C.DEV.PRED_FEAT_S3 = False # when use swin as target and vit as student, use stage 3 feature as tgt for token number with 14*14 30 | _C.DEV.VIT_WITHKBIAS = False 31 | _C.DEV.FT_SKIP_REMAP = False 32 | 33 | # ----------------------------------------------------------------------------- 34 | # Data settings 35 | # ----------------------------------------------------------------------------- 36 | _C.DATA = CN() 37 | # Batch size for a single GPU, could be overwritten by command line argument 38 | _C.DATA.BATCH_SIZE = 128 39 | # Path to dataset, could be overwritten by command line argument 40 | _C.DATA.DATA_PATH = '' 41 | # Dataset name 42 | _C.DATA.DATASET = 'imagenet' 43 | # Input image size 44 | _C.DATA.IMG_SIZE = 224 45 | # Interpolation to resize image (random, bilinear, bicubic) 46 | _C.DATA.INTERPOLATION = 'bicubic' 47 | # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU. 48 | _C.DATA.PIN_MEMORY = True 49 | # Number of data loading threads 50 | _C.DATA.NUM_WORKERS = 8 51 | # Zip Mode as in Swin Transformer 52 | _C.DATA.ZIP_MODE = False 53 | 54 | # ----------------------------------------------------------------------------- 55 | # Model settings 56 | # ----------------------------------------------------------------------------- 57 | _C.MODEL = CN() 58 | # Model type 59 | _C.MODEL.TYPE = 'swin' 60 | # Model name 61 | _C.MODEL.NAME = 'swin_tiny_patch4_window7_224' 62 | # Checkpoint to resume, could be overwritten by command line argument 63 | _C.MODEL.RESUME = '' 64 | # Number of classes, overwritten in data preparation 65 | _C.MODEL.NUM_CLASSES = 1000 66 | # Dropout rate 67 | _C.MODEL.DROP_RATE = 0.0 68 | # Drop path rate 69 | _C.MODEL.DROP_PATH_RATE = 0.1 70 | # Label Smoothing 71 | _C.MODEL.LABEL_SMOOTHING = 0.1 72 | 73 | # Swin Transformer parameters 74 | _C.MODEL.SWIN = CN() 75 | _C.MODEL.SWIN.PATCH_SIZE = 4 76 | _C.MODEL.SWIN.IN_CHANS = 3 77 | _C.MODEL.SWIN.EMBED_DIM = 96 78 | _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2] 79 | _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24] 80 | _C.MODEL.SWIN.WINDOW_SIZE = 7 81 | _C.MODEL.SWIN.MLP_RATIO = 4. 82 | _C.MODEL.SWIN.QKV_BIAS = True 83 | _C.MODEL.SWIN.QK_SCALE = None 84 | _C.MODEL.SWIN.USE_SHARED_RPB = False 85 | _C.MODEL.SWIN.APE = False 86 | _C.MODEL.SWIN.PATCH_NORM = True 87 | 88 | # Vision Transformer parameters 89 | _C.MODEL.VIT = CN() 90 | _C.MODEL.VIT.PATCH_SIZE = 16 91 | _C.MODEL.VIT.IN_CHANS = 3 92 | _C.MODEL.VIT.EMBED_DIM = 768 93 | _C.MODEL.VIT.DEPTH = 12 94 | _C.MODEL.VIT.NUM_HEADS = 12 95 | _C.MODEL.VIT.MLP_RATIO = 4 96 | _C.MODEL.VIT.QKV_BIAS = True 97 | _C.MODEL.VIT.INIT_VALUES = 0.1 98 | _C.MODEL.VIT.USE_APE = False 99 | _C.MODEL.VIT.USE_RPB = False 100 | _C.MODEL.VIT.USE_SHARED_RPB = True 101 | _C.MODEL.VIT.USE_MEAN_POOLING = False 102 | _C.MODEL.VIT.ATTN_TYPE = 'normal' 103 | _C.MODEL.VIT.WITH_CLS_TOKEN = True 104 | 105 | 106 | # ----------------------------------------------------------------------------- 107 | # Training settings 108 | # ----------------------------------------------------------------------------- 109 | _C.TRAIN = CN() 110 | _C.TRAIN.START_EPOCH = 0 111 | _C.TRAIN.EPOCHS = 300 112 | _C.TRAIN.WARMUP_EPOCHS = 20 113 | _C.TRAIN.WARMUP_EPOCHS_FINE = 0.0 # incase of less than 1ep warmup 114 | _C.TRAIN.WEIGHT_DECAY = 0.05 115 | _C.TRAIN.BASE_LR = 5e-4 116 | _C.TRAIN.WARMUP_LR = 5e-7 117 | _C.TRAIN.MIN_LR = 5e-6 118 | # Clip gradient norm 119 | _C.TRAIN.CLIP_GRAD = 5.0 120 | # Auto resume from latest checkpoint 121 | _C.TRAIN.AUTO_RESUME = True 122 | # Gradient accumulation steps 123 | # could be overwritten by command line argument 124 | _C.TRAIN.ACCUMULATION_STEPS = 0 125 | # Whether to use gradient checkpointing to save memory 126 | # could be overwritten by command line argument 127 | _C.TRAIN.USE_CHECKPOINT = False 128 | 129 | # LR scheduler 130 | _C.TRAIN.LR_SCHEDULER = CN() 131 | _C.TRAIN.LR_SCHEDULER.NAME = 'cosine' 132 | # Epoch interval to decay LR, used in StepLRScheduler 133 | _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 134 | # LR decay rate, used in StepLRScheduler 135 | _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 136 | # Gamma / Multi steps value, used in MultiStepLRScheduler 137 | _C.TRAIN.LR_SCHEDULER.GAMMA = 0.1 138 | _C.TRAIN.LR_SCHEDULER.MULTISTEPS = [] 139 | 140 | # Optimizer 141 | _C.TRAIN.OPTIMIZER = CN() 142 | _C.TRAIN.OPTIMIZER.NAME = 'adamw' 143 | # Optimizer Epsilon 144 | _C.TRAIN.OPTIMIZER.EPS = 1e-8 145 | # Optimizer Betas 146 | _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) 147 | # SGD momentum 148 | _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 149 | 150 | # Layer decay for fine-tuning 151 | _C.TRAIN.LAYER_DECAY = 1.0 152 | 153 | # ----------------------------------------------------------------------------- 154 | # Augmentation settings 155 | # ----------------------------------------------------------------------------- 156 | _C.AUG = CN() 157 | # Color jitter factor 158 | _C.AUG.COLOR_JITTER = 0.4 159 | # Use AutoAugment policy. "v0" or "original" 160 | _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' 161 | # Random erase prob 162 | _C.AUG.REPROB = 0.25 163 | # Random erase mode 164 | _C.AUG.REMODE = 'pixel' 165 | # Random erase count 166 | _C.AUG.RECOUNT = 1 167 | # Mixup alpha, mixup enabled if > 0 168 | _C.AUG.MIXUP = 0.8 169 | # Cutmix alpha, cutmix enabled if > 0 170 | _C.AUG.CUTMIX = 1.0 171 | # Cutmix min/max ratio, overrides alpha and enables cutmix if set 172 | _C.AUG.CUTMIX_MINMAX = None 173 | # Probability of performing mixup or cutmix when either/both is enabled 174 | _C.AUG.MIXUP_PROB = 1.0 175 | # Probability of switching to cutmix when both mixup and cutmix enabled 176 | _C.AUG.MIXUP_SWITCH_PROB = 0.5 177 | # How to apply mixup/cutmix params. Per "batch", "pair", or "elem" 178 | _C.AUG.MIXUP_MODE = 'batch' 179 | _C.AUG.MAX_SCALE = 1.0 180 | _C.AUG.MIN_SCALE = 0.67 181 | 182 | # ----------------------------------------------------------------------------- 183 | # Testing settings 184 | # ----------------------------------------------------------------------------- 185 | _C.TEST = CN() 186 | # Whether to use center crop when testing 187 | _C.TEST.CROP = True 188 | 189 | # ----------------------------------------------------------------------------- 190 | # Misc 191 | # ----------------------------------------------------------------------------- 192 | # Whether to enable pytorch amp, overwritten by command line argument 193 | _C.ENABLE_AMP = False 194 | # Path to output folder, overwritten by command line argument 195 | _C.OUTPUT = '' 196 | # Tag of experiment, overwritten by command line argument 197 | _C.TAG = 'default' 198 | # Frequency to save checkpoint 199 | _C.SAVE_FREQ = 1 200 | # Frequency to logging info 201 | _C.PRINT_FREQ = 10 202 | # Fixed random seed 203 | _C.SEED = 0 204 | # Perform evaluation only, overwritten by command line argument 205 | _C.EVAL_MODE = False 206 | # Test throughput only, overwritten by command line argument 207 | _C.THROUGHPUT_MODE = False 208 | # local rank for DistributedDataParallel, given by command line argument 209 | _C.LOCAL_RANK = 0 210 | 211 | # path to pre-trained model 212 | _C.PRETRAINED = '' 213 | 214 | 215 | def _update_config_from_file(config, cfg_file): 216 | config.defrost() 217 | with open(cfg_file, 'r') as f: 218 | yaml_cfg = yaml.load(f, Loader=yaml.FullLoader) 219 | 220 | for cfg in yaml_cfg.setdefault('BASE', ['']): 221 | if cfg: 222 | _update_config_from_file( 223 | config, os.path.join(os.path.dirname(cfg_file), cfg) 224 | ) 225 | print('=> merge config from {}'.format(cfg_file)) 226 | config.merge_from_file(cfg_file) 227 | config.freeze() 228 | 229 | 230 | def update_config(config, args): 231 | _update_config_from_file(config, args.cfg) 232 | 233 | config.defrost() 234 | if args.opts: 235 | config.merge_from_list(args.opts) 236 | 237 | def _check_args(name): 238 | if hasattr(args, name) and eval(f'args.{name}'): 239 | return True 240 | return False 241 | 242 | # merge from specific arguments 243 | if _check_args('batch_size'): 244 | config.DATA.BATCH_SIZE = args.batch_size 245 | if _check_args('data_path'): 246 | config.DATA.DATA_PATH = args.data_path 247 | if _check_args('resume'): 248 | config.MODEL.RESUME = args.resume 249 | if _check_args('pretrained'): 250 | config.PRETRAINED = args.pretrained 251 | if _check_args('accumulation_steps'): 252 | config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps 253 | if _check_args('use_checkpoint'): 254 | config.TRAIN.USE_CHECKPOINT = True 255 | if _check_args('enable_amp'): 256 | config.ENABLE_AMP = args.enable_amp 257 | if _check_args('output'): 258 | config.OUTPUT = args.output 259 | if _check_args('tag'): 260 | config.TAG = args.tag 261 | if _check_args('eval'): 262 | config.EVAL_MODE = True 263 | if _check_args('throughput'): 264 | config.THROUGHPUT_MODE = True 265 | 266 | # set local rank for distributed training 267 | config.LOCAL_RANK = args.local_rank 268 | 269 | # output folder 270 | config.OUTPUT = os.path.join(config.OUTPUT, config.TAG) 271 | 272 | config.freeze() 273 | 274 | 275 | def get_config(args): 276 | """Get a yacs CfgNode object with default values.""" 277 | # Return a clone so that the defaults will not be altered 278 | # This is for the "local variable" use pattern 279 | config = _C.clone() 280 | update_config(config, args) 281 | 282 | return config 283 | -------------------------------------------------------------------------------- /configs/finetune/fd_finetune__clip_vit_base__img224__100ep.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vit 3 | NAME: fd_finetune 4 | DROP_PATH_RATE: 0.1 5 | VIT: 6 | EMBED_DIM: 768 7 | DEPTH: 12 8 | NUM_HEADS: 12 9 | USE_APE: False 10 | USE_RPB: True 11 | USE_SHARED_RPB: False 12 | USE_MEAN_POOLING: True 13 | WITH_CLS_TOKEN: True 14 | DATA: 15 | IMG_SIZE: 224 16 | BATCH_SIZE: 128 17 | TRAIN: 18 | EPOCHS: 100 19 | WARMUP_EPOCHS: 20 20 | BASE_LR: 1.25e-3 21 | WARMUP_LR: 2.5e-7 22 | MIN_LR: 2.5e-7 23 | WEIGHT_DECAY: 0.05 24 | LAYER_DECAY: 0.65 25 | PRINT_FREQ: 100 26 | SAVE_FREQ: 5 27 | TAG: fd_finetune__clip_vit_base__img224__100ep -------------------------------------------------------------------------------- /configs/finetune/fd_finetune__clip_vit_base__img224__300ep.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vit 3 | NAME: fd_finetune 4 | DROP_PATH_RATE: 0.3 5 | VIT: 6 | EMBED_DIM: 768 7 | DEPTH: 12 8 | NUM_HEADS: 12 9 | USE_APE: False 10 | USE_RPB: True 11 | USE_SHARED_RPB: False 12 | USE_MEAN_POOLING: True 13 | WITH_CLS_TOKEN: True 14 | DATA: 15 | IMG_SIZE: 224 16 | BATCH_SIZE: 128 17 | TRAIN: 18 | EPOCHS: 100 19 | WARMUP_EPOCHS: 20 20 | BASE_LR: 1.25e-3 21 | WARMUP_LR: 2.5e-7 22 | MIN_LR: 2.5e-7 23 | WEIGHT_DECAY: 0.05 24 | LAYER_DECAY: 0.6 25 | PRINT_FREQ: 100 26 | SAVE_FREQ: 5 27 | TAG: fd_finetune__clip_vit_base__img224__300ep -------------------------------------------------------------------------------- /configs/finetune/fd_finetune__clip_vit_large__img224__300ep.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vit 3 | NAME: fd_finetune 4 | DROP_PATH_RATE: 0.4 5 | VIT: 6 | EMBED_DIM: 1024 7 | DEPTH: 24 8 | NUM_HEADS: 16 9 | USE_APE: False 10 | USE_RPB: True 11 | USE_SHARED_RPB: False 12 | USE_MEAN_POOLING: True 13 | WITH_CLS_TOKEN: True 14 | PATCH_SIZE: 14 15 | DATA: 16 | IMG_SIZE: 224 17 | BATCH_SIZE: 128 18 | TRAIN: 19 | EPOCHS: 50 20 | WARMUP_EPOCHS: 5 21 | BASE_LR: 2.5e-4 22 | WARMUP_LR: 5.0e-7 23 | MIN_LR: 5.0e-7 24 | WEIGHT_DECAY: 0.05 25 | LAYER_DECAY: 0.75 26 | PRINT_FREQ: 100 27 | SAVE_FREQ: 5 28 | TAG: fd_finetune__clip_vit_large__img224__300ep -------------------------------------------------------------------------------- /configs/finetune/fd_finetune__deit_vit_base__img224__300ep.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vit 3 | NAME: fd_finetune 4 | DROP_PATH_RATE: 0.3 5 | VIT: 6 | EMBED_DIM: 768 7 | DEPTH: 12 8 | NUM_HEADS: 12 9 | USE_APE: False 10 | USE_RPB: True 11 | USE_SHARED_RPB: False 12 | USE_MEAN_POOLING: True 13 | WITH_CLS_TOKEN: True 14 | DATA: 15 | IMG_SIZE: 224 16 | BATCH_SIZE: 128 17 | TRAIN: 18 | EPOCHS: 100 19 | WARMUP_EPOCHS: 20 20 | BASE_LR: 1.25e-3 21 | WARMUP_LR: 2.5e-7 22 | MIN_LR: 2.5e-7 23 | WEIGHT_DECAY: 0.05 24 | LAYER_DECAY: 0.65 25 | PRINT_FREQ: 100 26 | SAVE_FREQ: 5 27 | TAG: fd_finetune__deit_vit_base__img224__300ep -------------------------------------------------------------------------------- /configs/finetune/fd_finetune__dino_vit_base__img224__300ep.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vit 3 | NAME: fd_finetune 4 | DROP_PATH_RATE: 0.2 5 | VIT: 6 | EMBED_DIM: 768 7 | DEPTH: 12 8 | NUM_HEADS: 12 9 | USE_APE: False 10 | USE_RPB: True 11 | USE_SHARED_RPB: False 12 | USE_MEAN_POOLING: True 13 | WITH_CLS_TOKEN: True 14 | DATA: 15 | IMG_SIZE: 224 16 | BATCH_SIZE: 128 17 | TRAIN: 18 | EPOCHS: 100 19 | WARMUP_EPOCHS: 20 20 | BASE_LR: 1.5e-3 21 | WARMUP_LR: 2.5e-7 22 | MIN_LR: 2.5e-7 23 | WEIGHT_DECAY: 0.05 24 | LAYER_DECAY: 0.6 25 | PRINT_FREQ: 100 26 | SAVE_FREQ: 5 27 | TAG: fd_finetune__dino_vit_base__img224__300ep -------------------------------------------------------------------------------- /configs/finetune/fd_finetune__esvit_swin_base__img224__300ep.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin_v2 3 | NAME: fd_finetune 4 | DROP_PATH_RATE: 0.4 5 | SWIN: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 4, 8, 16, 32 ] 9 | WINDOW_SIZE: 14 10 | DATA: 11 | IMG_SIZE: 224 12 | BATCH_SIZE: 64 # here is just a random mistake in our experiments. we believe 128*16=2048 will lead to similar results 13 | TRAIN: 14 | EPOCHS: 100 15 | WARMUP_EPOCHS: 20 16 | BASE_LR: 1.25e-3 17 | WARMUP_LR: 2.5e-7 18 | MIN_LR: 2.5e-7 19 | WEIGHT_DECAY: 0.05 20 | LAYER_DECAY: 0.8 21 | PRINT_FREQ: 100 22 | SAVE_FREQ: 5 23 | TAG: fd_finetune__esvit_swin_base__img224__300ep -------------------------------------------------------------------------------- /configs/pretrain/fd_pretrain__clip_vit_base__img224__100ep.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vit 3 | NAME: fd_pretrain 4 | DROP_PATH_RATE: 0.1 5 | VIT: 6 | EMBED_DIM: 768 7 | DEPTH: 12 8 | NUM_HEADS: 12 9 | USE_APE: False 10 | USE_RPB: False 11 | USE_SHARED_RPB: True 12 | USE_MEAN_POOLING: False 13 | WITH_CLS_TOKEN: True 14 | DATA: 15 | IMG_SIZE: 224 16 | BATCH_SIZE: 128 17 | TRAIN: 18 | EPOCHS: 100 19 | WARMUP_EPOCHS: 10 20 | BASE_LR: 3e-4 21 | WARMUP_LR: 5e-7 22 | MIN_LR: 5e-6 23 | WEIGHT_DECAY: 0.05 24 | CLIP_GRAD: 3.0 25 | DEV: 26 | PRED_FEAT: CLIP_400M 27 | PRED_FEAT_AFTERNORM: True 28 | PRINT_FREQ: 100 29 | SAVE_FREQ: 5 30 | TAG: fd_pretrain__clip_vit_base__img224__100ep 31 | -------------------------------------------------------------------------------- /configs/pretrain/fd_pretrain__clip_vit_base__img224__300ep.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vit 3 | NAME: fd_pretrain 4 | DROP_PATH_RATE: 0.2 5 | VIT: 6 | EMBED_DIM: 768 7 | DEPTH: 12 8 | NUM_HEADS: 12 9 | USE_APE: False 10 | USE_RPB: False 11 | USE_SHARED_RPB: True 12 | USE_MEAN_POOLING: False 13 | WITH_CLS_TOKEN: True 14 | DATA: 15 | IMG_SIZE: 224 16 | BATCH_SIZE: 128 17 | AUG: 18 | MIN_SCALE: 0.08 19 | TRAIN: 20 | EPOCHS: 300 21 | WARMUP_EPOCHS: 10 22 | BASE_LR: 3e-4 23 | WARMUP_LR: 5e-7 24 | MIN_LR: 5e-6 25 | WEIGHT_DECAY: 0.05 26 | CLIP_GRAD: 3.0 27 | DEV: 28 | PRED_FEAT: CLIP_400M 29 | PRED_FEAT_AFTERNORM: True 30 | PRINT_FREQ: 100 31 | SAVE_FREQ: 5 32 | TAG: fd_pretrain__clip_vit_base__img224__300ep 33 | -------------------------------------------------------------------------------- /configs/pretrain/fd_pretrain__clip_vit_large__img224__300ep.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vit 3 | NAME: fd_pretrain 4 | DROP_PATH_RATE: 0.3 5 | VIT: 6 | EMBED_DIM: 1024 7 | DEPTH: 24 8 | NUM_HEADS: 16 9 | USE_APE: False 10 | USE_RPB: False 11 | USE_SHARED_RPB: True 12 | USE_MEAN_POOLING: False 13 | WITH_CLS_TOKEN: True 14 | PATCH_SIZE: 14 15 | DATA: 16 | IMG_SIZE: 224 17 | BATCH_SIZE: 128 18 | AUG: 19 | MIN_SCALE: 0.08 20 | TRAIN: 21 | EPOCHS: 300 22 | WARMUP_EPOCHS: 10 23 | BASE_LR: 3e-4 24 | WARMUP_LR: 5e-7 25 | MIN_LR: 5e-6 26 | WEIGHT_DECAY: 0.05 27 | CLIP_GRAD: 3.0 28 | DEV: 29 | PRED_FEAT: CLIP_400M_Large 30 | PRED_FEAT_AFTERNORM: True 31 | PRINT_FREQ: 100 32 | SAVE_FREQ: 5 33 | TAG: fd_pretrain__clip_vit_large__img224__300ep 34 | -------------------------------------------------------------------------------- /configs/pretrain/fd_pretrain__deit_vit_base__img224__300ep.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vit 3 | NAME: fd_pretrain 4 | DROP_PATH_RATE: 0.3 5 | VIT: 6 | EMBED_DIM: 768 7 | DEPTH: 12 8 | NUM_HEADS: 12 9 | USE_APE: False 10 | USE_RPB: False 11 | USE_SHARED_RPB: True 12 | USE_MEAN_POOLING: False 13 | WITH_CLS_TOKEN: True 14 | DATA: 15 | IMG_SIZE: 224 16 | BATCH_SIZE: 128 17 | AUG: 18 | MIN_SCALE: 0.08 19 | TRAIN: 20 | EPOCHS: 300 21 | WARMUP_EPOCHS: 10 22 | BASE_LR: 3e-4 23 | WARMUP_LR: 5e-7 24 | MIN_LR: 5e-6 25 | WEIGHT_DECAY: 0.05 26 | CLIP_GRAD: 3.0 27 | DEV: 28 | PRED_FEAT: DEIT 29 | PRED_FEAT_AFTERNORM: True 30 | PRINT_FREQ: 100 31 | SAVE_FREQ: 5 32 | TAG: fd_pretrain__deit_vit_base__img224__300ep 33 | -------------------------------------------------------------------------------- /configs/pretrain/fd_pretrain__dino_vit_base__img224__300ep.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: vit 3 | NAME: fd_pretrain 4 | DROP_PATH_RATE: 0.3 5 | VIT: 6 | EMBED_DIM: 768 7 | DEPTH: 12 8 | NUM_HEADS: 12 9 | USE_APE: False 10 | USE_RPB: False 11 | USE_SHARED_RPB: True 12 | USE_MEAN_POOLING: False 13 | WITH_CLS_TOKEN: True 14 | DATA: 15 | IMG_SIZE: 224 16 | BATCH_SIZE: 128 17 | AUG: 18 | MIN_SCALE: 0.08 19 | TRAIN: 20 | EPOCHS: 300 21 | WARMUP_EPOCHS: 10 22 | BASE_LR: 3e-4 23 | WARMUP_LR: 5e-7 24 | MIN_LR: 5e-6 25 | WEIGHT_DECAY: 0.05 26 | CLIP_GRAD: 3.0 27 | DEV: 28 | PRED_FEAT: DINO 29 | PRED_FEAT_AFTERNORM: True 30 | PRINT_FREQ: 100 31 | SAVE_FREQ: 5 32 | TAG: fd_pretrain__dino_vit_base__img224__300ep 33 | -------------------------------------------------------------------------------- /configs/pretrain/fd_pretrain__esvit_swin_base__img224__300ep.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: swin_v2 3 | NAME: fd_pretrain 4 | DROP_PATH_RATE: 0.3 5 | SWIN: 6 | EMBED_DIM: 128 7 | DEPTHS: [ 2, 2, 18, 2 ] 8 | NUM_HEADS: [ 4, 8, 16, 32 ] 9 | WINDOW_SIZE: 14 10 | DATA: 11 | IMG_SIZE: 224 12 | BATCH_SIZE: 128 13 | AUG: 14 | MIN_SCALE: 0.08 15 | TRAIN: 16 | EPOCHS: 300 17 | WARMUP_EPOCHS: 10 18 | BASE_LR: 2e-4 19 | WARMUP_LR: 1e-6 20 | MIN_LR: 1e-5 21 | WEIGHT_DECAY: 0.05 22 | CLIP_GRAD: 3.0 23 | DEV: 24 | PRED_FEAT: ESVIT 25 | PRED_FEAT_AFTERNORM: True 26 | PRINT_FREQ: 100 27 | SAVE_FREQ: 5 28 | TAG: fd_pretrain__esvit_swin_base__img224__300ep 29 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_fd import build_loader_fd 2 | from .data_finetune import build_loader_finetune 3 | 4 | def build_loader(config, logger, is_pretrain): 5 | if is_pretrain: 6 | return build_loader_fd(config, logger) 7 | else: 8 | return build_loader_finetune(config, logger) -------------------------------------------------------------------------------- /data/cached_image_folder.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Swin Transformer 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # -------------------------------------------------------- 7 | 8 | import io 9 | import os 10 | import time 11 | import torch.distributed as dist 12 | import torch.utils.data as data 13 | from PIL import Image 14 | 15 | from .utils import is_zip_path, ZipReader 16 | 17 | 18 | def has_file_allowed_extension(filename, extensions): 19 | """Checks if a file is an allowed extension. 20 | Args: 21 | filename (string): path to a file 22 | Returns: 23 | bool: True if the filename ends with a known image extension 24 | """ 25 | filename_lower = filename.lower() 26 | return any(filename_lower.endswith(ext) for ext in extensions) 27 | 28 | 29 | def find_classes(dir): 30 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 31 | classes.sort() 32 | class_to_idx = {classes[i]: i for i in range(len(classes))} 33 | return classes, class_to_idx 34 | 35 | 36 | def make_dataset(dir, class_to_idx, extensions): 37 | images = [] 38 | dir = os.path.expanduser(dir) 39 | for target in sorted(os.listdir(dir)): 40 | d = os.path.join(dir, target) 41 | if not os.path.isdir(d): 42 | continue 43 | 44 | for root, _, fnames in sorted(os.walk(d)): 45 | for fname in sorted(fnames): 46 | if has_file_allowed_extension(fname, extensions): 47 | path = os.path.join(root, fname) 48 | item = (path, class_to_idx[target]) 49 | images.append(item) 50 | 51 | return images 52 | 53 | 54 | def make_dataset_with_ann(ann_file, img_prefix, extensions): 55 | images = [] 56 | with open(ann_file, "r") as f: 57 | contents = f.readlines() 58 | for line_str in contents: 59 | path_contents = [c for c in line_str.split('\t')] 60 | im_file_name = path_contents[0] 61 | class_index = int(path_contents[1]) 62 | 63 | assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions 64 | item = (os.path.join(img_prefix, im_file_name), class_index) 65 | 66 | images.append(item) 67 | 68 | return images 69 | 70 | 71 | class DatasetFolder(data.Dataset): 72 | """A generic data loader where the samples are arranged in this way: :: 73 | root/class_x/xxx.ext 74 | root/class_x/xxy.ext 75 | root/class_x/xxz.ext 76 | root/class_y/123.ext 77 | root/class_y/nsdf3.ext 78 | root/class_y/asd932_.ext 79 | Args: 80 | root (string): Root directory path. 81 | loader (callable): A function to load a sample given its path. 82 | extensions (list[string]): A list of allowed extensions. 83 | transform (callable, optional): A function/transform that takes in 84 | a sample and returns a transformed version. 85 | E.g, ``transforms.RandomCrop`` for images. 86 | target_transform (callable, optional): A function/transform that takes 87 | in the target and transforms it. 88 | Attributes: 89 | samples (list): List of (sample path, class_index) tuples 90 | """ 91 | 92 | def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None, 93 | cache_mode="no"): 94 | # image folder mode 95 | if ann_file == '': 96 | _, class_to_idx = find_classes(root) 97 | samples = make_dataset(root, class_to_idx, extensions) 98 | # zip mode 99 | else: 100 | samples = make_dataset_with_ann(os.path.join(root, ann_file), 101 | os.path.join(root, img_prefix), 102 | extensions) 103 | 104 | if len(samples) == 0: 105 | raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" + 106 | "Supported extensions are: " + ",".join(extensions))) 107 | 108 | self.root = root 109 | self.loader = loader 110 | self.extensions = extensions 111 | 112 | self.samples = samples 113 | self.labels = [y_1k for _, y_1k in samples] 114 | self.classes = list(set(self.labels)) 115 | 116 | self.transform = transform 117 | self.target_transform = target_transform 118 | 119 | self.cache_mode = cache_mode 120 | if self.cache_mode != "no": 121 | self.init_cache() 122 | 123 | def init_cache(self): 124 | assert self.cache_mode in ["part", "full"] 125 | n_sample = len(self.samples) 126 | global_rank = dist.get_rank() 127 | world_size = dist.get_world_size() 128 | 129 | samples_bytes = [None for _ in range(n_sample)] 130 | start_time = time.time() 131 | for index in range(n_sample): 132 | if index % (n_sample // 10) == 0: 133 | t = time.time() - start_time 134 | print(f'global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block') 135 | start_time = time.time() 136 | path, target = self.samples[index] 137 | if self.cache_mode == "full": 138 | samples_bytes[index] = (ZipReader.read(path), target) 139 | elif self.cache_mode == "part" and index % world_size == global_rank: 140 | samples_bytes[index] = (ZipReader.read(path), target) 141 | else: 142 | samples_bytes[index] = (path, target) 143 | self.samples = samples_bytes 144 | 145 | def __getitem__(self, index): 146 | """ 147 | Args: 148 | index (int): Index 149 | Returns: 150 | tuple: (sample, target) where target is class_index of the target class. 151 | """ 152 | path, target = self.samples[index] 153 | sample = self.loader(path) 154 | if self.transform is not None: 155 | sample = self.transform(sample) 156 | if self.target_transform is not None: 157 | target = self.target_transform(target) 158 | 159 | return sample, target 160 | 161 | def __len__(self): 162 | return len(self.samples) 163 | 164 | def __repr__(self): 165 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 166 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 167 | fmt_str += ' Root Location: {}\n'.format(self.root) 168 | tmp = ' Transforms (if any): ' 169 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 170 | tmp = ' Target Transforms (if any): ' 171 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 172 | return fmt_str 173 | 174 | 175 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] 176 | 177 | 178 | def pil_loader(path): 179 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 180 | if isinstance(path, bytes): 181 | img = Image.open(io.BytesIO(path)) 182 | elif is_zip_path(path): 183 | data = ZipReader.read(path) 184 | img = Image.open(io.BytesIO(data)) 185 | else: 186 | with open(path, 'rb') as f: 187 | img = Image.open(f) 188 | return img.convert('RGB') 189 | 190 | 191 | def accimage_loader(path): 192 | import accimage 193 | try: 194 | return accimage.Image(path) 195 | except IOError: 196 | # Potentially a decoding problem, fall back to PIL.Image 197 | return pil_loader(path) 198 | 199 | 200 | def default_img_loader(path): 201 | from torchvision import get_image_backend 202 | if get_image_backend() == 'accimage': 203 | return accimage_loader(path) 204 | else: 205 | return pil_loader(path) 206 | 207 | 208 | class CachedImageFolder(DatasetFolder): 209 | """A generic data loader where the images are arranged in this way: :: 210 | root/dog/xxx.png 211 | root/dog/xxy.png 212 | root/dog/xxz.png 213 | root/cat/123.png 214 | root/cat/nsdf3.png 215 | root/cat/asd932_.png 216 | Args: 217 | root (string): Root directory path. 218 | transform (callable, optional): A function/transform that takes in an PIL image 219 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 220 | target_transform (callable, optional): A function/transform that takes in the 221 | target and transforms it. 222 | loader (callable, optional): A function to load an image given its path. 223 | Attributes: 224 | imgs (list): List of (image path, class_index) tuples 225 | """ 226 | 227 | def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None, 228 | loader=default_img_loader, cache_mode="no"): 229 | super(CachedImageFolder, self).__init__(root, loader, IMG_EXTENSIONS, 230 | ann_file=ann_file, img_prefix=img_prefix, 231 | transform=transform, target_transform=target_transform, 232 | cache_mode=cache_mode) 233 | self.imgs = self.samples 234 | 235 | def __getitem__(self, index): 236 | """ 237 | Args: 238 | index (int): Index 239 | Returns: 240 | tuple: (image, target) where target is class_index of the target class. 241 | """ 242 | path, target = self.samples[index] 243 | image = self.loader(path) 244 | if self.transform is not None and not isinstance(self.transform, list): 245 | img = self.transform(image) 246 | elif self.transform is not None and isinstance(self.transform, list): 247 | img = [] 248 | for i in range(len(self.transform)): 249 | _img = self.transform[i](image) 250 | if isinstance(_img, list) or isinstance(_img, tuple): 251 | img.extend(_img) 252 | else: 253 | img.append(_img) 254 | else: 255 | img = image 256 | if self.target_transform is not None: 257 | target = self.target_transform(target) 258 | 259 | return img, target 260 | -------------------------------------------------------------------------------- /data/data_fd.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Feature Distillation 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Zhenda Xie 6 | # Modified by Yixuan Wei 7 | # -------------------------------------------------------- 8 | 9 | import numpy as np 10 | 11 | import torch 12 | import torch.distributed as dist 13 | import torchvision.transforms as T 14 | from torch.utils.data import DataLoader, DistributedSampler 15 | from torchvision.datasets import ImageFolder 16 | from torchvision.datasets.folder import IMG_EXTENSIONS, has_file_allowed_extension 17 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 18 | from .utils import SubsetRandomSampler 19 | from .cached_image_folder import CachedImageFolder 20 | 21 | from PIL import ImageFile 22 | ImageFile.LOAD_TRUNCATED_IMAGES = True 23 | 24 | class FDTransform: 25 | def __init__(self, config): 26 | self.config = config 27 | 28 | crop_size = config.DATA.IMG_SIZE 29 | self.transform_img = T.Compose([ 30 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), 31 | T.RandomResizedCrop(crop_size, scale=(config.AUG.MIN_SCALE, config.AUG.MAX_SCALE), ratio=(3. / 4., 4. / 3.)), 32 | T.RandomHorizontalFlip(), 33 | T.ToTensor(), 34 | T.Normalize(mean=torch.tensor(IMAGENET_DEFAULT_MEAN),std=torch.tensor(IMAGENET_DEFAULT_STD)), 35 | ]) 36 | 37 | def __call__(self, img): 38 | img = self.transform_img(img) 39 | return img 40 | 41 | 42 | def is_valid_file(x: str) -> bool: 43 | unvalid_file_list = """ 44 | n01678043_6448.JPEG 45 | n01896844_997.JPEG 46 | n02368116_318.JPEG 47 | n02428089_710.JPEG 48 | n02487347_1956.JPEG 49 | n02597972_5463.JPEG 50 | n03957420_30695.JPEG 51 | n03957420_33553.JPEG 52 | n03957420_8296.JPEG 53 | n04135315_8814.JPEG 54 | n04135315_9318.JPEG 55 | n04257684_9033.JPEG 56 | n04427559_2974.JPEG 57 | n06470073_47249.JPEG 58 | n07930062_4147.JPEG 59 | n09224725_3995.JPEG 60 | n09359803_8155.JPEG 61 | n09620794_5529.JPEG 62 | n09789566_3522.JPEG 63 | n09894445_7463.JPEG 64 | n10175248_583.JPEG 65 | n10316360_4246.JPEG 66 | n10368624_12550.JPEG 67 | n10585217_8484.JPEG 68 | n10721819_1131.JPEG 69 | n12353203_3849.JPEG 70 | n12630763_8018.JPEG 71 | """ 72 | unvalid_file_list = tuple([i.strip() for i in unvalid_file_list.split('\n') if len(i.strip()) > 0]) 73 | assert len(unvalid_file_list) == 27 74 | 75 | return has_file_allowed_extension(x, IMG_EXTENSIONS) and not x.endswith(unvalid_file_list) 76 | 77 | 78 | def build_loader_fd(config, logger): 79 | transform = FDTransform(config) 80 | logger.info(f'Pre-train data transform:\n{transform}') 81 | 82 | if config.DATA.DATASET == 'imagenet': 83 | prefix = 'train' 84 | if config.DATA.ZIP_MODE: 85 | ann_file = prefix + "_map.txt" 86 | prefix = prefix + ".zip@/" 87 | dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform, 88 | cache_mode='part') 89 | else: 90 | dataset = ImageFolder(config.DATA.DATA_PATH, transform, is_valid_file=is_valid_file) 91 | elif config.DATA.DATASET == 'imagenet22k': 92 | dataset = ImageFolder(config.DATA.DATA_PATH, transform, is_valid_file=is_valid_file) 93 | 94 | if config.DATA.DATASET == 'imagenet' and config.DATA.ZIP_MODE: 95 | indices = np.arange(dist.get_rank(), len(dataset), dist.get_world_size()) 96 | sampler = SubsetRandomSampler(indices) 97 | else: 98 | sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True) 99 | dataloader = DataLoader(dataset, config.DATA.BATCH_SIZE, sampler=sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, drop_last=True) 100 | 101 | return dataloader -------------------------------------------------------------------------------- /data/data_finetune.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Feature Distillation 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Zhenda Xie 6 | # Modified by Yixuan Wei 7 | # -------------------------------------------------------- 8 | 9 | import os 10 | import torch.distributed as dist 11 | from torch.utils.data import DataLoader, DistributedSampler 12 | from torchvision import datasets, transforms 13 | from torchvision.datasets import ImageFolder 14 | from torchvision.datasets.folder import IMG_EXTENSIONS, has_file_allowed_extension 15 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 16 | from timm.data import Mixup 17 | from timm.data import create_transform 18 | from timm.data.transforms import _pil_interp 19 | 20 | import numpy as np 21 | from .cached_image_folder import CachedImageFolder 22 | from .utils import SubsetRandomSampler 23 | IMAGENET_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073) 24 | IMAGENET_CLIP_STD = (0.26862954, 0.26130258, 0.27577711) 25 | 26 | import random 27 | from PIL import ImageFilter, ImageOps 28 | 29 | 30 | class GaussianBlur(object): 31 | """ 32 | Apply Gaussian Blur to the PIL image. 33 | """ 34 | def __init__(self, p=0.1, radius_min=0.1, radius_max=2.): 35 | self.prob = p 36 | self.radius_min = radius_min 37 | self.radius_max = radius_max 38 | 39 | def __call__(self, img): 40 | do_it = random.random() <= self.prob 41 | if not do_it: 42 | return img 43 | 44 | img = img.filter( 45 | ImageFilter.GaussianBlur( 46 | radius=random.uniform(self.radius_min, self.radius_max) 47 | ) 48 | ) 49 | return img 50 | 51 | class Solarization(object): 52 | """ 53 | Apply Solarization to the PIL image. 54 | """ 55 | def __init__(self, p=0.2): 56 | self.p = p 57 | 58 | def __call__(self, img): 59 | if random.random() < self.p: 60 | return ImageOps.solarize(img) 61 | else: 62 | return img 63 | 64 | class gray_scale(object): 65 | """ 66 | Apply Solarization to the PIL image. 67 | """ 68 | def __init__(self, p=0.2): 69 | self.p = p 70 | self.transf = transforms.Grayscale(3) 71 | 72 | def __call__(self, img): 73 | if random.random() < self.p: 74 | return self.transf(img) 75 | else: 76 | return img 77 | 78 | 79 | def build_loader_finetune(config, logger): 80 | config.defrost() 81 | dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config, logger=logger) 82 | config.freeze() 83 | dataset_val, _ = build_dataset(is_train=False, config=config, logger=logger) 84 | logger.info(f"Build dataset: train images = {len(dataset_train)}, val images = {len(dataset_val)}") 85 | 86 | num_tasks = dist.get_world_size() 87 | global_rank = dist.get_rank() 88 | if config.DATA.ZIP_MODE: 89 | indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size()) 90 | sampler_train = SubsetRandomSampler(indices) 91 | else: 92 | sampler_train = DistributedSampler( 93 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 94 | ) 95 | 96 | if config.DATA.ZIP_MODE: 97 | indices = np.arange(dist.get_rank(), len(dataset_val), dist.get_world_size()) 98 | sampler_val = SubsetRandomSampler(indices) 99 | else: 100 | sampler_val = DistributedSampler( 101 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True 102 | ) 103 | 104 | data_loader_train = DataLoader( 105 | dataset_train, sampler=sampler_train, 106 | batch_size=config.DATA.BATCH_SIZE, 107 | num_workers=config.DATA.NUM_WORKERS, 108 | pin_memory=config.DATA.PIN_MEMORY, 109 | drop_last=True, 110 | ) 111 | 112 | data_loader_val = DataLoader( 113 | dataset_val, sampler=sampler_val, 114 | batch_size=config.DATA.BATCH_SIZE, 115 | num_workers=config.DATA.NUM_WORKERS, 116 | pin_memory=config.DATA.PIN_MEMORY, 117 | drop_last=False, 118 | ) 119 | 120 | # setup mixup / cutmix 121 | mixup_fn = None 122 | mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None 123 | if mixup_active: 124 | mixup_fn = Mixup( 125 | mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX, 126 | prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE, 127 | label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES) 128 | 129 | return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn 130 | 131 | 132 | def is_valid_file(x: str) -> bool: 133 | unvalid_file_list = """ 134 | n01678043_6448.JPEG 135 | n01896844_997.JPEG 136 | n02368116_318.JPEG 137 | n02428089_710.JPEG 138 | n02487347_1956.JPEG 139 | n02597972_5463.JPEG 140 | n03957420_30695.JPEG 141 | n03957420_33553.JPEG 142 | n03957420_8296.JPEG 143 | n04135315_8814.JPEG 144 | n04135315_9318.JPEG 145 | n04257684_9033.JPEG 146 | n04427559_2974.JPEG 147 | n06470073_47249.JPEG 148 | n07930062_4147.JPEG 149 | n09224725_3995.JPEG 150 | n09359803_8155.JPEG 151 | n09620794_5529.JPEG 152 | n09789566_3522.JPEG 153 | n09894445_7463.JPEG 154 | n10175248_583.JPEG 155 | n10316360_4246.JPEG 156 | n10368624_12550.JPEG 157 | n10585217_8484.JPEG 158 | n10721819_1131.JPEG 159 | n12353203_3849.JPEG 160 | n12630763_8018.JPEG 161 | """ 162 | unvalid_file_list = tuple([i.strip() for i in unvalid_file_list.split('\n') if len(i.strip()) > 0]) 163 | assert len(unvalid_file_list) == 27 164 | 165 | return has_file_allowed_extension(x, IMG_EXTENSIONS) and not x.endswith(unvalid_file_list) 166 | 167 | 168 | def build_dataset(is_train, config, logger): 169 | transform = build_transform(is_train, config) 170 | logger.info(f'Fine-tune data transform, is_train={is_train}:\n{transform}') 171 | 172 | if config.DATA.DATASET == 'imagenet': 173 | prefix = 'train' if is_train else 'val' 174 | if config.DATA.ZIP_MODE: 175 | ann_file = prefix + "_map.txt" 176 | prefix = prefix + ".zip@/" 177 | dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform, 178 | cache_mode='part') 179 | else: 180 | root = os.path.join(config.DATA.DATA_PATH, prefix) 181 | dataset = datasets.ImageFolder(root, transform=transform) 182 | nb_classes = 1000 183 | elif config.DATA.DATASET == 'imagenet22k': 184 | if is_train: 185 | dataset = ImageFolder(config.DATA.DATA_PATH, transform, is_valid_file=is_valid_file) 186 | nb_classes = 21841 187 | else: 188 | nb_classes = 1000 189 | prefix = 'val' 190 | if config.DATA.ZIP_MODE: 191 | ann_file = prefix + "_map.txt" 192 | prefix = prefix + ".zip@/" 193 | dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform, 194 | cache_mode='part') 195 | else: 196 | root = os.path.join(config.DATA.DATA_PATH, prefix) 197 | dataset = datasets.ImageFolder(root, transform=transform) 198 | else: 199 | raise NotImplementedError("We only support ImageNet Now.") 200 | 201 | return dataset, nb_classes 202 | 203 | 204 | def build_transform(is_train, config): 205 | resize_im = config.DATA.IMG_SIZE > 32 206 | if is_train: 207 | # this should always dispatch to transforms_imagenet_train 208 | transform = create_transform( 209 | input_size=config.DATA.IMG_SIZE, 210 | is_training=True, 211 | color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None, 212 | auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None, 213 | re_prob=config.AUG.REPROB, 214 | re_mode=config.AUG.REMODE, 215 | re_count=config.AUG.RECOUNT, 216 | interpolation=config.DATA.INTERPOLATION, 217 | mean=IMAGENET_DEFAULT_MEAN, 218 | std=IMAGENET_DEFAULT_STD 219 | ) 220 | if not resize_im: 221 | # replace RandomResizedCropAndInterpolation with 222 | # RandomCrop 223 | transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4) 224 | return transform 225 | 226 | t = [] 227 | if resize_im: 228 | if config.TEST.CROP: 229 | size = int((256 / 224) * config.DATA.IMG_SIZE) 230 | t.append( 231 | transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)), 232 | # to maintain same ratio w.r.t. 224 images 233 | ) 234 | t.append(transforms.CenterCrop(config.DATA.IMG_SIZE)) 235 | else: 236 | t.append( 237 | transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE), 238 | interpolation=_pil_interp(config.DATA.INTERPOLATION)) 239 | ) 240 | 241 | t.append(transforms.ToTensor()) 242 | t.append(transforms.Normalize(IMAGENET_CLIP_MEAN, IMAGENET_CLIP_STD)) 243 | return transforms.Compose(t) -------------------------------------------------------------------------------- /data/data_linear.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Feature Distillation 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Zhenda Xie 6 | # Modified by Yixuan Wei 7 | # -------------------------------------------------------- 8 | 9 | import os 10 | import math 11 | import torch 12 | import torch.distributed as dist 13 | from torchvision.transforms import functional as F 14 | from torch.utils.data import DataLoader, DistributedSampler 15 | from torchvision import datasets, transforms 16 | from torchvision.datasets.folder import IMG_EXTENSIONS, has_file_allowed_extension 17 | from timm.data.transforms import _pil_interp 18 | 19 | import numpy as np 20 | from .cached_image_folder import CachedImageFolder 21 | from .utils import SubsetRandomSampler 22 | IMAGENET_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073) 23 | IMAGENET_CLIP_STD = (0.26862954, 0.26130258, 0.27577711) 24 | 25 | 26 | def build_loader(config, logger): 27 | config.defrost() 28 | dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config, logger=logger) 29 | config.freeze() 30 | dataset_val, _ = build_dataset(is_train=False, config=config, logger=logger) 31 | logger.info(f"Build dataset: train images = {len(dataset_train)}, val images = {len(dataset_val)}") 32 | 33 | num_tasks = dist.get_world_size() 34 | global_rank = dist.get_rank() 35 | if config.DATA.ZIP_MODE: 36 | indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size()) 37 | sampler_train = SubsetRandomSampler(indices) 38 | else: 39 | sampler_train = DistributedSampler( 40 | dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True 41 | ) 42 | 43 | if config.DATA.ZIP_MODE: 44 | indices = np.arange(dist.get_rank(), len(dataset_val), dist.get_world_size()) 45 | sampler_val = SubsetRandomSampler(indices) 46 | else: 47 | sampler_val = DistributedSampler( 48 | dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True 49 | ) 50 | 51 | data_loader_train = DataLoader( 52 | dataset_train, sampler=sampler_train, 53 | batch_size=config.DATA.BATCH_SIZE, 54 | num_workers=config.DATA.NUM_WORKERS, 55 | pin_memory=config.DATA.PIN_MEMORY, 56 | drop_last=True, 57 | ) 58 | 59 | data_loader_val = DataLoader( 60 | dataset_val, sampler=sampler_val, 61 | batch_size=config.DATA.BATCH_SIZE, 62 | num_workers=config.DATA.NUM_WORKERS, 63 | pin_memory=config.DATA.PIN_MEMORY, 64 | drop_last=False, 65 | ) 66 | 67 | mixup_fn = None 68 | 69 | return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn 70 | 71 | 72 | def is_valid_file(x: str) -> bool: 73 | unvalid_file_list = """ 74 | n01678043_6448.JPEG 75 | n01896844_997.JPEG 76 | n02368116_318.JPEG 77 | n02428089_710.JPEG 78 | n02487347_1956.JPEG 79 | n02597972_5463.JPEG 80 | n03957420_30695.JPEG 81 | n03957420_33553.JPEG 82 | n03957420_8296.JPEG 83 | n04135315_8814.JPEG 84 | n04135315_9318.JPEG 85 | n04257684_9033.JPEG 86 | n04427559_2974.JPEG 87 | n06470073_47249.JPEG 88 | n07930062_4147.JPEG 89 | n09224725_3995.JPEG 90 | n09359803_8155.JPEG 91 | n09620794_5529.JPEG 92 | n09789566_3522.JPEG 93 | n09894445_7463.JPEG 94 | n10175248_583.JPEG 95 | n10316360_4246.JPEG 96 | n10368624_12550.JPEG 97 | n10585217_8484.JPEG 98 | n10721819_1131.JPEG 99 | n12353203_3849.JPEG 100 | n12630763_8018.JPEG 101 | """ 102 | unvalid_file_list = tuple([i.strip() for i in unvalid_file_list.split('\n') if len(i.strip()) > 0]) 103 | assert len(unvalid_file_list) == 27 104 | 105 | return has_file_allowed_extension(x, IMG_EXTENSIONS) and not x.endswith(unvalid_file_list) 106 | 107 | 108 | def build_dataset(is_train, config, logger): 109 | transform = build_transform(is_train, config) 110 | logger.info(f'Fine-tune data transform, is_train={is_train}:\n{transform}') 111 | 112 | if config.DATA.DATASET == 'imagenet': 113 | prefix = 'train' if is_train else 'val' 114 | if config.DATA.ZIP_MODE: 115 | ann_file = prefix + "_map.txt" 116 | prefix = prefix + ".zip@/" 117 | dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform, 118 | cache_mode='part') 119 | else: 120 | root = os.path.join(config.DATA.DATA_PATH, prefix) 121 | dataset = datasets.ImageFolder(root, transform=transform) 122 | nb_classes = 1000 123 | else: 124 | raise NotImplementedError("We only support ImageNet Now.") 125 | 126 | return dataset, nb_classes 127 | 128 | 129 | class RandomResizedCrop(transforms.RandomResizedCrop): 130 | """ 131 | RandomResizedCrop for matching TF/TPU implementation: no for-loop is used. 132 | This may lead to results different with torchvision's version. 133 | Following BYOL's TF code: 134 | https://github.com/deepmind/deepmind-research/blob/master/byol/utils/dataset.py#L206 135 | """ 136 | @staticmethod 137 | def get_params(img, scale, ratio): 138 | width, height = F._get_image_size(img) 139 | area = height * width 140 | 141 | target_area = area * torch.empty(1).uniform_(scale[0], scale[1]).item() 142 | log_ratio = torch.log(torch.tensor(ratio)) 143 | aspect_ratio = torch.exp( 144 | torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) 145 | ).item() 146 | 147 | w = int(round(math.sqrt(target_area * aspect_ratio))) 148 | h = int(round(math.sqrt(target_area / aspect_ratio))) 149 | 150 | w = min(w, width) 151 | h = min(h, height) 152 | 153 | i = torch.randint(0, height - h + 1, size=(1,)).item() 154 | j = torch.randint(0, width - w + 1, size=(1,)).item() 155 | 156 | return i, j, h, w 157 | 158 | 159 | def build_transform(is_train, config): 160 | resize_im = config.DATA.IMG_SIZE > 32 161 | if is_train: 162 | # linear probe: weak augmentation 163 | transform = transforms.Compose([ 164 | RandomResizedCrop(config.DATA.IMG_SIZE, interpolation=3), 165 | transforms.RandomHorizontalFlip(), 166 | transforms.ToTensor(), 167 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 168 | return transform 169 | 170 | t = [] 171 | if resize_im: 172 | if True: # config.TEST.CROP: 173 | size = int((256 / 224) * config.DATA.IMG_SIZE) 174 | t.append( 175 | transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)), 176 | # to maintain same ratio w.r.t. 224 images 177 | ) 178 | t.append(transforms.CenterCrop(config.DATA.IMG_SIZE)) 179 | else: 180 | t.append( 181 | transforms.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE), 182 | interpolation=_pil_interp(config.DATA.INTERPOLATION)) 183 | ) 184 | 185 | t.append(transforms.ToTensor()) 186 | t.append(transforms.Normalize(IMAGENET_CLIP_MEAN, IMAGENET_CLIP_STD)) 187 | return transforms.Compose(t) -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Feature Distillation 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Yixuan Wei 6 | # -------------------------------------------------------- 7 | 8 | import torch 9 | 10 | import os 11 | import io 12 | import zipfile 13 | from PIL import Image 14 | 15 | 16 | def is_zip_path(img_or_path): 17 | """judge if this is a zip path""" 18 | return '.zip@' in img_or_path 19 | 20 | 21 | class ZipReader(object): 22 | zip_bank = dict() 23 | 24 | def __init__(self): 25 | super(ZipReader, self).__init__() 26 | 27 | @staticmethod 28 | def get_zipfile(path): 29 | zip_bank = ZipReader.zip_bank 30 | if path in zip_bank: 31 | return zip_bank[path] 32 | else: 33 | zfile = zipfile.ZipFile(path, 'r') 34 | zip_bank[path] = zfile 35 | return zip_bank[path] 36 | 37 | @staticmethod 38 | def split_zip_style_path(path): 39 | pos_zip_at = path.index('.zip@') 40 | if pos_zip_at == len(path): 41 | print("character '@' is not found from the given path '%s'" % (path)) 42 | assert 0 43 | pos_at = pos_zip_at + len('.zip@') - 1 44 | 45 | zip_path = path[0: pos_at] 46 | folder_path = path[pos_at + 1:] 47 | folder_path = str.strip(folder_path, '/') 48 | return zip_path, folder_path 49 | 50 | @staticmethod 51 | def list_folder(path): 52 | zip_path, folder_path = ZipReader.split_zip_style_path(path) 53 | 54 | zfile = ZipReader.get_zipfile(zip_path) 55 | folder_list = [] 56 | for file_foler_name in zfile.namelist(): 57 | file_foler_name = str.strip(file_foler_name, '/') 58 | if file_foler_name.startswith(folder_path) and \ 59 | len(os.path.splitext(file_foler_name)[-1]) == 0 and \ 60 | file_foler_name != folder_path: 61 | if len(folder_path) == 0: 62 | folder_list.append(file_foler_name) 63 | else: 64 | folder_list.append(file_foler_name[len(folder_path)+1:]) 65 | 66 | return folder_list 67 | 68 | @staticmethod 69 | def list_files(path, extension=['.*']): 70 | zip_path, folder_path = ZipReader.split_zip_style_path(path) 71 | 72 | zfile = ZipReader.get_zipfile(zip_path) 73 | file_lists = [] 74 | for file_foler_name in zfile.namelist(): 75 | file_foler_name = str.strip(file_foler_name, '/') 76 | if file_foler_name.startswith(folder_path) and str.lower(os.path.splitext(file_foler_name)[-1]) in extension: 77 | if len(folder_path) == 0: 78 | file_lists.append(file_foler_name) 79 | else: 80 | file_lists.append(file_foler_name[len(folder_path)+1:]) 81 | 82 | return file_lists 83 | 84 | @staticmethod 85 | def list_files_fullpath(path, extension=['.*']): 86 | zip_path, folder_path = ZipReader.split_zip_style_path(path) 87 | 88 | zfile = ZipReader.get_zipfile(zip_path) 89 | file_lists = [] 90 | for file_foler_name in zfile.namelist(): 91 | if file_foler_name.startswith(folder_path) and str.lower(os.path.splitext(file_foler_name)[-1]) in extension: 92 | file_lists.append(file_foler_name) 93 | 94 | return file_lists 95 | 96 | @staticmethod 97 | def imread(path): 98 | zip_path, path_img = ZipReader.split_zip_style_path(path) 99 | zfile = ZipReader.get_zipfile(zip_path) 100 | data = zfile.read(path_img) 101 | im = Image.open(io.BytesIO(data)) 102 | return im 103 | 104 | @staticmethod 105 | def read(path): 106 | zip_path, path_img = ZipReader.split_zip_style_path(path) 107 | zfile = ZipReader.get_zipfile(zip_path) 108 | data = zfile.read(path_img) 109 | return data 110 | 111 | 112 | class SubsetRandomSampler(torch.utils.data.Sampler): 113 | r"""Samples elements randomly from a given list of indices, without replacement. 114 | 115 | Arguments: 116 | indices (sequence): a sequence of indices 117 | """ 118 | 119 | def __init__(self, indices): 120 | self.epoch = 0 121 | self.indices = indices 122 | 123 | def __iter__(self): 124 | return (self.indices[i] for i in torch.randperm(len(self.indices))) 125 | 126 | def __len__(self): 127 | return len(self.indices) 128 | 129 | def set_epoch(self, epoch): 130 | self.epoch = epoch -------------------------------------------------------------------------------- /figures/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SwinTransformer/Feature-Distillation/2115145a388822bba14c183f9ae74fdf479f7df9/figures/teaser.jpg -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Copyright (c) 2021 Microsoft 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Ze Liu 5 | # Modified by Zhenda Xie 6 | # -------------------------------------------------------- 7 | 8 | import os 9 | import sys 10 | import logging 11 | import functools 12 | from termcolor import colored 13 | 14 | 15 | @functools.lru_cache() 16 | def create_logger(output_dir, dist_rank=0, name=''): 17 | # create logger 18 | logger = logging.getLogger(name) 19 | logger.setLevel(logging.DEBUG) 20 | logger.propagate = False 21 | 22 | # create formatter 23 | fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s' 24 | color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \ 25 | colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s' 26 | 27 | # create console handlers for master process 28 | if dist_rank == 0: 29 | console_handler = logging.StreamHandler(sys.stdout) 30 | console_handler.setLevel(logging.DEBUG) 31 | console_handler.setFormatter( 32 | logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S')) 33 | logger.addHandler(console_handler) 34 | 35 | # create file handlers 36 | file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a') 37 | file_handler.setLevel(logging.DEBUG) 38 | file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S')) 39 | logger.addHandler(file_handler) 40 | 41 | return logger 42 | -------------------------------------------------------------------------------- /lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Copyright (c) 2021 Microsoft 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Ze Liu 5 | # Modified by Zhenda Xie 6 | # -------------------------------------------------------- 7 | 8 | from collections import Counter 9 | from bisect import bisect_right 10 | 11 | import torch 12 | from timm.scheduler.cosine_lr import CosineLRScheduler 13 | from timm.scheduler.step_lr import StepLRScheduler 14 | from timm.scheduler.scheduler import Scheduler 15 | 16 | 17 | def build_scheduler(config, optimizer, n_iter_per_epoch): 18 | num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch) 19 | if config.TRAIN.WARMUP_EPOCHS < 0 and config.TRAIN.WARMUP_EPOCHS_FINE != 0: 20 | warmup_steps = int(config.TRAIN.WARMUP_EPOCHS_FINE * n_iter_per_epoch) 21 | else: 22 | warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch) 23 | decay_steps = int(config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch) 24 | multi_steps = [i * n_iter_per_epoch for i in config.TRAIN.LR_SCHEDULER.MULTISTEPS] 25 | 26 | lr_scheduler = None 27 | if config.TRAIN.LR_SCHEDULER.NAME == 'cosine': 28 | lr_scheduler = CosineLRScheduler( 29 | optimizer, 30 | t_initial=num_steps - warmup_steps, 31 | t_mul=1., 32 | lr_min=config.TRAIN.MIN_LR, 33 | warmup_lr_init=config.TRAIN.WARMUP_LR, 34 | warmup_prefix=True, 35 | warmup_t=warmup_steps, 36 | cycle_limit=1, 37 | t_in_epochs=False, 38 | ) 39 | elif config.TRAIN.LR_SCHEDULER.NAME == 'linear': 40 | lr_scheduler = LinearLRScheduler( 41 | optimizer, 42 | t_initial=num_steps, 43 | lr_min_rate=0.01, 44 | warmup_lr_init=config.TRAIN.WARMUP_LR, 45 | warmup_t=warmup_steps, 46 | t_in_epochs=False, 47 | ) 48 | elif config.TRAIN.LR_SCHEDULER.NAME == 'step': 49 | lr_scheduler = StepLRScheduler( 50 | optimizer, 51 | decay_t=decay_steps, 52 | decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE, 53 | warmup_lr_init=config.TRAIN.WARMUP_LR, 54 | warmup_t=warmup_steps, 55 | t_in_epochs=False, 56 | ) 57 | elif config.TRAIN.LR_SCHEDULER.NAME == 'multistep': 58 | lr_scheduler = MultiStepLRScheduler( 59 | optimizer, 60 | milestones=multi_steps, 61 | gamma=config.TRAIN.LR_SCHEDULER.GAMMA, 62 | warmup_lr_init=config.TRAIN.WARMUP_LR, 63 | warmup_t=warmup_steps, 64 | t_in_epochs=False, 65 | ) 66 | 67 | return lr_scheduler 68 | 69 | 70 | class LinearLRScheduler(Scheduler): 71 | def __init__(self, 72 | optimizer: torch.optim.Optimizer, 73 | t_initial: int, 74 | lr_min_rate: float, 75 | warmup_t=0, 76 | warmup_lr_init=0., 77 | t_in_epochs=True, 78 | noise_range_t=None, 79 | noise_pct=0.67, 80 | noise_std=1.0, 81 | noise_seed=42, 82 | initialize=True, 83 | ) -> None: 84 | super().__init__( 85 | optimizer, param_group_field="lr", 86 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 87 | initialize=initialize) 88 | 89 | self.t_initial = t_initial 90 | self.lr_min_rate = lr_min_rate 91 | self.warmup_t = warmup_t 92 | self.warmup_lr_init = warmup_lr_init 93 | self.t_in_epochs = t_in_epochs 94 | if self.warmup_t: 95 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 96 | super().update_groups(self.warmup_lr_init) 97 | else: 98 | self.warmup_steps = [1 for _ in self.base_values] 99 | 100 | def _get_lr(self, t): 101 | if t < self.warmup_t: 102 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 103 | else: 104 | t = t - self.warmup_t 105 | total_t = self.t_initial - self.warmup_t 106 | lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values] 107 | return lrs 108 | 109 | def get_epoch_values(self, epoch: int): 110 | if self.t_in_epochs: 111 | return self._get_lr(epoch) 112 | else: 113 | return None 114 | 115 | def get_update_values(self, num_updates: int): 116 | if not self.t_in_epochs: 117 | return self._get_lr(num_updates) 118 | else: 119 | return None 120 | 121 | 122 | class MultiStepLRScheduler(Scheduler): 123 | def __init__(self, optimizer: torch.optim.Optimizer, milestones, gamma=0.1, warmup_t=0, warmup_lr_init=0, t_in_epochs=True) -> None: 124 | super().__init__(optimizer, param_group_field="lr") 125 | 126 | self.milestones = milestones 127 | self.gamma = gamma 128 | self.warmup_t = warmup_t 129 | self.warmup_lr_init = warmup_lr_init 130 | self.t_in_epochs = t_in_epochs 131 | if self.warmup_t: 132 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 133 | super().update_groups(self.warmup_lr_init) 134 | else: 135 | self.warmup_steps = [1 for _ in self.base_values] 136 | 137 | assert self.warmup_t <= min(self.milestones) 138 | 139 | def _get_lr(self, t): 140 | if t < self.warmup_t: 141 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 142 | else: 143 | lrs = [v * (self.gamma ** bisect_right(self.milestones, t)) for v in self.base_values] 144 | return lrs 145 | 146 | def get_epoch_values(self, epoch: int): 147 | if self.t_in_epochs: 148 | return self._get_lr(epoch) 149 | else: 150 | return None 151 | 152 | def get_update_values(self, num_updates: int): 153 | if not self.t_in_epochs: 154 | return self._get_lr(num_updates) 155 | else: 156 | return None -------------------------------------------------------------------------------- /main_fd.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Feature Distillation 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # Modified by Yixuan Wei 7 | # -------------------------------------------------------- 8 | 9 | import os 10 | import time 11 | import argparse 12 | import datetime 13 | import numpy as np 14 | 15 | import torch 16 | import torch.backends.cudnn as cudnn 17 | import torch.distributed as dist 18 | import torch.cuda.amp as amp 19 | from timm.utils import AverageMeter 20 | 21 | from config import get_config 22 | from models import build_model 23 | from data import build_loader 24 | from lr_scheduler import build_scheduler 25 | from optimizer import build_optimizer 26 | from logger import create_logger 27 | from utils import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor 28 | 29 | import wandb 30 | global no_wandb 31 | no_wandb = True 32 | 33 | def wandb_log(*args, **kwargs): 34 | if dist.get_rank() == 0 and not no_wandb: 35 | wandb.log(*args, **kwargs) 36 | 37 | 38 | def parse_option(): 39 | global no_wandb 40 | parser = argparse.ArgumentParser('Feature Distillation pre-training script', add_help=False) 41 | parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', ) 42 | parser.add_argument( 43 | "--opts", 44 | help="Modify config options by adding 'KEY VALUE' pairs. ", 45 | default=None, 46 | nargs='+', 47 | ) 48 | 49 | # easy config modification 50 | parser.add_argument('--batch-size', type=int, help="batch size for single GPU") 51 | parser.add_argument('--data-path', type=str, help='path to dataset') 52 | parser.add_argument('--resume', help='resume from checkpoint') 53 | parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps") 54 | parser.add_argument('--use-checkpoint', action='store_true', 55 | help="whether to use gradient checkpointing to save memory") 56 | parser.add_argument('--enable-amp', action='store_true') 57 | parser.add_argument('--disable-amp', action='store_false', dest='enable_amp') 58 | parser.set_defaults(enable_amp=True) 59 | parser.add_argument('--output', default='output', type=str, metavar='PATH', 60 | help='root of output folder, the full path is / (default: output)') 61 | parser.add_argument('--tag', help='tag of experiment') 62 | 63 | # distributed training 64 | parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel') 65 | 66 | parser.add_argument('--with-wandb', action='store_true') 67 | 68 | args = parser.parse_args() 69 | 70 | config = get_config(args) 71 | 72 | if args.with_wandb: 73 | no_wandb = False 74 | print(" warning you're using wandb ! ") 75 | return args, config 76 | 77 | 78 | def main(config): 79 | data_loader_train = build_loader(config, logger, is_pretrain=True) 80 | logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") 81 | model = build_model(config, is_pretrain=True) 82 | model.cuda() 83 | logger.info(str(model)) 84 | 85 | optimizer = build_optimizer(config, model, logger, is_pretrain=True) 86 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False) 87 | model_without_ddp = model.module 88 | 89 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 90 | logger.info(f"number of params: {n_parameters}") 91 | if hasattr(model_without_ddp, 'flops'): 92 | flops = model_without_ddp.flops() 93 | logger.info(f"number of GFLOPs: {flops / 1e9}") 94 | 95 | lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) 96 | scaler = amp.GradScaler() 97 | 98 | if config.TRAIN.AUTO_RESUME: 99 | resume_file = auto_resume_helper(config.OUTPUT, logger) 100 | if resume_file: 101 | if config.MODEL.RESUME: 102 | logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}") 103 | config.defrost() 104 | config.MODEL.RESUME = resume_file 105 | config.freeze() 106 | logger.info(f'auto resuming from {resume_file}') 107 | else: 108 | logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume') 109 | 110 | if config.MODEL.RESUME: 111 | load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, scaler, logger) 112 | elif config.PRETRAINED: 113 | logger.info(f">>>>>>>>>> Continue Pre-train from {config.PRETRAINED} ..........") 114 | checkpoint = torch.load(config.PRETRAINED, map_location='cpu')['model'] 115 | msg = model_without_ddp.load_state_dict(checkpoint, strict=False) 116 | logger.info(msg) 117 | del checkpoint 118 | torch.cuda.empty_cache() 119 | logger.info(f">>>>>>>>>> loaded successfully '{config.PRETRAINED}'") 120 | 121 | logger.info("Start training") 122 | start_time = time.time() 123 | for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): 124 | try: 125 | data_loader_train.sampler.set_epoch(epoch) 126 | except: 127 | pass 128 | 129 | train_one_epoch(config, model, data_loader_train, optimizer, epoch, lr_scheduler, scaler) 130 | if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)): 131 | save_checkpoint(config, epoch, model_without_ddp, 0., optimizer, lr_scheduler, scaler, logger) 132 | 133 | total_time = time.time() - start_time 134 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 135 | logger.info('Training time {}'.format(total_time_str)) 136 | 137 | 138 | def train_one_epoch(config, model, data_loader, optimizer, epoch, lr_scheduler, scaler): 139 | model.train() 140 | optimizer.zero_grad() 141 | 142 | num_steps = len(data_loader) 143 | batch_time = AverageMeter() 144 | loss_meter = AverageMeter() 145 | norm_meter = AverageMeter() 146 | loss_scale_meter = AverageMeter() 147 | 148 | start = time.time() 149 | end = time.time() 150 | for idx, (img, targets) in enumerate(data_loader): 151 | img = img.cuda(non_blocking=True) 152 | targets = targets.cuda(non_blocking=True) 153 | 154 | with amp.autocast(enabled=config.ENABLE_AMP): 155 | if hasattr(model.module, 'require_targets') and model.module.require_targets: 156 | outputs = model(img, targets) 157 | else: 158 | outputs = model(img) 159 | 160 | if isinstance(outputs, dict): 161 | loss = outputs['loss'] 162 | else: 163 | loss = outputs 164 | 165 | if config.TRAIN.ACCUMULATION_STEPS > 1: 166 | loss = loss / config.TRAIN.ACCUMULATION_STEPS 167 | scaler.scale(loss).backward() 168 | if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: 169 | if config.TRAIN.CLIP_GRAD: 170 | scaler.unscale_(optimizer) 171 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) 172 | else: 173 | scaler.unscale_(optimizer) 174 | grad_norm = get_grad_norm(model.parameters()) 175 | scaler.step(optimizer) 176 | optimizer.zero_grad() 177 | scaler.update() 178 | lr_scheduler.step_update(epoch * num_steps + idx) 179 | else: 180 | grad_norm = get_grad_norm(model.parameters()) 181 | else: 182 | optimizer.zero_grad() 183 | scaler.scale(loss).backward() 184 | if config.TRAIN.CLIP_GRAD: 185 | scaler.unscale_(optimizer) 186 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) 187 | else: 188 | scaler.unscale_(optimizer) 189 | grad_norm = get_grad_norm(model.parameters()) 190 | scaler.step(optimizer) 191 | scaler.update() 192 | lr_scheduler.step_update(epoch * num_steps + idx) 193 | 194 | torch.cuda.synchronize() 195 | 196 | if isinstance(img, list): 197 | img = img[-1] 198 | loss_meter.update(loss.item(), img.size(0)) 199 | norm_meter.update(grad_norm) 200 | loss_scale_meter.update(scaler.get_scale()) 201 | batch_time.update(time.time() - end) 202 | end = time.time() 203 | 204 | if idx % config.PRINT_FREQ == 0: 205 | lr = optimizer.param_groups[0]['lr'] 206 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) 207 | etas = batch_time.avg * (num_steps - idx) 208 | metric_info = '' 209 | 210 | logger.info( 211 | f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' 212 | f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t' 213 | f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' 214 | f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' 215 | f'{metric_info}' 216 | f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' 217 | f'loss_scale {loss_scale_meter.val:.4f} ({loss_scale_meter.avg:.4f})\t' 218 | f'mem {memory_used:.0f}MB') 219 | if (epoch * num_steps + idx) % config.PRINT_FREQ == 0: 220 | log_message = dict(lr=lr, time=batch_time.val, epoch=epoch, iter=idx, loss=loss_meter.val, loss_ma=loss_meter.avg, grad_norm=norm_meter.val, loss_scale=loss_scale_meter.val) 221 | wandb_log( 222 | data=log_message, 223 | step=epoch * num_steps + idx, 224 | ) 225 | epoch_time = time.time() - start 226 | logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}") 227 | 228 | 229 | if __name__ == '__main__': 230 | _, config = parse_option() 231 | 232 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 233 | rank = int(os.environ["RANK"]) 234 | world_size = int(os.environ['WORLD_SIZE']) 235 | print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}") 236 | else: 237 | rank = -1 238 | world_size = -1 239 | torch.cuda.set_device(config.LOCAL_RANK) 240 | torch.distributed.init_process_group(backend='nccl', init_method='env://', timeout=datetime.timedelta(seconds=600), world_size=world_size, rank=rank) 241 | torch.distributed.barrier() 242 | 243 | seed = config.SEED + dist.get_rank() 244 | torch.manual_seed(seed) 245 | np.random.seed(seed) 246 | cudnn.benchmark = True 247 | 248 | # linear scale the learning rate according to total batch size, may not be optimal 249 | linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 250 | linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 251 | linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 252 | # gradient accumulation also need to scale the learning rate 253 | if config.TRAIN.ACCUMULATION_STEPS > 1: 254 | linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS 255 | linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS 256 | linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS 257 | config.defrost() 258 | config.TRAIN.BASE_LR = linear_scaled_lr 259 | config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr 260 | config.TRAIN.MIN_LR = linear_scaled_min_lr 261 | config.freeze() 262 | 263 | os.makedirs(config.OUTPUT, exist_ok=True) 264 | logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}") 265 | 266 | if dist.get_rank() == 0: 267 | path = os.path.join(config.OUTPUT, "config.json") 268 | with open(path, "w") as f: 269 | f.write(config.dump()) 270 | logger.info(f"Full config saved to {path}") 271 | 272 | # setup wandb 273 | if not no_wandb: 274 | raise NotImplementedError(" using yourself wandb ") 275 | 276 | # print config 277 | logger.info(config.dump()) 278 | 279 | main(config) 280 | -------------------------------------------------------------------------------- /main_finetune.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Feaure Distillation 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # Modified by Zhenda Xie 7 | # Modified by Yixuan Wei 8 | # -------------------------------------------------------- 9 | 10 | import os 11 | import time 12 | import argparse 13 | import datetime 14 | import numpy as np 15 | 16 | import torch 17 | import torch.backends.cudnn as cudnn 18 | import torch.distributed as dist 19 | import torch.cuda.amp as amp 20 | 21 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 22 | from timm.utils import accuracy, AverageMeter 23 | 24 | from config import get_config 25 | from models import build_model 26 | from data import build_loader 27 | from lr_scheduler import build_scheduler 28 | from optimizer import build_optimizer 29 | from logger import create_logger 30 | from utils import load_checkpoint, load_pretrained, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor 31 | 32 | import warnings 33 | 34 | warnings.filterwarnings('ignore', 35 | 'Argument interpolation should be of type InterpolationMode instead of int', 36 | UserWarning) 37 | 38 | import wandb 39 | global no_wandb 40 | no_wandb = True 41 | 42 | def wandb_log(*args, **kwargs): 43 | if dist.get_rank() == 0 and not no_wandb: 44 | wandb.log(*args, **kwargs) 45 | 46 | 47 | def parse_option(): 48 | global no_wandb 49 | parser = argparse.ArgumentParser('Feature Distillation fine-tuning script', add_help=False) 50 | parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', ) 51 | parser.add_argument( 52 | "--opts", 53 | help="Modify config options by adding 'KEY VALUE' pairs. ", 54 | default=None, 55 | nargs='+', 56 | ) 57 | 58 | # easy config modification 59 | parser.add_argument('--batch-size', type=int, help="batch size for single GPU") 60 | parser.add_argument('--data-path', type=str, help='path to dataset') 61 | parser.add_argument('--pretrained', type=str, help='path to pre-trained model') 62 | parser.add_argument('--resume', help='resume from checkpoint') 63 | parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps") 64 | parser.add_argument('--use-checkpoint', action='store_true', 65 | help="whether to use gradient checkpointing to save memory") 66 | parser.add_argument('--enable-amp', action='store_true') 67 | parser.add_argument('--disable-amp', action='store_false', dest='enable_amp') 68 | parser.set_defaults(enable_amp=True) 69 | parser.add_argument('--output', default='output', type=str, metavar='PATH', 70 | help='root of output folder, the full path is / (default: output)') 71 | parser.add_argument('--tag', help='tag of experiment') 72 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 73 | parser.add_argument('--throughput', action='store_true', help='Test throughput only') 74 | 75 | # distributed training 76 | parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel') 77 | 78 | parser.add_argument('--with-wandb', action='store_true') 79 | args = parser.parse_args() 80 | 81 | config = get_config(args) 82 | if args.with_wandb: 83 | no_wandb = False 84 | print(" warning you're using wandb ! ") 85 | 86 | return args, config 87 | 88 | 89 | def main(config): 90 | dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config, logger, is_pretrain=False) 91 | 92 | logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") 93 | model = build_model(config, is_pretrain=False) 94 | model.cuda() 95 | logger.info(str(model)) 96 | 97 | optimizer = build_optimizer(config, model, logger, is_pretrain=False) 98 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False) 99 | model_without_ddp = model.module 100 | 101 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 102 | logger.info(f"number of params: {n_parameters}") 103 | if hasattr(model_without_ddp, 'flops'): 104 | flops = model_without_ddp.flops() 105 | logger.info(f"number of GFLOPs: {flops / 1e9}") 106 | 107 | lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) 108 | scaler = amp.GradScaler() 109 | 110 | if config.AUG.MIXUP > 0.: 111 | # smoothing is handled with mixup label transform 112 | criterion = SoftTargetCrossEntropy() 113 | elif config.MODEL.LABEL_SMOOTHING > 0.: 114 | criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING) 115 | else: 116 | criterion = torch.nn.CrossEntropyLoss() 117 | 118 | max_accuracy = 0.0 119 | 120 | if config.TRAIN.AUTO_RESUME: 121 | resume_file = auto_resume_helper(config.OUTPUT, logger) 122 | if resume_file: 123 | if config.MODEL.RESUME: 124 | logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}") 125 | config.defrost() 126 | config.MODEL.RESUME = resume_file 127 | config.freeze() 128 | logger.info(f'auto resuming from {resume_file}') 129 | else: 130 | logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume') 131 | 132 | if config.MODEL.RESUME: 133 | max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, scaler, logger) 134 | acc1, acc5, loss = validate(config, data_loader_val, model) 135 | logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") 136 | if config.EVAL_MODE: 137 | return 138 | elif config.PRETRAINED: 139 | load_pretrained(config, model_without_ddp, logger) 140 | 141 | if config.THROUGHPUT_MODE: 142 | throughput(data_loader_val, model, logger) 143 | return 144 | 145 | logger.info("Start training") 146 | start_time = time.time() 147 | for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): 148 | data_loader_train.sampler.set_epoch(epoch) 149 | 150 | train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler, scaler) 151 | if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)): 152 | save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, scaler, logger) 153 | 154 | acc1, acc5, loss = validate(config, data_loader_val, model) 155 | logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") 156 | max_accuracy = max(max_accuracy, acc1) 157 | logger.info(f'Max accuracy: {max_accuracy:.2f}%') 158 | wandb_log( 159 | data=dict(acc1=acc1, acc5=acc5, val_loss=loss, max_acc=max_accuracy), 160 | step=(epoch + 1) * len(data_loader_train), 161 | ) 162 | 163 | total_time = time.time() - start_time 164 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 165 | logger.info('Training time {}'.format(total_time_str)) 166 | 167 | 168 | def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, scaler): 169 | model.train() 170 | optimizer.zero_grad() 171 | 172 | logger.info(f'Current learning rate for different parameter groups: {[it["lr"] for it in optimizer.param_groups]}') 173 | 174 | num_steps = len(data_loader) 175 | batch_time = AverageMeter() 176 | loss_meter = AverageMeter() 177 | norm_meter = AverageMeter() 178 | loss_scale_meter = AverageMeter() 179 | 180 | start = time.time() 181 | end = time.time() 182 | for idx, (samples, targets) in enumerate(data_loader): 183 | samples = samples.cuda(non_blocking=True) 184 | targets = targets.cuda(non_blocking=True) 185 | 186 | if mixup_fn is not None: 187 | samples, targets = mixup_fn(samples, targets) 188 | 189 | with amp.autocast(enabled=config.ENABLE_AMP): 190 | outputs = model(samples) 191 | 192 | if config.TRAIN.ACCUMULATION_STEPS > 1: 193 | loss = criterion(outputs, targets) 194 | loss = loss / config.TRAIN.ACCUMULATION_STEPS 195 | scaler.scale(loss).backward() 196 | if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: 197 | if config.TRAIN.CLIP_GRAD: 198 | scaler.unscale_(optimizer) 199 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) 200 | else: 201 | grad_norm = get_grad_norm(model.parameters()) 202 | scaler.step(optimizer) 203 | optimizer.zero_grad() 204 | scaler.update() 205 | lr_scheduler.step_update(epoch * num_steps + idx) 206 | else: 207 | grad_norm = get_grad_norm(model.parameters()) 208 | else: 209 | loss = criterion(outputs, targets) 210 | optimizer.zero_grad() 211 | scaler.scale(loss).backward() 212 | if config.TRAIN.CLIP_GRAD and (config.TRAIN.OPTIMIZER.NAME.lower() != 'fusedlamb'): 213 | scaler.unscale_(optimizer) 214 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) 215 | else: 216 | grad_norm = get_grad_norm(model.parameters()) 217 | scaler.step(optimizer) 218 | scaler.update() 219 | lr_scheduler.step_update(epoch * num_steps + idx) 220 | 221 | torch.cuda.synchronize() 222 | 223 | loss_meter.update(loss.item(), targets.size(0)) 224 | norm_meter.update(grad_norm) 225 | loss_scale_meter.update(scaler.get_scale()) 226 | batch_time.update(time.time() - end) 227 | end = time.time() 228 | 229 | if idx % config.PRINT_FREQ == 0: 230 | lr = optimizer.param_groups[-1]['lr'] 231 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) 232 | etas = batch_time.avg * (num_steps - idx) 233 | logger.info( 234 | f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' 235 | f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t' 236 | f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' 237 | f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' 238 | f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' 239 | f'loss_scale {loss_scale_meter.val:.4f} ({loss_scale_meter.avg:.4f})\t' 240 | f'mem {memory_used:.0f}MB') 241 | if (epoch * num_steps + idx) % config.PRINT_FREQ == 0: 242 | wandb_log( 243 | data=dict(lr=lr, time=batch_time.val, epoch=epoch, iter=idx, loss=loss_meter.val, loss_ma=loss_meter.avg, grad_norm=norm_meter.val, loss_scale=loss_scale_meter.val), 244 | step=epoch * num_steps + idx, 245 | ) 246 | epoch_time = time.time() - start 247 | logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}") 248 | 249 | 250 | @torch.no_grad() 251 | def validate(config, data_loader, model): 252 | criterion = torch.nn.CrossEntropyLoss() 253 | model.eval() 254 | 255 | batch_time = AverageMeter() 256 | loss_meter = AverageMeter() 257 | acc1_meter = AverageMeter() 258 | acc5_meter = AverageMeter() 259 | 260 | end = time.time() 261 | for idx, (images, target) in enumerate(data_loader): 262 | images = images.cuda(non_blocking=True) 263 | target = target.cuda(non_blocking=True) 264 | 265 | # compute output 266 | output = model(images) 267 | 268 | # measure accuracy and record loss 269 | loss = criterion(output, target) 270 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 271 | 272 | acc1 = reduce_tensor(acc1) 273 | acc5 = reduce_tensor(acc5) 274 | loss = reduce_tensor(loss) 275 | 276 | loss_meter.update(loss.item(), target.size(0)) 277 | acc1_meter.update(acc1.item(), target.size(0)) 278 | acc5_meter.update(acc5.item(), target.size(0)) 279 | 280 | # measure elapsed time 281 | batch_time.update(time.time() - end) 282 | end = time.time() 283 | 284 | if idx % config.PRINT_FREQ == 0: 285 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) 286 | logger.info( 287 | f'Test: [{idx}/{len(data_loader)}]\t' 288 | f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 289 | f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' 290 | f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t' 291 | f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t' 292 | f'Mem {memory_used:.0f}MB') 293 | logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}') 294 | return acc1_meter.avg, acc5_meter.avg, loss_meter.avg 295 | 296 | 297 | @torch.no_grad() 298 | def throughput(data_loader, model, logger): 299 | model.eval() 300 | 301 | for idx, (images, _) in enumerate(data_loader): 302 | images = images.cuda(non_blocking=True) 303 | batch_size = images.shape[0] 304 | for i in range(50): 305 | model(images) 306 | torch.cuda.synchronize() 307 | logger.info(f"throughput averaged with 30 times") 308 | tic1 = time.time() 309 | for i in range(30): 310 | model(images) 311 | torch.cuda.synchronize() 312 | tic2 = time.time() 313 | logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}") 314 | return 315 | 316 | 317 | if __name__ == '__main__': 318 | _, config = parse_option() 319 | 320 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 321 | rank = int(os.environ["RANK"]) 322 | world_size = int(os.environ['WORLD_SIZE']) 323 | print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}") 324 | else: 325 | rank = -1 326 | world_size = -1 327 | torch.cuda.set_device(config.LOCAL_RANK) 328 | torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank) 329 | torch.distributed.barrier() 330 | 331 | seed = config.SEED + dist.get_rank() 332 | torch.manual_seed(seed) 333 | np.random.seed(seed) 334 | cudnn.benchmark = True 335 | 336 | # linear scale the learning rate according to total batch size, may not be optimal 337 | linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 338 | linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 339 | linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 340 | # gradient accumulation also need to scale the learning rate 341 | if config.TRAIN.ACCUMULATION_STEPS > 1: 342 | linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS 343 | linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS 344 | linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS 345 | config.defrost() 346 | config.TRAIN.BASE_LR = linear_scaled_lr 347 | config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr 348 | config.TRAIN.MIN_LR = linear_scaled_min_lr 349 | config.freeze() 350 | 351 | os.makedirs(config.OUTPUT, exist_ok=True) 352 | logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}") 353 | 354 | if dist.get_rank() == 0: 355 | path = os.path.join(config.OUTPUT, "config.json") 356 | with open(path, "w") as f: 357 | f.write(config.dump()) 358 | logger.info(f"Full config saved to {path}") 359 | 360 | # setup wandb 361 | if not no_wandb: 362 | raise NotImplementedError(" using yourself wandb ") 363 | 364 | # print config 365 | logger.info(config.dump()) 366 | 367 | main(config) 368 | -------------------------------------------------------------------------------- /main_linear.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Copyright (c) 2021 Microsoft 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Ze Liu 5 | # Modified by Zhenda Xie 6 | # Modified by Yixuan Wei 7 | # -------------------------------------------------------- 8 | 9 | import os 10 | import time 11 | import argparse 12 | import datetime 13 | import numpy as np 14 | 15 | import torch 16 | import torch.backends.cudnn as cudnn 17 | import torch.distributed as dist 18 | import torch.cuda.amp as amp 19 | 20 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy 21 | from timm.utils import accuracy, AverageMeter, NativeScaler 22 | from timm.models.layers import trunc_normal_ 23 | 24 | from config import get_config 25 | from models import build_model 26 | from data.data_linear import build_loader 27 | from lr_scheduler import build_scheduler 28 | from optimizer import build_optimizer 29 | from logger import create_logger 30 | from utils import LARS, load_checkpoint, load_pretrained, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor 31 | 32 | import warnings 33 | 34 | warnings.filterwarnings('ignore', 35 | 'Argument interpolation should be of type InterpolationMode instead of int', 36 | UserWarning) 37 | 38 | import wandb 39 | global no_wandb 40 | no_wandb = True 41 | 42 | def wandb_log(*args, **kwargs): 43 | if dist.get_rank() == 0 and not no_wandb: 44 | wandb.log(*args, **kwargs) 45 | 46 | 47 | def parse_option(): 48 | global no_wandb 49 | parser = argparse.ArgumentParser('feature distillation linear probe script', add_help=False) 50 | parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', ) 51 | parser.add_argument( 52 | "--opts", 53 | help="Modify config options by adding 'KEY VALUE' pairs. ", 54 | default=None, 55 | nargs='+', 56 | ) 57 | 58 | # easy config modification 59 | parser.add_argument('--batch-size', type=int, help="batch size for single GPU") 60 | parser.add_argument('--data-path', type=str, help='path to dataset') 61 | parser.add_argument('--pretrained', type=str, help='path to pre-trained model') 62 | parser.add_argument('--resume', help='resume from checkpoint') 63 | parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps") 64 | parser.add_argument('--use-checkpoint', action='store_true', 65 | help="whether to use gradient checkpointing to save memory") 66 | parser.add_argument('--enable-amp', action='store_true') 67 | parser.add_argument('--disable-amp', action='store_false', dest='enable_amp') 68 | parser.set_defaults(enable_amp=True) 69 | parser.add_argument('--output', default='output', type=str, metavar='PATH', 70 | help='root of output folder, the full path is / (default: output)') 71 | parser.add_argument('--tag', help='tag of experiment') 72 | parser.add_argument('--eval', action='store_true', help='Perform evaluation only') 73 | parser.add_argument('--throughput', action='store_true', help='Test throughput only') 74 | 75 | # distributed training 76 | parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel') 77 | 78 | parser.add_argument('--with-wandb', action='store_true') 79 | args = parser.parse_args() 80 | 81 | config = get_config(args) 82 | if args.with_wandb: 83 | no_wandb = False 84 | print(" warning you're using wandb ! ") 85 | 86 | return args, config 87 | 88 | 89 | def main(config): 90 | dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config, logger) 91 | 92 | logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") 93 | model = build_model(config, is_pretrain=False) 94 | # ===== add linear probe head like mae ======== 95 | # manually initialize fc layer: following MoCo v3 96 | trunc_normal_(model.head.weight, std=0.01) 97 | 98 | # for linear prob only 99 | # hack: revise model's head with BN 100 | model.head = torch.nn.Sequential(torch.nn.BatchNorm1d(model.head.in_features, affine=False, eps=1e-6), model.head) 101 | # freeze all but the head 102 | for _, p in model.named_parameters(): 103 | p.requires_grad = False 104 | for _, p in model.head.named_parameters(): 105 | p.requires_grad = True 106 | # ====== finish adding ============== 107 | model.cuda() 108 | logger.info(str(model)) 109 | 110 | optimizer = LARS(model.head.parameters(), lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 111 | print(optimizer) 112 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False) 113 | model_without_ddp = model.module 114 | 115 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 116 | logger.info(f"number of params: {n_parameters}") 117 | if hasattr(model_without_ddp, 'flops'): 118 | flops = model_without_ddp.flops() 119 | logger.info(f"number of GFLOPs: {flops / 1e9}") 120 | 121 | lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) 122 | scaler = amp.GradScaler() 123 | 124 | criterion = torch.nn.CrossEntropyLoss() 125 | 126 | max_accuracy = 0.0 127 | 128 | if config.TRAIN.AUTO_RESUME: 129 | resume_file = auto_resume_helper(config.OUTPUT, logger) 130 | if resume_file: 131 | if config.MODEL.RESUME: 132 | logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}") 133 | config.defrost() 134 | config.MODEL.RESUME = resume_file 135 | config.freeze() 136 | logger.info(f'auto resuming from {resume_file}') 137 | else: 138 | logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume') 139 | 140 | if config.MODEL.RESUME: 141 | max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, scaler, logger) 142 | acc1, acc5, loss = validate(config, data_loader_val, model) 143 | logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") 144 | if config.EVAL_MODE: 145 | return 146 | elif config.PRETRAINED: 147 | load_pretrained(config, model_without_ddp, logger) 148 | 149 | if config.THROUGHPUT_MODE: 150 | throughput(data_loader_val, model, logger) 151 | return 152 | 153 | logger.info("Start training") 154 | start_time = time.time() 155 | for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): 156 | data_loader_train.sampler.set_epoch(epoch) 157 | 158 | train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler, scaler) 159 | if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)): 160 | save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, scaler, logger) 161 | 162 | acc1, acc5, loss = validate(config, data_loader_val, model) 163 | logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%") 164 | max_accuracy = max(max_accuracy, acc1) 165 | logger.info(f'Max accuracy: {max_accuracy:.2f}%') 166 | wandb_log( 167 | data=dict(acc1=acc1, acc5=acc5, val_loss=loss, max_acc=max_accuracy), 168 | step=(epoch + 1) * len(data_loader_train), 169 | ) 170 | 171 | total_time = time.time() - start_time 172 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 173 | logger.info('Training time {}'.format(total_time_str)) 174 | 175 | 176 | def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, scaler): 177 | model.train() 178 | optimizer.zero_grad() 179 | 180 | logger.info(f'Current learning rate for different parameter groups: {[it["lr"] for it in optimizer.param_groups]}') 181 | 182 | num_steps = len(data_loader) 183 | batch_time = AverageMeter() 184 | loss_meter = AverageMeter() 185 | norm_meter = AverageMeter() 186 | loss_scale_meter = AverageMeter() 187 | 188 | start = time.time() 189 | end = time.time() 190 | for idx, (samples, targets) in enumerate(data_loader): 191 | samples = samples.cuda(non_blocking=True) 192 | targets = targets.cuda(non_blocking=True) 193 | 194 | if mixup_fn is not None: 195 | samples, targets = mixup_fn(samples, targets) 196 | 197 | with amp.autocast(enabled=config.ENABLE_AMP): 198 | outputs = model(samples) 199 | loss = criterion(outputs, targets) 200 | 201 | optimizer.zero_grad() 202 | scaler.scale(loss).backward() 203 | if config.TRAIN.CLIP_GRAD: 204 | scaler.unscale_(optimizer) 205 | grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) 206 | else: 207 | grad_norm = get_grad_norm(model.parameters()) 208 | scaler.step(optimizer) 209 | scaler.update() 210 | lr_scheduler.step_update(epoch * num_steps + idx) 211 | 212 | torch.cuda.synchronize() 213 | 214 | loss_meter.update(loss.item(), targets.size(0)) 215 | norm_meter.update(grad_norm) 216 | loss_scale_meter.update(scaler.get_scale()) 217 | batch_time.update(time.time() - end) 218 | end = time.time() 219 | 220 | if idx % config.PRINT_FREQ == 0: 221 | lr = optimizer.param_groups[-1]['lr'] 222 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) 223 | etas = batch_time.avg * (num_steps - idx) 224 | logger.info( 225 | f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' 226 | f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t' 227 | f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' 228 | f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' 229 | f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' 230 | f'loss_scale {loss_scale_meter.val:.4f} ({loss_scale_meter.avg:.4f})\t' 231 | f'mem {memory_used:.0f}MB') 232 | if (epoch * num_steps + idx) % config.PRINT_FREQ == 0: 233 | wandb_log( 234 | data=dict(lr=lr, time=batch_time.val, epoch=epoch, iter=idx, loss=loss_meter.val, loss_ma=loss_meter.avg, grad_norm=norm_meter.val, loss_scale=loss_scale_meter.val), 235 | step=epoch * num_steps + idx, 236 | ) 237 | epoch_time = time.time() - start 238 | logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}") 239 | 240 | 241 | @torch.no_grad() 242 | def validate(config, data_loader, model): 243 | criterion = torch.nn.CrossEntropyLoss() 244 | model.eval() 245 | 246 | batch_time = AverageMeter() 247 | loss_meter = AverageMeter() 248 | acc1_meter = AverageMeter() 249 | acc5_meter = AverageMeter() 250 | 251 | end = time.time() 252 | for idx, (images, target) in enumerate(data_loader): 253 | images = images.cuda(non_blocking=True) 254 | target = target.cuda(non_blocking=True) 255 | 256 | # compute output 257 | output = model(images) 258 | 259 | # measure accuracy and record loss 260 | loss = criterion(output, target) 261 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 262 | 263 | acc1 = reduce_tensor(acc1) 264 | acc5 = reduce_tensor(acc5) 265 | loss = reduce_tensor(loss) 266 | 267 | loss_meter.update(loss.item(), target.size(0)) 268 | acc1_meter.update(acc1.item(), target.size(0)) 269 | acc5_meter.update(acc5.item(), target.size(0)) 270 | 271 | # measure elapsed time 272 | batch_time.update(time.time() - end) 273 | end = time.time() 274 | 275 | if idx % config.PRINT_FREQ == 0: 276 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) 277 | logger.info( 278 | f'Test: [{idx}/{len(data_loader)}]\t' 279 | f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 280 | f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' 281 | f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t' 282 | f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t' 283 | f'Mem {memory_used:.0f}MB') 284 | logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}') 285 | return acc1_meter.avg, acc5_meter.avg, loss_meter.avg 286 | 287 | 288 | @torch.no_grad() 289 | def throughput(data_loader, model, logger): 290 | model.eval() 291 | 292 | for idx, (images, _) in enumerate(data_loader): 293 | images = images.cuda(non_blocking=True) 294 | batch_size = images.shape[0] 295 | for i in range(50): 296 | model(images) 297 | torch.cuda.synchronize() 298 | logger.info(f"throughput averaged with 30 times") 299 | tic1 = time.time() 300 | for i in range(30): 301 | model(images) 302 | torch.cuda.synchronize() 303 | tic2 = time.time() 304 | logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}") 305 | return 306 | 307 | 308 | if __name__ == '__main__': 309 | _, config = parse_option() 310 | 311 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 312 | rank = int(os.environ["RANK"]) 313 | world_size = int(os.environ['WORLD_SIZE']) 314 | print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}") 315 | else: 316 | rank = -1 317 | world_size = -1 318 | torch.cuda.set_device(config.LOCAL_RANK) 319 | torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank) 320 | torch.distributed.barrier() 321 | 322 | seed = config.SEED + dist.get_rank() 323 | torch.manual_seed(seed) 324 | np.random.seed(seed) 325 | cudnn.benchmark = True 326 | 327 | # linear scale the learning rate according to total batch size, may not be optimal 328 | linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 329 | linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 330 | linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0 331 | # gradient accumulation also need to scale the learning rate 332 | if config.TRAIN.ACCUMULATION_STEPS > 1: 333 | linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS 334 | linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS 335 | linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS 336 | config.defrost() 337 | config.TRAIN.BASE_LR = linear_scaled_lr 338 | config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr 339 | config.TRAIN.MIN_LR = linear_scaled_min_lr 340 | config.freeze() 341 | 342 | os.makedirs(config.OUTPUT, exist_ok=True) 343 | logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}") 344 | 345 | if dist.get_rank() == 0: 346 | path = os.path.join(config.OUTPUT, "config.json") 347 | with open(path, "w") as f: 348 | f.write(config.dump()) 349 | logger.info(f"Full config saved to {path}") 350 | 351 | # setup wandb 352 | if not no_wandb: 353 | raise NotImplementedError(" using yourself wandb ") 354 | 355 | # print config 356 | logger.info(config.dump()) 357 | 358 | main(config) 359 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_model -------------------------------------------------------------------------------- /models/build.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Feature Distillation 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # Modified by Zhenda Xie 7 | # Modified by Yixuan Wei 8 | # -------------------------------------------------------- 9 | 10 | from .swin_transformer import build_swin 11 | from .swin_transformer_v2 import build_swin_v2 12 | from .vision_transformer import build_vit 13 | from .feature_distillation import build_fd 14 | 15 | 16 | def build_model(config, is_pretrain=True): 17 | if is_pretrain: 18 | model = build_fd(config) 19 | else: 20 | model_type = config.MODEL.TYPE 21 | if model_type == 'swin_v2': 22 | model = build_swin_v2(config) 23 | elif model_type == 'swin': 24 | model = build_swin(config) 25 | elif model_type == 'vit': 26 | model = build_vit(config) 27 | else: 28 | raise NotImplementedError(f"Unknown fine-tune model: {model_type}") 29 | 30 | return model 31 | -------------------------------------------------------------------------------- /models/clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /models/clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Union, List 6 | 7 | import torch 8 | from PIL import Image 9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 10 | from tqdm import tqdm 11 | 12 | from .model import build_model 13 | import torch.distributed as dist 14 | # from .simple_tokenizer import SimpleTokenizer as _Tokenizer 15 | 16 | try: 17 | from torchvision.transforms import InterpolationMode 18 | BICUBIC = InterpolationMode.BICUBIC 19 | except ImportError: 20 | BICUBIC = Image.BICUBIC 21 | 22 | 23 | if torch.__version__.split(".") < ["1", "7", "1"]: 24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 25 | 26 | 27 | __all__ = ["available_models", "load", "tokenize"] 28 | # _tokenizer = _Tokenizer() 29 | 30 | _MODELS = { 31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 35 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 36 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 37 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 38 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 39 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 40 | } 41 | 42 | 43 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip"), sha_check=True): 44 | os.makedirs(root, exist_ok=True) 45 | filename = os.path.basename(url) 46 | 47 | expected_sha256 = url.split("/")[-2] 48 | download_target = os.path.join(root, filename) 49 | 50 | if os.path.exists(download_target) and not os.path.isfile(download_target): 51 | raise RuntimeError(f"{download_target} exists and is not a regular file") 52 | 53 | if os.path.isfile(download_target): 54 | if (not sha_check) or (sha_check and hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256): 55 | return download_target 56 | else: 57 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 58 | 59 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 60 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 61 | while True: 62 | buffer = source.read(8192) 63 | if not buffer: 64 | break 65 | 66 | output.write(buffer) 67 | loop.update(len(buffer)) 68 | 69 | if sha_check and hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 70 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 71 | 72 | return download_target 73 | 74 | 75 | def _transform(n_px): 76 | return Compose([ 77 | Resize(n_px, interpolation=BICUBIC), 78 | CenterCrop(n_px), 79 | lambda image: image.convert("RGB"), 80 | ToTensor(), 81 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 82 | ]) 83 | 84 | 85 | def available_models() -> List[str]: 86 | """Returns the names of available CLIP models""" 87 | return list(_MODELS.keys()) 88 | 89 | 90 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False, image_size=None): 91 | """Load a CLIP model 92 | 93 | Parameters 94 | ---------- 95 | name : str 96 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 97 | 98 | device : Union[str, torch.device] 99 | The device to put the loaded model 100 | 101 | jit : bool 102 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 103 | 104 | Returns 105 | ------- 106 | model : torch.nn.Module 107 | The CLIP model 108 | 109 | preprocess : Callable[[PIL.Image], torch.Tensor] 110 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 111 | """ 112 | if name in _MODELS: 113 | if ((dist.is_initialized() or dist.is_available()) and int(dist.get_rank()) % torch.cuda.device_count() == 0) or not dist.is_available(): 114 | model_path = _download(_MODELS[name]) 115 | dist.barrier() 116 | model_path = _download(_MODELS[name]) 117 | elif os.path.isfile(name): 118 | model_path = name 119 | else: 120 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 121 | 122 | with open(model_path, 'rb') as opened_file: 123 | try: 124 | # loading JIT archive 125 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() 126 | state_dict = None 127 | except RuntimeError: 128 | # loading saved state dict 129 | if jit: 130 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 131 | jit = False 132 | state_dict = torch.load(opened_file, map_location="cpu") 133 | 134 | if not jit: 135 | model = build_model(state_dict or model.state_dict(), image_size).to(device) 136 | if str(device) == "cpu": 137 | model.float() 138 | return model, _transform(model.visual.input_resolution) 139 | 140 | # patch the device names 141 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 142 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 143 | 144 | def patch_device(module): 145 | try: 146 | graphs = [module.graph] if hasattr(module, "graph") else [] 147 | except RuntimeError: 148 | graphs = [] 149 | 150 | if hasattr(module, "forward1"): 151 | graphs.append(module.forward1.graph) 152 | 153 | for graph in graphs: 154 | for node in graph.findAllNodes("prim::Constant"): 155 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 156 | node.copyAttributes(device_node) 157 | 158 | model.apply(patch_device) 159 | patch_device(model.encode_image) 160 | patch_device(model.encode_text) 161 | 162 | # patch dtype to float32 on CPU 163 | if str(device) == "cpu": 164 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 165 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 166 | float_node = float_input.node() 167 | 168 | def patch_float(module): 169 | try: 170 | graphs = [module.graph] if hasattr(module, "graph") else [] 171 | except RuntimeError: 172 | graphs = [] 173 | 174 | if hasattr(module, "forward1"): 175 | graphs.append(module.forward1.graph) 176 | 177 | for graph in graphs: 178 | for node in graph.findAllNodes("aten::to"): 179 | inputs = list(node.inputs()) 180 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 181 | if inputs[i].node()["value"] == 5: 182 | inputs[i].node().copyAttributes(float_node) 183 | 184 | model.apply(patch_float) 185 | patch_float(model.encode_image) 186 | patch_float(model.encode_text) 187 | 188 | model.float() 189 | 190 | return model, _transform(model.input_resolution.item()) 191 | 192 | 193 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 194 | """ 195 | Returns the tokenized representation of given input string(s) 196 | 197 | Parameters 198 | ---------- 199 | texts : Union[str, List[str]] 200 | An input string or a list of input strings to tokenize 201 | 202 | context_length : int 203 | The context length to use; all CLIP models use 77 as the context length 204 | 205 | truncate: bool 206 | Whether to truncate the text in case its encoding is longer than the context length 207 | 208 | Returns 209 | ------- 210 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 211 | """ 212 | if isinstance(texts, str): 213 | texts = [texts] 214 | 215 | sot_token = _tokenizer.encoder["<|startoftext|>"] 216 | eot_token = _tokenizer.encoder["<|endoftext|>"] 217 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 218 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 219 | 220 | for i, tokens in enumerate(all_tokens): 221 | if len(tokens) > context_length: 222 | if truncate: 223 | tokens = tokens[:context_length] 224 | tokens[-1] = eot_token 225 | else: 226 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 227 | result[i, :len(tokens)] = torch.tensor(tokens) 228 | 229 | return result 230 | -------------------------------------------------------------------------------- /models/clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /models/clip/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # -------------------------------------------------------- 7 | # Position embedding utils 8 | # -------------------------------------------------------- 9 | 10 | import numpy as np 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | # -------------------------------------------------------- 15 | # 2D sine-cosine position embedding 16 | # References: 17 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 18 | # MoCo v3: https://github.com/facebookresearch/moco-v3 19 | # -------------------------------------------------------- 20 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 21 | """ 22 | grid_size: int of the grid height and width 23 | return: 24 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 25 | """ 26 | grid_h = np.arange(grid_size, dtype=np.float32) 27 | grid_w = np.arange(grid_size, dtype=np.float32) 28 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 29 | grid = np.stack(grid, axis=0) 30 | 31 | grid = grid.reshape([2, 1, grid_size, grid_size]) 32 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 33 | if cls_token: 34 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 35 | return pos_embed 36 | 37 | 38 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 39 | assert embed_dim % 2 == 0 40 | 41 | # use half of dimensions to encode grid_h 42 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 43 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 44 | 45 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 46 | return emb 47 | 48 | 49 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 50 | """ 51 | embed_dim: output dimension for each position 52 | pos: a list of positions to be encoded: size (M,) 53 | out: (M, D) 54 | """ 55 | assert embed_dim % 2 == 0 56 | omega = np.arange(embed_dim // 2, dtype=np.float) 57 | omega /= embed_dim / 2. 58 | omega = 1. / 10000**omega # (D/2,) 59 | 60 | pos = pos.reshape(-1) # (M,) 61 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 62 | 63 | emb_sin = np.sin(out) # (M, D/2) 64 | emb_cos = np.cos(out) # (M, D/2) 65 | 66 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 67 | return emb 68 | 69 | 70 | # -------------------------------------------------------- 71 | # Interpolate position embeddings for high-resolution 72 | # References: 73 | # DeiT: https://github.com/facebookresearch/deit 74 | # -------------------------------------------------------- 75 | def interpolate_pos_embed(model, checkpoint_model): 76 | if 'pos_embed' in checkpoint_model: 77 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 78 | embedding_size = pos_embed_checkpoint.shape[-1] 79 | num_patches = model.patch_embed.num_patches 80 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 81 | # height (== width) for the checkpoint position embedding 82 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 83 | # height (== width) for the new position embedding 84 | new_size = int(num_patches ** 0.5) 85 | # class_token and dist_token are kept unchanged 86 | if orig_size != new_size: 87 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 88 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 89 | # only the position tokens are interpolated 90 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 91 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 92 | pos_tokens = torch.nn.functional.interpolate( 93 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 94 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 95 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 96 | checkpoint_model['pos_embed'] = new_pos_embed -------------------------------------------------------------------------------- /models/clip/vit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Modified from github.com/facebook/SLIP 8 | from collections import OrderedDict 9 | 10 | import torch 11 | from torch import nn 12 | 13 | from functools import partial 14 | from timm.models.layers import DropPath 15 | from timm.models.vision_transformer import PatchEmbed, Mlp 16 | 17 | class Attention(nn.Module): 18 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 19 | super().__init__() 20 | self.num_heads = num_heads 21 | head_dim = dim // num_heads 22 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 23 | self.scale = qk_scale or head_dim ** -0.5 24 | 25 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 26 | self.attn_drop = nn.Dropout(attn_drop) 27 | self.proj = nn.Linear(dim, dim) 28 | self.proj_drop = nn.Dropout(proj_drop) 29 | 30 | def forward(self, x): 31 | B, N, C = x.shape 32 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 33 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 34 | 35 | attn = (q @ k.transpose(-2, -1)) * self.scale 36 | attn = attn.softmax(dim=-1) 37 | attn = self.attn_drop(attn) 38 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 39 | x = self.proj(x) 40 | x = self.proj_drop(x) 41 | return x 42 | 43 | 44 | class Block(nn.Module): 45 | 46 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 47 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 48 | super().__init__() 49 | self.norm1 = norm_layer(dim) 50 | self.attn = Attention( 51 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 52 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 53 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 54 | self.norm2 = norm_layer(dim) 55 | mlp_hidden_dim = int(dim * mlp_ratio) 56 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 57 | 58 | def forward(self, x): 59 | attn_x = self.attn(self.norm1(x)) 60 | x = x + self.drop_path(attn_x) 61 | x = x + self.drop_path(self.mlp(self.norm2(x))) 62 | return x 63 | 64 | 65 | class VisionTransformer(nn.Module): 66 | """ Vision Transformer 67 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` 68 | - https://arxiv.org/abs/2010.11929 69 | Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` 70 | - https://arxiv.org/abs/2012.12877 71 | """ 72 | 73 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 74 | num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, 75 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, 76 | act_layer=None, weight_init=''): 77 | """ 78 | Args: 79 | img_size (int, tuple): input image size 80 | patch_size (int, tuple): patch size 81 | in_chans (int): number of input channels 82 | num_classes (int): number of classes for classification head 83 | embed_dim (int): embedding dimension 84 | depth (int): depth of transformer 85 | num_heads (int): number of attention heads 86 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 87 | qkv_bias (bool): enable bias for qkv if True 88 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 89 | distilled (bool): model includes a distillation token and head as in DeiT models 90 | drop_rate (float): dropout rate 91 | attn_drop_rate (float): attention dropout rate 92 | drop_path_rate (float): stochastic depth rate 93 | embed_layer (nn.Module): patch embedding layer 94 | norm_layer: (nn.Module): normalization layer 95 | weight_init: (str): weight init scheme 96 | """ 97 | super().__init__() 98 | self.num_classes = num_classes 99 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 100 | self.num_tokens = 2 if distilled else 1 101 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 102 | act_layer = act_layer or nn.GELU 103 | 104 | self.patch_embed = embed_layer( 105 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 106 | num_patches = self.patch_embed.num_patches 107 | 108 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 109 | self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None 110 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) 111 | self.pos_drop = nn.Dropout(p=drop_rate) 112 | 113 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 114 | self.blocks = nn.Sequential(*[ 115 | Block( 116 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, 117 | attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) 118 | for i in range(depth)]) 119 | self.norm = norm_layer(embed_dim) 120 | 121 | # Representation layer 122 | if representation_size and not distilled: 123 | self.num_features = representation_size 124 | self.pre_logits = nn.Sequential(OrderedDict([ 125 | ('fc', nn.Linear(embed_dim, representation_size)), 126 | ('act', nn.Tanh()) 127 | ])) 128 | else: 129 | self.pre_logits = nn.Identity() 130 | 131 | # Classifier head(s) 132 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 133 | self.head_dist = None 134 | if distilled: 135 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 136 | 137 | @torch.jit.ignore 138 | def no_weight_decay(self): 139 | return {'pos_embed', 'cls_token', 'dist_token'} 140 | 141 | def get_classifier(self): 142 | if self.dist_token is None: 143 | return self.head 144 | else: 145 | return self.head, self.head_dist 146 | 147 | def reset_classifier(self, num_classes, global_pool=''): 148 | self.num_classes = num_classes 149 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 150 | if self.num_tokens == 2: 151 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 152 | 153 | def forward_featuremap(self, x): 154 | x = self.patch_embed(x) 155 | cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks 156 | if self.dist_token is None: 157 | x = torch.cat((cls_token, x), dim=1) 158 | else: 159 | x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) 160 | x = self.pos_drop(x + self.pos_embed) 161 | # apply Transformer blocks 162 | for blk_idx, blk in enumerate(self.blocks): 163 | x = blk(x) 164 | return x 165 | 166 | def forward_features(self, x): 167 | x = self.forward_featuremap(x) 168 | x = self.norm(x) 169 | if self.dist_token is None: 170 | return self.pre_logits(x[:, 0]) 171 | else: 172 | return x[:, 0], x[:, 1] 173 | 174 | def forward(self, x): 175 | x = self.forward_features(x) 176 | if self.head_dist is not None: 177 | x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple 178 | if self.training and not torch.jit.is_scripting(): 179 | # during inference, return the average of both classifier predictions 180 | return x, x_dist 181 | else: 182 | return (x + x_dist) / 2 183 | else: 184 | x = self.head(x) 185 | return x 186 | -------------------------------------------------------------------------------- /models/deit.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Union, List 3 | import torch 4 | from .clip.clip import _download 5 | from .clip.vit import VisionTransformer 6 | import torch.distributed as dist 7 | _MODELS = { 8 | "DEIT": "https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", 9 | } 10 | 11 | class WarpperVisionTransformer(VisionTransformer): 12 | def __init__(self, **kwargs): 13 | super(WarpperVisionTransformer, self).__init__(**kwargs) 14 | 15 | @property 16 | def dtype(self): 17 | return self.norm.weight.dtype 18 | 19 | def encode_image_featuremap(self, image): 20 | return self.forward_featuremap(image.type(self.dtype)) 21 | 22 | 23 | def load_deit(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", **kwargs): 24 | if name in _MODELS: 25 | if ((dist.is_initialized() or dist.is_available()) and int(dist.get_rank()) % torch.cuda.device_count() == 0) or not dist.is_available(): 26 | model_path = _download(_MODELS[name], sha_check=False) 27 | dist.barrier() 28 | model_path = _download(_MODELS[name], sha_check=False) 29 | elif os.path.isfile(name): 30 | model_path = name 31 | else: 32 | raise RuntimeError(f"Model {name} not found; ") 33 | 34 | 35 | state_dict = torch.load(model_path, map_location="cpu")['model'] 36 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, num_classes=0) 37 | model = WarpperVisionTransformer(**model_kwargs) 38 | msg = model.load_state_dict(state_dict, strict=False) 39 | print(msg) 40 | return model.to(device) 41 | -------------------------------------------------------------------------------- /models/dino.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Union, List 3 | import torch 4 | from .clip.clip import _download 5 | from .clip.vit import VisionTransformer 6 | import torch.distributed as dist 7 | _MODELS = { 8 | "DINO": "https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth", 9 | "DINO_T": "https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain_full_checkpoint.pth" 10 | } 11 | 12 | class WarpperVisionTransformer(VisionTransformer): 13 | def __init__(self, **kwargs): 14 | super(WarpperVisionTransformer, self).__init__(**kwargs) 15 | 16 | @property 17 | def dtype(self): 18 | return self.norm.weight.dtype 19 | 20 | def encode_image_featuremap(self, image): 21 | return self.forward_featuremap(image.type(self.dtype)) 22 | 23 | def load_dino(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", **kwargs): 24 | if name in _MODELS: 25 | if ((dist.is_initialized() or dist.is_available()) and int(dist.get_rank()) % torch.cuda.device_count() == 0) or not dist.is_available(): 26 | model_path = _download(_MODELS[name], sha_check=False) 27 | dist.barrier() 28 | model_path = _download(_MODELS[name], sha_check=False) 29 | elif os.path.isfile(name): 30 | model_path = name 31 | else: 32 | raise RuntimeError(f"Model {name} not found; ") 33 | 34 | 35 | state_dict = torch.load(model_path, map_location="cpu") 36 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, num_classes=0) 37 | model = WarpperVisionTransformer(**model_kwargs) 38 | msg = model.load_state_dict(state_dict) 39 | print(msg) 40 | return model.to(device) -------------------------------------------------------------------------------- /models/esvit.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Union 3 | import torch 4 | from .clip.clip import _download 5 | from .swin_transformer import SwinTransformer 6 | import torch.distributed as dist 7 | _MODELS = { 8 | "ESVIT": "https://chunyleu.blob.core.windows.net/output/ckpts/esvit/swin/swin_base/bl_lr0.0005_nodes4_gpu16_bs8_multicrop_epoch300_dino_aug_window14_lv/continued_from_epoch0200_dense_norm_true/checkpoint_best.pth", 9 | } 10 | 11 | class WarpperSwinTransformer(SwinTransformer): 12 | def __init__(self, **kwargs): 13 | super(WarpperSwinTransformer, self).__init__(**kwargs) 14 | 15 | @property 16 | def dtype(self): 17 | return self.norm.weight.dtype 18 | 19 | def forward_featuremap(self, x): 20 | x = self.patch_embed(x) 21 | if self.ape: 22 | x = x + self.absolute_pos_embed 23 | x = self.pos_drop(x) 24 | 25 | for layer in self.layers: 26 | x = layer(x) 27 | 28 | # x = self.norm(x) # B L C 29 | # x = self.avgpool(x.transpose(1, 2)) # B C 1 30 | # x = torch.flatten(x, 1) 31 | return x 32 | 33 | def encode_image_featuremap(self, image): 34 | return self.forward_featuremap(image.type(self.dtype)) 35 | 36 | def load_esvit(name: str, return_s3=False, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", **kwargs): 37 | if name in _MODELS: 38 | if ((dist.is_initialized() or dist.is_available()) and int(dist.get_rank()) % torch.cuda.device_count() == 0) or not dist.is_available(): 39 | model_path = _download(_MODELS[name], sha_check=False) 40 | dist.barrier() 41 | model_path = _download(_MODELS[name], sha_check=False) 42 | elif os.path.isfile(name): 43 | model_path = name 44 | else: 45 | raise RuntimeError(f"Model {name} not found; ") 46 | 47 | 48 | state_dict = torch.load(model_path, map_location="cpu")['student'] 49 | state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} 50 | model_kwargs = dict(patch_size=4, embed_dim=128, depths=[2,2,18] if return_s3 else [2,2,18,2], num_heads=[4,8,16] if return_s3 else [4,8,16,32], window_size=14, num_classes=0) 51 | model = WarpperSwinTransformer(**model_kwargs) 52 | if return_s3: 53 | state_dict.pop('norm.weight') 54 | state_dict.pop('norm.bias') 55 | msg = model.load_state_dict(state_dict, strict=False) 56 | print(msg) 57 | return model.to(device) 58 | -------------------------------------------------------------------------------- /models/feature_distillation.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Feature Distillation 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Yixuan Wei 6 | # -------------------------------------------------------- 7 | 8 | from functools import partial 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.utils.checkpoint as checkpoint 13 | from torchvision.transforms import Resize 14 | from timm.models.layers import trunc_normal_ 15 | 16 | # from .swin_transformer import SwinTransformer 17 | from .swin_transformer_v2 import SwinTransformerV2 18 | from .vision_transformer import VisionTransformer 19 | from .clip import load as load_clip 20 | from .dino import load_dino 21 | from .deit import load_deit 22 | from .esvit import load_esvit 23 | 24 | import torchvision.transforms as T 25 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 26 | mean = torch.tensor(IMAGENET_DEFAULT_MEAN) 27 | std = torch.tensor(IMAGENET_DEFAULT_STD) 28 | normalize = T.Normalize(mean=mean, std=std) 29 | unnormalize = T.Normalize(mean=-mean / std, std=1.0 / std) 30 | normalize_clip = T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 31 | 32 | 33 | class VisionTransformerForFD(VisionTransformer): 34 | def __init__(self, use_checkpoint=False, **kwargs): 35 | super().__init__(**kwargs) 36 | 37 | assert self.num_classes == 0 38 | self.use_checkpoint = use_checkpoint 39 | 40 | def _trunc_normal_(self, tensor, mean=0., std=1.): 41 | trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std) 42 | 43 | def forward(self, x): 44 | x = self.patch_embed(x) 45 | 46 | B, L, _ = x.shape 47 | if self.cls_token is not None: 48 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 49 | x = torch.cat((cls_tokens, x), dim=1) 50 | 51 | if self.pos_embed is not None: 52 | x = x + self.pos_embed 53 | x = self.pos_drop(x) 54 | 55 | rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None 56 | for blk_idx, blk in enumerate(self.blocks): 57 | if self.use_checkpoint: 58 | x = checkpoint.checkpoint(blk, x, rel_pos_bias) 59 | else: 60 | x = blk(x, rel_pos_bias=rel_pos_bias) 61 | x = self.norm(x) 62 | 63 | return x 64 | 65 | 66 | class SwinV2ForFD(SwinTransformerV2): 67 | def __init__(self, **kwargs): 68 | super().__init__(**kwargs) 69 | 70 | assert self.num_classes == 0 71 | 72 | def forward(self, x): 73 | x = self.patch_embed(x) 74 | 75 | if self.ape: 76 | x = x + self.absolute_pos_embed 77 | x = self.pos_drop(x) 78 | 79 | for layer in self.layers: 80 | x = layer(x) 81 | x = self.norm(x) 82 | 83 | return x 84 | 85 | @torch.jit.ignore 86 | def no_weight_decay(self): 87 | return super().no_weight_decay() 88 | 89 | 90 | class FD(nn.Module): 91 | def __init__(self, config, encoder): 92 | super().__init__() 93 | self.encoder = encoder 94 | 95 | self.pred_feat = config.DEV.PRED_FEAT 96 | self.feat_after_norm = config.DEV.PRED_FEAT_AFTERNORM 97 | 98 | # pred target is feature 99 | if config.DEV.PRED_FEAT == 'CLIP_400M': 100 | self.feature_model, _ = load_clip("ViT-B/16", image_size=config.DATA.IMG_SIZE) 101 | self.resize_func = None 102 | if config.DATA.IMG_SIZE != 224: 103 | self.resize_func = Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE)) 104 | elif config.DEV.PRED_FEAT == 'CLIP_400M_Large': 105 | self.feature_model, _ = load_clip("ViT-L/14", image_size=224) 106 | self.resize_func = None 107 | if config.DATA.IMG_SIZE != 224: 108 | self.resize_func = Resize((224, 224)) 109 | elif config.DEV.PRED_FEAT == 'DINO': 110 | self.feature_model = load_dino(config.DEV.PRED_FEAT) 111 | elif config.DEV.PRED_FEAT == 'DEIT': 112 | self.feature_model = load_deit(config.DEV.PRED_FEAT) 113 | elif config.DEV.PRED_FEAT == 'ESVIT': 114 | self.feature_model = load_esvit(config.DEV.PRED_FEAT, return_s3=config.DEV.PRED_FEAT_S3) 115 | else: 116 | raise NotImplementedError 117 | for name, params in self.feature_model.named_parameters(): 118 | params.requires_grad = False 119 | 120 | if 'Large' in config.DEV.PRED_FEAT: 121 | embed_dim = 1024 122 | elif 'ESVIT' in config.DEV.PRED_FEAT: 123 | embed_dim = self.feature_model.embed_dim * 4 if config.DEV.PRED_FEAT_S3 else self.feature_model.embed_dim * 8 124 | else: 125 | embed_dim = 768 126 | 127 | self.loss_feat = nn.SmoothL1Loss(beta=2.0) 128 | self.ln_tgt = nn.LayerNorm(embed_dim, elementwise_affine=False) 129 | self.decoder = nn.Sequential( 130 | nn.Conv1d( 131 | in_channels=self.encoder.num_features, 132 | out_channels=embed_dim, kernel_size=1), 133 | ) 134 | 135 | self.in_chans = self.encoder.in_chans 136 | self.patch_size = self.encoder.patch_size 137 | 138 | def forward(self, x): 139 | z = self.encoder(x) 140 | 141 | x_rec = self.decoder(z.permute(0,2,1)).permute(0,2,1) 142 | self.feature_model.eval() 143 | with torch.no_grad(): 144 | # DINO & DeiT don't unnormalize 145 | if 'CLIP' in self.pred_feat: 146 | x = normalize_clip(unnormalize(x)) 147 | if self.pred_feat == 'CLIP_400M_Large' or self.pred_feat == 'CLIP_400M': 148 | # large as teacher: student: 256/p16 or 224/p14; teacher 224/p14 149 | if self.resize_func is not None: 150 | x = self.resize_func(x) 151 | 152 | x_tgt = self.feature_model.encode_image_featuremap(x) 153 | 154 | if self.feat_after_norm: 155 | if 'CLIP' in self.pred_feat: 156 | x_tgt = self.feature_model.visual.ln_post(x_tgt) 157 | elif 'DINO' in self.pred_feat or 'DEIT' in self.pred_feat or 'ESVIT' in self.pred_feat: 158 | x_tgt = self.feature_model.norm(x_tgt) 159 | else: 160 | raise NotImplementedError 161 | x_tgt = x_tgt.detach() 162 | x_tgt = self.ln_tgt(x_tgt) 163 | 164 | loss = self.loss_feat(x_rec, x_tgt) 165 | loss = loss.mean() 166 | return {'loss': loss} 167 | 168 | @torch.jit.ignore 169 | def no_weight_decay(self): 170 | if hasattr(self.encoder, 'no_weight_decay'): 171 | return {'encoder.' + i for i in self.encoder.no_weight_decay()} 172 | return {} 173 | 174 | @torch.jit.ignore 175 | def no_weight_decay_keywords(self): 176 | if hasattr(self.encoder, 'no_weight_decay_keywords'): 177 | return self.encoder.no_weight_decay_keywords() 178 | return {} 179 | 180 | 181 | def build_fd(config): 182 | model_type = config.MODEL.TYPE 183 | if model_type == 'swin_v2': 184 | encoder = SwinV2ForFD( 185 | img_size=config.DATA.IMG_SIZE, 186 | patch_size=config.MODEL.SWIN.PATCH_SIZE, 187 | in_chans=config.MODEL.SWIN.IN_CHANS, 188 | num_classes=0, 189 | embed_dim=config.MODEL.SWIN.EMBED_DIM, 190 | depths=config.MODEL.SWIN.DEPTHS, 191 | num_heads=config.MODEL.SWIN.NUM_HEADS, 192 | window_size=config.MODEL.SWIN.WINDOW_SIZE, 193 | mlp_ratio=config.MODEL.SWIN.MLP_RATIO, 194 | qkv_bias=config.MODEL.SWIN.QKV_BIAS, 195 | qk_scale=config.MODEL.SWIN.QK_SCALE, 196 | drop_rate=config.MODEL.DROP_RATE, 197 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 198 | ape=config.MODEL.SWIN.APE, 199 | patch_norm=config.MODEL.SWIN.PATCH_NORM, 200 | use_shared_rel_pos_bias=config.MODEL.SWIN.USE_SHARED_RPB, 201 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 202 | elif model_type == 'vit': 203 | encoder = VisionTransformerForFD( 204 | img_size=config.DATA.IMG_SIZE, 205 | patch_size=config.MODEL.VIT.PATCH_SIZE, 206 | in_chans=config.MODEL.VIT.IN_CHANS, 207 | num_classes=0, 208 | embed_dim=config.MODEL.VIT.EMBED_DIM, 209 | depth=config.MODEL.VIT.DEPTH, 210 | num_heads=config.MODEL.VIT.NUM_HEADS, 211 | mlp_ratio=config.MODEL.VIT.MLP_RATIO, 212 | qkv_bias=config.MODEL.VIT.QKV_BIAS, 213 | with_k_bias=config.DEV.VIT_WITHKBIAS, 214 | drop_rate=config.MODEL.DROP_RATE, 215 | drop_path_rate=config.MODEL.DROP_PATH_RATE, 216 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 217 | init_values=config.MODEL.VIT.INIT_VALUES, 218 | use_abs_pos_emb=config.MODEL.VIT.USE_APE, 219 | use_rel_pos_bias=config.MODEL.VIT.USE_RPB, 220 | use_shared_rel_pos_bias=config.MODEL.VIT.USE_SHARED_RPB, 221 | use_mean_pooling=config.MODEL.VIT.USE_MEAN_POOLING, 222 | with_cls_token=config.MODEL.VIT.WITH_CLS_TOKEN, 223 | use_checkpoint=config.TRAIN.USE_CHECKPOINT) 224 | else: 225 | raise NotImplementedError(f"Unknown pre-train model: {model_type}") 226 | 227 | model = FD(config=config, encoder=encoder) 228 | 229 | return model 230 | -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | 6 | def norm_targets(targets, norm_patch_size, patch_size): 7 | # targets [N, C, H, W] 8 | targets_ = targets 9 | N, _, H, W = targets.shape 10 | assert norm_patch_size >= patch_size 11 | assert (norm_patch_size - patch_size) % 2 == 0 12 | 13 | padding = (norm_patch_size - patch_size) // 2 14 | 15 | targets = F.pad(targets, [padding,] * 4, mode='reflect') # [N, C, H_pad, W_pad] 16 | targets = F.unfold(targets, kernel_size=norm_patch_size, dilation=1, padding=0, stride=patch_size) # [N, norm_patch_size * norm_patch_size * C, H // patch_size * W // patch_size] 17 | 18 | mean = targets.mean(dim=1).reshape((N, 1, H // patch_size, W // patch_size)).repeat_interleave(patch_size, -1).repeat_interleave(patch_size, -2) 19 | var = targets.var(dim=1).reshape((N, 1, H // patch_size, W // patch_size)).repeat_interleave(patch_size, -1).repeat_interleave(patch_size, -2) 20 | 21 | targets_ = (targets_ - mean) / (var + 1.e-6) ** 0.5 22 | 23 | return targets_ 24 | 25 | 26 | # Copyright (c) Meta Platforms, Inc. and affiliates. 27 | # All rights reserved. 28 | 29 | # This source code is licensed under the license found in the 30 | # LICENSE file in the root directory of this source tree. 31 | # -------------------------------------------------------- 32 | # Position embedding utils 33 | # -------------------------------------------------------- 34 | # -------------------------------------------------------- 35 | # 2D sine-cosine position embedding 36 | # References: 37 | # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py 38 | # MoCo v3: https://github.com/facebookresearch/moco-v3 39 | # -------------------------------------------------------- 40 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 41 | """ 42 | grid_size: int of the grid height and width 43 | return: 44 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 45 | """ 46 | grid_h = np.arange(grid_size, dtype=np.float32) 47 | grid_w = np.arange(grid_size, dtype=np.float32) 48 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 49 | grid = np.stack(grid, axis=0) 50 | 51 | grid = grid.reshape([2, 1, grid_size, grid_size]) 52 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 53 | if cls_token: 54 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 55 | return pos_embed 56 | 57 | 58 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 59 | assert embed_dim % 2 == 0 60 | 61 | # use half of dimensions to encode grid_h 62 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 63 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 64 | 65 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 66 | return emb 67 | 68 | 69 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 70 | """ 71 | embed_dim: output dimension for each position 72 | pos: a list of positions to be encoded: size (M,) 73 | out: (M, D) 74 | """ 75 | assert embed_dim % 2 == 0 76 | omega = np.arange(embed_dim // 2, dtype=np.float) 77 | omega /= embed_dim / 2. 78 | omega = 1. / 10000**omega # (D/2,) 79 | 80 | pos = pos.reshape(-1) # (M,) 81 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 82 | 83 | emb_sin = np.sin(out) # (M, D/2) 84 | emb_cos = np.cos(out) # (M, D/2) 85 | 86 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 87 | return emb 88 | 89 | 90 | # -------------------------------------------------------- 91 | # Interpolate position embeddings for high-resolution 92 | # References: 93 | # DeiT: https://github.com/facebookresearch/deit 94 | # -------------------------------------------------------- 95 | def interpolate_pos_embed(model, checkpoint_model): 96 | if 'pos_embed' in checkpoint_model: 97 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 98 | embedding_size = pos_embed_checkpoint.shape[-1] 99 | num_patches = model.patch_embed.num_patches 100 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 101 | # height (== width) for the checkpoint position embedding 102 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 103 | # height (== width) for the new position embedding 104 | new_size = int(num_patches ** 0.5) 105 | # class_token and dist_token are kept unchanged 106 | if orig_size != new_size: 107 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 108 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 109 | # only the position tokens are interpolated 110 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 111 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 112 | pos_tokens = torch.nn.functional.interpolate( 113 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 114 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 115 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 116 | checkpoint_model['pos_embed'] = new_pos_embed 117 | -------------------------------------------------------------------------------- /models/vision_transformer.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Copyright (c) 2021 Microsoft 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Based on BEIT code bases (https://github.com/microsoft/unilm/tree/master/beit) 5 | # Written by Yutong Lin, Zhenda Xie 6 | # Modified by Yixuan Wei 7 | # -------------------------------------------------------- 8 | 9 | import math 10 | from functools import partial 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | from torch import Tensor 16 | import torch.utils.checkpoint as checkpoint 17 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 18 | 19 | class LinearFP32(nn.Linear): 20 | def __init__(self, in_features, out_features, bias=True): 21 | super(LinearFP32, self).__init__(in_features, out_features, bias) 22 | 23 | def forward(self, input: Tensor) -> Tensor: 24 | return F.linear(input.float(), self.weight.float(), 25 | self.bias.float() if self.bias is not None else None) 26 | 27 | class Mlp(nn.Module): 28 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 29 | super().__init__() 30 | out_features = out_features or in_features 31 | hidden_features = hidden_features or in_features 32 | self.fc1 = nn.Linear(in_features, hidden_features) 33 | self.act = act_layer() 34 | self.fc2 = nn.Linear(hidden_features, out_features) 35 | self.drop = nn.Dropout(drop) 36 | 37 | def forward(self, x): 38 | x = self.fc1(x) 39 | x = self.act(x) 40 | # x = self.drop(x) 41 | # comment out this for the orignal BERT implement 42 | x = self.fc2(x) 43 | x = self.drop(x) 44 | return x 45 | 46 | 47 | class Attention(nn.Module): 48 | def __init__( 49 | self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., 50 | proj_drop=0., window_size=None, attn_head_dim=None, with_cls_token=True, 51 | with_k_bias=False): 52 | super().__init__() 53 | self.num_heads = num_heads 54 | head_dim = dim // num_heads 55 | if attn_head_dim is not None: 56 | head_dim = attn_head_dim 57 | all_head_dim = head_dim * self.num_heads 58 | self.scale = qk_scale or head_dim ** -0.5 59 | 60 | self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) 61 | if qkv_bias: 62 | self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) 63 | self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) 64 | else: 65 | self.q_bias = None 66 | self.v_bias = None 67 | self.with_k_bias = with_k_bias 68 | if self.with_k_bias: 69 | self.k_bias = nn.Parameter(torch.zeros(all_head_dim)) 70 | self.window_size = window_size 71 | if window_size: 72 | if with_cls_token: 73 | # extra 3: cls to token & token 2 cls & cls to cls 74 | self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 75 | self.num_tokens = self.window_size[0] * self.window_size[1] + 1 76 | else: 77 | self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) 78 | self.num_tokens = self.window_size[0] * self.window_size[1] 79 | self.relative_position_bias_table = nn.Parameter( 80 | torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH 81 | 82 | # get pair-wise relative position index for each token inside the window 83 | coords_h = torch.arange(window_size[0]) 84 | coords_w = torch.arange(window_size[1]) 85 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 86 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 87 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 88 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 89 | relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 90 | relative_coords[:, :, 1] += window_size[1] - 1 91 | relative_coords[:, :, 0] *= 2 * window_size[1] - 1 92 | 93 | if with_cls_token: 94 | relative_position_index = \ 95 | torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) 96 | relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 97 | relative_position_index[0, 0:] = self.num_relative_distance - 3 98 | relative_position_index[0:, 0] = self.num_relative_distance - 2 99 | relative_position_index[0, 0] = self.num_relative_distance - 1 100 | else: 101 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 102 | 103 | self.register_buffer("relative_position_index", relative_position_index) 104 | else: 105 | self.window_size = None 106 | self.relative_position_bias_table = None 107 | self.relative_position_index = None 108 | 109 | self.attn_drop = nn.Dropout(attn_drop) 110 | self.proj = nn.Linear(all_head_dim, dim) 111 | self.proj_drop = nn.Dropout(proj_drop) 112 | 113 | def forward(self, x, rel_pos_bias=None, rpb_mask=None, attn_mask=None): 114 | B, N, C = x.shape 115 | qkv_bias = None 116 | if self.q_bias is not None: 117 | if self.with_k_bias: 118 | qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) 119 | else: 120 | qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) 121 | qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) 122 | qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 123 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 124 | 125 | q = q * self.scale 126 | attn = (q @ k.transpose(-2, -1)) 127 | 128 | if self.relative_position_bias_table is not None: 129 | relative_position_bias = \ 130 | self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 131 | self.num_tokens, self.num_tokens, -1) # Wh*Ww,Wh*Ww,nH 132 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0) # 1, nH, Wh*Ww, Wh*Ww 133 | 134 | if rpb_mask is not None: 135 | relative_position_bias = relative_position_bias.expand(B, -1, -1, -1) # [B, nH, L + 1, L + 1] 136 | relative_position_bias = relative_position_bias[rpb_mask].reshape((B, -1, N, N)) 137 | 138 | attn = attn + relative_position_bias 139 | 140 | if rel_pos_bias is not None: 141 | attn = attn + rel_pos_bias 142 | 143 | if attn_mask is not None: 144 | attn = attn - 1e10 * attn_mask 145 | 146 | attn = attn.softmax(dim=-1) 147 | attn = self.attn_drop(attn) 148 | 149 | x = (attn @ v).transpose(1, 2) 150 | x = x.reshape(B, N, -1) 151 | x = self.proj(x) 152 | x = self.proj_drop(x) 153 | return x 154 | 155 | 156 | class Block(nn.Module): 157 | 158 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 159 | drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, 160 | window_size=None, attn_head_dim=None, with_cls_token=True, 161 | with_k_bias=False): 162 | super().__init__() 163 | self.norm1 = norm_layer(dim) 164 | self.attn = Attention( 165 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 166 | attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim, 167 | with_cls_token=with_cls_token, with_k_bias=with_k_bias) 168 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 169 | self.norm2 = norm_layer(dim) 170 | mlp_hidden_dim = int(dim * mlp_ratio) 171 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 172 | 173 | if init_values is not None: 174 | self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) 175 | self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) 176 | else: 177 | self.gamma_1, self.gamma_2 = None, None 178 | 179 | def forward(self, x, rel_pos_bias=None, rpb_mask=None, attn_mask=None): 180 | attn_x = self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias, rpb_mask=rpb_mask, attn_mask=attn_mask) 181 | if self.gamma_1 is None: 182 | x = x + self.drop_path(attn_x) 183 | x = x + self.drop_path(self.mlp(self.norm2(x))) 184 | else: 185 | x = x + self.drop_path(self.gamma_1 * attn_x) 186 | x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) 187 | return x 188 | 189 | 190 | class PatchEmbed(nn.Module): 191 | """ Image to Patch Embedding 192 | """ 193 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 194 | super().__init__() 195 | img_size = to_2tuple(img_size) 196 | patch_size = to_2tuple(patch_size) 197 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 198 | self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 199 | self.img_size = img_size 200 | self.patch_size = patch_size 201 | self.num_patches = num_patches 202 | 203 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 204 | 205 | def forward(self, x, **kwargs): 206 | B, C, H, W = x.shape 207 | # FIXME look at relaxing size constraints 208 | assert H == self.img_size[0] and W == self.img_size[1], \ 209 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 210 | x = self.proj(x).flatten(2).transpose(1, 2) 211 | return x 212 | 213 | 214 | class RelativePositionBias(nn.Module): 215 | 216 | def __init__(self, window_size, num_heads, with_cls_token=True): 217 | super().__init__() 218 | self.window_size = window_size 219 | if with_cls_token: 220 | # extra 3: cls to token & token 2 cls & cls to cls 221 | self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 222 | self.num_tokens = self.window_size[0] * self.window_size[1] + 1 223 | else: 224 | self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) 225 | self.num_tokens = self.window_size[0] * self.window_size[1] 226 | self.relative_position_bias_table = nn.Parameter( 227 | torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH 228 | 229 | # get pair-wise relative position index for each token inside the window 230 | coords_h = torch.arange(window_size[0]) 231 | coords_w = torch.arange(window_size[1]) 232 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 233 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 234 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 235 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 236 | relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 237 | relative_coords[:, :, 1] += window_size[1] - 1 238 | relative_coords[:, :, 0] *= 2 * window_size[1] - 1 239 | 240 | if with_cls_token: 241 | relative_position_index = \ 242 | torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype) 243 | relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 244 | relative_position_index[0, 0:] = self.num_relative_distance - 3 245 | relative_position_index[0:, 0] = self.num_relative_distance - 2 246 | relative_position_index[0, 0] = self.num_relative_distance - 1 247 | else: 248 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 249 | 250 | self.register_buffer("relative_position_index", relative_position_index) 251 | 252 | def forward(self): 253 | relative_position_bias = \ 254 | self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 255 | self.num_tokens, self.num_tokens, -1) # Wh*Ww,Wh*Ww,nH 256 | return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 257 | 258 | 259 | class VisionTransformer(nn.Module): 260 | """ Vision Transformer with support for patch or hybrid CNN input stage 261 | """ 262 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 263 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 264 | drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, 265 | use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, 266 | use_mean_pooling=True, init_scale=0.001, with_cls_token=True, use_checkpoint=False, with_k_bias=False): 267 | super().__init__() 268 | self.num_classes = num_classes 269 | self.num_features = self.embed_dim = embed_dim 270 | self.patch_size = patch_size 271 | self.in_chans = in_chans 272 | self.num_heads = num_heads 273 | self.use_abs_pos_emb = use_abs_pos_emb 274 | self.with_cls_token = with_cls_token 275 | self.use_checkpoint = use_checkpoint 276 | 277 | self.patch_embed = PatchEmbed( 278 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 279 | num_patches = self.patch_embed.num_patches 280 | self.num_patches = num_patches 281 | 282 | if with_cls_token: 283 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 284 | else: 285 | self.cls_token = None 286 | 287 | if use_abs_pos_emb: 288 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) if with_cls_token else nn.Parameter(torch.zeros(1, num_patches, embed_dim)) 289 | else: 290 | self.pos_embed = None 291 | self.pos_drop = nn.Dropout(p=drop_rate) 292 | 293 | if use_shared_rel_pos_bias: 294 | self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads, 295 | with_cls_token=with_cls_token) 296 | else: 297 | self.rel_pos_bias = None 298 | 299 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 300 | self.use_rel_pos_bias = use_rel_pos_bias 301 | self.blocks = nn.ModuleList([ 302 | Block( 303 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 304 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 305 | init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None, 306 | with_cls_token=with_cls_token, with_k_bias=with_k_bias) 307 | for i in range(depth)]) 308 | self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim) 309 | self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None 310 | if num_classes > 0: 311 | self.head = nn.Linear(embed_dim, num_classes) 312 | else: 313 | self.head = nn.Identity() 314 | 315 | if self.pos_embed is not None: 316 | self._trunc_normal_(self.pos_embed, std=.02) 317 | if self.cls_token is not None: 318 | self._trunc_normal_(self.cls_token, std=.02) 319 | if num_classes > 0: 320 | self._trunc_normal_(self.head.weight, std=.02) 321 | self.apply(self._init_weights) 322 | self.fix_init_weight() 323 | 324 | if num_classes > 0: 325 | self.head.weight.data.mul_(init_scale) 326 | self.head.bias.data.mul_(init_scale) 327 | 328 | 329 | def _trunc_normal_(self, tensor, mean=0., std=1.): 330 | trunc_normal_(tensor, mean=mean, std=std) 331 | 332 | def fix_init_weight(self): 333 | def rescale(param, layer_id): 334 | param.div_(math.sqrt(2.0 * layer_id)) 335 | 336 | for layer_id, layer in enumerate(self.blocks): 337 | rescale(layer.attn.proj.weight.data, layer_id + 1) 338 | rescale(layer.mlp.fc2.weight.data, layer_id + 1) 339 | 340 | def _init_weights(self, m): 341 | if isinstance(m, nn.Linear): 342 | self._trunc_normal_(m.weight, std=.02) 343 | if isinstance(m, nn.Linear) and m.bias is not None: 344 | nn.init.constant_(m.bias, 0) 345 | elif isinstance(m, nn.LayerNorm): 346 | nn.init.constant_(m.bias, 0) 347 | nn.init.constant_(m.weight, 1.0) 348 | elif isinstance(m, nn.Conv2d): 349 | self._trunc_normal_(m.weight, std=.02) 350 | if m.bias is not None: 351 | nn.init.constant_(m.bias, 0) 352 | 353 | def get_num_layers(self): 354 | return len(self.blocks) 355 | 356 | @torch.jit.ignore 357 | def no_weight_decay(self): 358 | return {'pos_embed', 'cls_token'} 359 | 360 | def get_classifier(self): 361 | return self.head 362 | 363 | def reset_classifier(self, num_classes, global_pool=''): 364 | self.num_classes = num_classes 365 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 366 | 367 | def forward_features(self, x): 368 | x = self.patch_embed(x) 369 | batch_size, seq_len, _ = x.size() 370 | 371 | if self.cls_token is not None: 372 | cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 373 | x = torch.cat((cls_tokens, x), dim=1) 374 | if self.pos_embed is not None: 375 | x = x + self.pos_embed 376 | x = self.pos_drop(x) 377 | 378 | rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None 379 | for blk in self.blocks: 380 | if self.use_checkpoint: 381 | x = checkpoint.checkpoint(blk, x, rel_pos_bias) 382 | else: 383 | x = blk(x, rel_pos_bias=rel_pos_bias) 384 | 385 | x = self.norm(x) 386 | if self.fc_norm is not None: 387 | t = x[:, 1:, :] if self.cls_token is not None else x 388 | return self.fc_norm(t.mean(1)) 389 | else: 390 | if self.cls_token is not None: 391 | return x[:, 0] 392 | else: 393 | raise ValueError 394 | 395 | def forward(self, x): 396 | x = self.forward_features(x) 397 | x = self.head(x) 398 | return x 399 | 400 | 401 | def build_vit(config): 402 | model = VisionTransformer( 403 | img_size=config.DATA.IMG_SIZE, # 224 404 | patch_size=config.MODEL.VIT.PATCH_SIZE, # 16 405 | in_chans=config.MODEL.VIT.IN_CHANS, # 3 406 | num_classes=config.MODEL.NUM_CLASSES, # 0 407 | embed_dim=config.MODEL.VIT.EMBED_DIM, # 768 408 | depth=config.MODEL.VIT.DEPTH, # 12 409 | num_heads=config.MODEL.VIT.NUM_HEADS, # 12 410 | mlp_ratio=config.MODEL.VIT.MLP_RATIO, # 4. 411 | qkv_bias=config.MODEL.VIT.QKV_BIAS, # False 412 | with_k_bias=config.DEV.VIT_WITHKBIAS, # False 413 | drop_rate=config.MODEL.DROP_RATE, # 0.0 414 | drop_path_rate=config.MODEL.DROP_PATH_RATE, # can set to 0.1 415 | norm_layer=partial(nn.LayerNorm, eps=1e-6), 416 | init_values=config.MODEL.VIT.INIT_VALUES, 417 | use_abs_pos_emb=config.MODEL.VIT.USE_APE, # False 418 | use_rel_pos_bias=config.MODEL.VIT.USE_RPB, # True 419 | use_shared_rel_pos_bias=config.MODEL.VIT.USE_SHARED_RPB, # False 420 | with_cls_token=config.MODEL.VIT.WITH_CLS_TOKEN, # True 421 | use_checkpoint=config.TRAIN.USE_CHECKPOINT,) 422 | 423 | return model 424 | -------------------------------------------------------------------------------- /optimizer.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Copyright (c) 2021 Microsoft 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # Written by Ze Liu 5 | # Modified by Zhenda Xie 6 | # Modified by Yixuan Wei 7 | # -------------------------------------------------------- 8 | 9 | import json 10 | from functools import partial 11 | from torch import optim as optim 12 | 13 | 14 | def build_optimizer(config, model, logger, is_pretrain): 15 | if is_pretrain: 16 | return build_pretrain_optimizer(config, model, logger) 17 | else: 18 | return build_finetune_optimizer(config, model, logger) 19 | 20 | 21 | def build_pretrain_optimizer(config, model, logger): 22 | logger.info('>>>>>>>>>> Build Optimizer for Pre-training Stage') 23 | skip = {} 24 | skip_keywords = {} 25 | if hasattr(model, 'no_weight_decay'): 26 | skip = model.no_weight_decay() 27 | logger.info(f'No weight decay: {skip}') 28 | if hasattr(model, 'no_weight_decay_keywords'): 29 | skip_keywords = model.no_weight_decay_keywords() 30 | logger.info(f'No weight decay keywords: {skip_keywords}') 31 | 32 | parameters = get_pretrain_param_groups(model, logger, skip, skip_keywords) 33 | 34 | opt_lower = config.TRAIN.OPTIMIZER.NAME.lower() 35 | optimizer = None 36 | if opt_lower == 'sgd': 37 | optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True, 38 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 39 | elif opt_lower == 'adamw': 40 | optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, 41 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 42 | 43 | logger.info(optimizer) 44 | return optimizer 45 | 46 | 47 | def get_pretrain_param_groups(model, logger, skip_list=(), skip_keywords=()): 48 | has_decay = [] 49 | no_decay = [] 50 | has_decay_name = [] 51 | no_decay_name = [] 52 | 53 | for name, param in model.named_parameters(): 54 | if not param.requires_grad: 55 | continue 56 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ 57 | check_keywords_in_name(name, skip_keywords): 58 | no_decay.append(param) 59 | no_decay_name.append(name) 60 | else: 61 | has_decay.append(param) 62 | has_decay_name.append(name) 63 | logger.info(f'No decay params: {no_decay_name}') 64 | logger.info(f'Has decay params: {has_decay_name}') 65 | return [{'params': has_decay}, 66 | {'params': no_decay, 'weight_decay': 0.}] 67 | 68 | 69 | def build_finetune_optimizer(config, model, logger): 70 | logger.info('>>>>>>>>>> Build Optimizer for Fine-tuning Stage') 71 | if config.MODEL.TYPE in ['swin', 'swin_v2']: 72 | depths = config.MODEL.SWIN.DEPTHS 73 | num_layers = sum(depths) 74 | get_layer_func = partial(get_swin_layer, num_layers=num_layers + 2, depths=depths) 75 | elif config.MODEL.TYPE in ['vit', 'rsvit', 'vit_mae', 'vit_dino', 'vit_feat']: 76 | num_layers = config.MODEL.VIT.DEPTH 77 | get_layer_func = partial(get_vit_layer, num_layers=num_layers + 2) 78 | elif config.MODEL.TYPE in ['resnet']: 79 | layers = model.layers 80 | num_layers = sum(layers) 81 | get_layer_func = partial(get_resnet_layer, num_layers=num_layers + 2, layers=layers) 82 | else: 83 | raise NotImplementedError 84 | 85 | scales = list(config.TRAIN.LAYER_DECAY ** i for i in reversed(range(num_layers + 2))) 86 | 87 | skip = {} 88 | skip_keywords = {} 89 | if hasattr(model, 'no_weight_decay'): 90 | skip = model.no_weight_decay() 91 | logger.info(f'No weight decay: {skip}') 92 | if hasattr(model, 'no_weight_decay_keywords'): 93 | skip_keywords = model.no_weight_decay_keywords() 94 | logger.info(f'No weight decay keywords: {skip_keywords}') 95 | 96 | parameters = get_finetune_param_groups( 97 | model, logger, config.TRAIN.BASE_LR, config.TRAIN.WEIGHT_DECAY, 98 | get_layer_func, scales, skip, skip_keywords) 99 | 100 | opt_lower = config.TRAIN.OPTIMIZER.NAME.lower() 101 | optimizer = None 102 | if opt_lower == 'sgd': 103 | optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True, 104 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 105 | elif opt_lower == 'adamw': 106 | optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, 107 | lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) 108 | 109 | logger.info(optimizer) 110 | return optimizer 111 | 112 | 113 | def get_vit_layer(name, num_layers): 114 | if name in ("cls_token", "mask_token", "pos_embed"): 115 | return 0 116 | elif name.startswith("patch_embed"): 117 | return 0 118 | elif name.startswith("rel_pos_bias"): 119 | return num_layers - 1 120 | elif name.startswith("blocks"): 121 | layer_id = int(name.split('.')[1]) 122 | return layer_id + 1 123 | else: 124 | return num_layers - 1 125 | 126 | 127 | def get_swin_layer(name, num_layers, depths): 128 | if name in ("mask_token"): 129 | return 0 130 | elif name.startswith("patch_embed"): 131 | return 0 132 | elif name.startswith("layers"): 133 | layer_id = int(name.split('.')[1]) 134 | block_id = name.split('.')[3] 135 | if block_id == 'reduction' or block_id == 'norm': 136 | return sum(depths[:layer_id + 1]) 137 | layer_id = sum(depths[:layer_id]) + int(block_id) 138 | return layer_id + 1 139 | else: 140 | return num_layers - 1 141 | 142 | 143 | def get_resnet_layer(name, num_layers, layers): 144 | if name in ("mask_token"): 145 | return 0 146 | elif name.startswith('conv1') or name.startswith('bn1'): 147 | return 0 148 | elif name.startswith('layer'): 149 | layer_index = int(name.split('.')[0][-1]) 150 | block_index = int(name.split('.')[1]) + sum(layers[:layer_index - 1]) + 1 151 | return block_index 152 | else: 153 | return num_layers - 1 154 | 155 | 156 | def get_finetune_param_groups(model, logger, lr, weight_decay, get_layer_func, scales, skip_list=(), skip_keywords=()): 157 | parameter_group_names = {} 158 | parameter_group_vars = {} 159 | 160 | for name, param in model.named_parameters(): 161 | if not param.requires_grad: 162 | continue 163 | if len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ 164 | check_keywords_in_name(name, skip_keywords): 165 | group_name = "no_decay" 166 | this_weight_decay = 0. 167 | else: 168 | group_name = "decay" 169 | this_weight_decay = weight_decay 170 | if get_layer_func is not None: 171 | layer_id = get_layer_func(name) 172 | group_name = "layer_%d_%s" % (layer_id, group_name) 173 | else: 174 | layer_id = None 175 | 176 | if group_name not in parameter_group_names: 177 | if scales is not None: 178 | scale = scales[layer_id] 179 | else: 180 | scale = 1. 181 | 182 | parameter_group_names[group_name] = { 183 | "group_name": group_name, 184 | "weight_decay": this_weight_decay, 185 | "params": [], 186 | "lr": lr * scale, 187 | "lr_scale": scale, 188 | } 189 | parameter_group_vars[group_name] = { 190 | "group_name": group_name, 191 | "weight_decay": this_weight_decay, 192 | "params": [], 193 | "lr": lr * scale, 194 | "lr_scale": scale 195 | } 196 | 197 | parameter_group_vars[group_name]["params"].append(param) 198 | parameter_group_names[group_name]["params"].append(name) 199 | logger.info("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) 200 | return list(parameter_group_vars.values()) 201 | 202 | 203 | def check_keywords_in_name(name, keywords=()): 204 | isin = False 205 | for keyword in keywords: 206 | if keyword in name: 207 | isin = True 208 | return isin -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pyyaml 2 | scipy 3 | termcolor 4 | timm==0.4.12 5 | yacs 6 | wandb 7 | jsonlines 8 | einops==0.3.2 -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Feature Distillation 3 | # Copyright (c) 2021 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ze Liu 6 | # Modified by Zhenda Xie 7 | # Modified by Yixuan Wei 8 | # -------------------------------------------------------- 9 | 10 | import os 11 | import torch 12 | import torch.distributed as dist 13 | import numpy as np 14 | import torchvision.transforms as T 15 | from scipy import interpolate 16 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 17 | 18 | mean = torch.tensor(IMAGENET_DEFAULT_MEAN) 19 | std = torch.tensor(IMAGENET_DEFAULT_STD) 20 | normalize = T.Normalize(mean=mean, std=std) 21 | unnormalize = T.Normalize(mean=-mean / std, std=1.0 / std) 22 | 23 | 24 | def load_checkpoint(config, model, optimizer, lr_scheduler, scaler, logger): 25 | logger.info(f">>>>>>>>>> Resuming from {config.MODEL.RESUME} ..........") 26 | if config.MODEL.RESUME.startswith('https'): 27 | checkpoint = torch.hub.load_state_dict_from_url( 28 | config.MODEL.RESUME, map_location='cpu', check_hash=True) 29 | else: 30 | checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu') 31 | msg = model.load_state_dict(checkpoint['model'], strict=False) 32 | logger.info(msg) 33 | 34 | if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'scaler' in checkpoint and 'epoch' in checkpoint: 35 | optimizer.load_state_dict(checkpoint['optimizer']) 36 | logger.info('Load Lr Scheduler') 37 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 38 | scaler.load_state_dict(checkpoint['scaler']) 39 | 40 | config.defrost() 41 | config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1 42 | config.freeze() 43 | 44 | logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})") 45 | if 'max_accuracy' in checkpoint: 46 | max_accuracy = checkpoint['max_accuracy'] 47 | else: 48 | max_accuracy = 0.0 49 | 50 | del checkpoint 51 | torch.cuda.empty_cache() 52 | return max_accuracy 53 | 54 | 55 | def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, scaler, logger): 56 | save_state = {'model': model.state_dict(), 57 | 'optimizer': optimizer.state_dict(), 58 | 'lr_scheduler': lr_scheduler.state_dict(), 59 | 'scaler': scaler.state_dict(), 60 | 'max_accuracy': max_accuracy, 61 | 'epoch': epoch, 62 | 'config': config} 63 | 64 | save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth') 65 | logger.info(f"{save_path} saving......") 66 | torch.save(save_state, save_path) 67 | logger.info(f"{save_path} saved !!!") 68 | 69 | 70 | def get_grad_norm(parameters, norm_type=2): 71 | if isinstance(parameters, torch.Tensor): 72 | parameters = [parameters] 73 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 74 | norm_type = float(norm_type) 75 | total_norm = [] 76 | for p in parameters: 77 | param_norm = p.grad.data.norm(norm_type) 78 | # total_norm += param_norm.item() ** norm_type 79 | total_norm.append(param_norm) 80 | # total_norm = total_norm ** (1. / norm_type) 81 | total_norm = torch.stack(total_norm).norm(norm_type).item() 82 | return total_norm 83 | 84 | 85 | def auto_resume_helper(output_dir, logger): 86 | checkpoints = os.listdir(output_dir) 87 | checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')] 88 | logger.info(f"All checkpoints founded in {output_dir}: {checkpoints}") 89 | if len(checkpoints) > 0: 90 | latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime) 91 | logger.info(f"The latest checkpoint founded: {latest_checkpoint}") 92 | resume_file = latest_checkpoint 93 | else: 94 | resume_file = None 95 | return resume_file 96 | 97 | 98 | def reduce_tensor(tensor): 99 | rt = tensor.clone() 100 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 101 | rt /= dist.get_world_size() 102 | return rt 103 | 104 | 105 | def load_pretrained(config, model, logger): 106 | logger.info(f">>>>>>>>>> Fine-tuned from {config.PRETRAINED} ..........") 107 | checkpoint = torch.load(config.PRETRAINED, map_location='cpu') 108 | checkpoint_model = checkpoint['model'] if 'model' in checkpoint else checkpoint 109 | checkpoint_model = checkpoint_model['student'] if 'student' in checkpoint_model else checkpoint_model # for esvit 110 | 111 | if any([True if 'encoder.' in k else False for k in checkpoint_model.keys()]): 112 | checkpoint_model = {k.replace('encoder.', ''): v for k, v in checkpoint_model.items() if k.startswith('encoder.')} 113 | logger.info('Detect pre-trained model, remove [encoder.] prefix.') 114 | elif any([True if 'module.' in k else False for k in checkpoint_model.keys()]): 115 | checkpoint_model = {k.replace('module.', ''): v for k, v in checkpoint_model.items() if k.startswith('module.')} 116 | logger.info('Detect pre-trained model, remove [module.] prefix.') 117 | else: 118 | logger.info('Detect non-pre-trained model, pass without doing anything.') 119 | 120 | if config.DEV.FT_SKIP_REMAP: 121 | logger.info(f">>>>>>>>>> Skip remapping when loading pre-trained model") 122 | else: 123 | if config.MODEL.TYPE in ['swin', 'swin_v2']: 124 | logger.info(f">>>>>>>>>> Remapping pre-trained keys for SWIN ..........") 125 | checkpoint_model = remap_pretrained_keys_swin(model, checkpoint_model, logger) 126 | elif config.MODEL.TYPE in ['vit']: 127 | logger.info(f">>>>>>>>>> Remapping pre-trained keys for VIT ..........") 128 | checkpoint_model = remap_pretrained_keys_vit(model, checkpoint_model, logger) 129 | else: 130 | raise NotImplementedError 131 | 132 | msg = model.load_state_dict(checkpoint_model, strict=False) 133 | logger.info(msg) 134 | 135 | del checkpoint 136 | del checkpoint_model 137 | torch.cuda.empty_cache() 138 | logger.info(f">>>>>>>>>> loaded successfully '{config.PRETRAINED}'") 139 | 140 | 141 | def remap_pretrained_keys_swin(model, checkpoint_model, logger): 142 | # Duplicate shared rel_pos_bias to each layer 143 | if "layers.0.rel_pos_bias.relative_coords_table" in checkpoint_model: 144 | # only support swinv2 145 | logger.info("Expand the shared relative position embedding to each transformer block.") 146 | for l in range(model.num_layers): 147 | # relative_coords_table = checkpoint_model.pop(f"layers.{l}.rel_pos_bias.relative_coords_table") 148 | # relative_position_index = checkpoint_model.pop(f"layers.{l}.rel_pos_bias.relative_position_index") 149 | mlp0weight = checkpoint_model.pop(f"layers.{l}.rel_pos_bias.rpe_mlp.0.weight") 150 | mlp0bias = checkpoint_model.pop(f"layers.{l}.rel_pos_bias.rpe_mlp.0.bias") 151 | mlp2bias = checkpoint_model.pop(f"layers.{l}.rel_pos_bias.rpe_mlp.2.weight") 152 | for i in range(model.depths[l]): 153 | # checkpoint_model[f"layers.{l}.blocks.{i}.attn.relative_coords_table"] = relative_coords_table.clone() 154 | # checkpoint_model[f"layers.{l}.blocks.{i}.attn.relative_position_index"] = relative_position_index.clone() 155 | checkpoint_model[f"layers.{l}.blocks.{i}.attn.rpe_mlp.0.weight"] = mlp0weight.clone() 156 | checkpoint_model[f"layers.{l}.blocks.{i}.attn.rpe_mlp.0.bias"] = mlp0bias.clone() 157 | checkpoint_model[f"layers.{l}.blocks.{i}.attn.rpe_mlp.2.weight"] = mlp2bias.clone() 158 | 159 | state_dict = model.state_dict() 160 | 161 | # Geometric interpolation when pre-trained patch size mismatch with fine-tuned patch size 162 | all_keys = list(checkpoint_model.keys()) 163 | for key in all_keys: 164 | if "relative_position_bias_table" in key: 165 | relative_position_bias_table_pretrained = checkpoint_model[key] 166 | relative_position_bias_table_current = state_dict[key] 167 | L1, nH1 = relative_position_bias_table_pretrained.size() 168 | L2, nH2 = relative_position_bias_table_current.size() 169 | if nH1 != nH2: 170 | logger.info(f"Error in loading {key}, passing......") 171 | else: 172 | if L1 != L2: 173 | logger.info(f"{key}: Interpolate relative_position_bias_table using geo.") 174 | src_size = int(L1 ** 0.5) 175 | dst_size = int(L2 ** 0.5) 176 | 177 | def geometric_progression(a, r, n): 178 | return a * (1.0 - r ** n) / (1.0 - r) 179 | 180 | left, right = 1.01, 1.5 181 | while right - left > 1e-6: 182 | q = (left + right) / 2.0 183 | gp = geometric_progression(1, q, src_size // 2) 184 | if gp > dst_size // 2: 185 | right = q 186 | else: 187 | left = q 188 | 189 | # if q > 1.090307: 190 | # q = 1.090307 191 | 192 | dis = [] 193 | cur = 1 194 | for i in range(src_size // 2): 195 | dis.append(cur) 196 | cur += q ** (i + 1) 197 | 198 | r_ids = [-_ for _ in reversed(dis)] 199 | 200 | x = r_ids + [0] + dis 201 | y = r_ids + [0] + dis 202 | 203 | t = dst_size // 2.0 204 | dx = np.arange(-t, t + 0.1, 1.0) 205 | dy = np.arange(-t, t + 0.1, 1.0) 206 | 207 | logger.info("Original positions = %s" % str(x)) 208 | logger.info("Target positions = %s" % str(dx)) 209 | 210 | all_rel_pos_bias = [] 211 | 212 | for i in range(nH1): 213 | z = relative_position_bias_table_pretrained[:, i].view(src_size, src_size).float().numpy() 214 | f_cubic = interpolate.interp2d(x, y, z, kind='cubic') 215 | all_rel_pos_bias.append(torch.Tensor(f_cubic(dx, dy)).contiguous().view(-1, 1).to( 216 | relative_position_bias_table_pretrained.device)) 217 | 218 | new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) 219 | checkpoint_model[key] = new_rel_pos_bias 220 | 221 | # delete relative_position_index since we always re-init it 222 | relative_position_index_keys = [k for k in checkpoint_model.keys() if "relative_position_index" in k] 223 | for k in relative_position_index_keys: 224 | del checkpoint_model[k] 225 | 226 | # delete relative_coords_table since we always re-init it 227 | relative_coords_table_keys = [k for k in checkpoint_model.keys() if "relative_coords_table" in k] 228 | for k in relative_coords_table_keys: 229 | del checkpoint_model[k] 230 | 231 | # delete attn_mask since we always re-init it 232 | attn_mask_keys = [k for k in checkpoint_model.keys() if "attn_mask" in k] 233 | for k in attn_mask_keys: 234 | del checkpoint_model[k] 235 | 236 | return checkpoint_model 237 | 238 | 239 | def remap_pretrained_keys_vit(model, checkpoint_model, logger, rpe_method=None): 240 | # Duplicate shared rel_pos_bias to each layer 241 | if getattr(model, 'use_rel_pos_bias', False) and "rel_pos_bias.relative_position_bias_table" in checkpoint_model: 242 | logger.info("Expand the shared relative position embedding to each transformer block.") 243 | num_layers = model.get_num_layers() 244 | if "rel_pos_bias.relative_position_bias_table" in checkpoint_model: 245 | rel_pos_bias = checkpoint_model["rel_pos_bias.relative_position_bias_table"] 246 | for i in range(num_layers): 247 | checkpoint_model["blocks.%d.attn.relative_position_bias_table" % i] = rel_pos_bias.clone() 248 | checkpoint_model.pop("rel_pos_bias.relative_position_bias_table") 249 | 250 | # Geometric interpolation when pre-trained patch size mismatch with fine-tuned patch size 251 | all_keys = list(checkpoint_model.keys()) 252 | for key in all_keys: 253 | if "relative_position_index" in key: 254 | checkpoint_model.pop(key) 255 | 256 | if "relative_position_bias_table" in key: 257 | rel_pos_bias = checkpoint_model[key] 258 | src_num_pos, num_attn_heads = rel_pos_bias.size() 259 | if key not in model.state_dict(): 260 | # case for additional encoder block 261 | continue 262 | dst_num_pos, _ = model.state_dict()[key].size() 263 | dst_patch_shape = model.patch_embed.patch_shape 264 | if dst_patch_shape[0] != dst_patch_shape[1]: 265 | raise NotImplementedError() 266 | num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1) 267 | src_size = int((src_num_pos - num_extra_tokens) ** 0.5) 268 | dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5) 269 | if src_size != dst_size: 270 | logger.info("Position interpolate for %s from %dx%d to %dx%d" % (key, src_size, src_size, dst_size, dst_size)) 271 | if rpe_method == 'outer_mask': 272 | pad_size = (dst_size - src_size) // 2 273 | padding = (pad_size, pad_size, pad_size, pad_size) 274 | 275 | all_rel_pos_bias = [] 276 | for i in range(num_attn_heads): 277 | z = rel_pos_bias[:, i].view(src_size, src_size) 278 | all_rel_pos_bias.append( 279 | torch.nn.functional.pad(z, padding, "constant", z.min().item() - 3).view(dst_num_pos, 1)) 280 | new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) 281 | 282 | checkpoint_model[key] = new_rel_pos_bias 283 | else: 284 | if num_extra_tokens > 0: 285 | extra_tokens = rel_pos_bias[-num_extra_tokens:, :] 286 | rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] 287 | else: 288 | extra_tokens = rel_pos_bias.new_zeros((0, num_attn_heads)) 289 | 290 | def geometric_progression(a, r, n): 291 | return a * (1.0 - r ** n) / (1.0 - r) 292 | 293 | left, right = 1.01, 1.5 294 | while right - left > 1e-6: 295 | q = (left + right) / 2.0 296 | gp = geometric_progression(1, q, src_size // 2) 297 | if gp > dst_size // 2: 298 | right = q 299 | else: 300 | left = q 301 | 302 | # if q > 1.090307: 303 | # q = 1.090307 304 | 305 | dis = [] 306 | cur = 1 307 | for i in range(src_size // 2): 308 | dis.append(cur) 309 | cur += q ** (i + 1) 310 | 311 | r_ids = [-_ for _ in reversed(dis)] 312 | 313 | x = r_ids + [0] + dis 314 | y = r_ids + [0] + dis 315 | 316 | t = dst_size // 2.0 317 | dx = np.arange(-t, t + 0.1, 1.0) 318 | dy = np.arange(-t, t + 0.1, 1.0) 319 | 320 | logger.info("Original positions = %s" % str(x)) 321 | logger.info("Target positions = %s" % str(dx)) 322 | 323 | all_rel_pos_bias = [] 324 | 325 | for i in range(num_attn_heads): 326 | z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy() 327 | f = interpolate.interp2d(x, y, z, kind='cubic') 328 | all_rel_pos_bias.append( 329 | torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device)) 330 | 331 | rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) 332 | 333 | new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0) 334 | checkpoint_model[key] = new_rel_pos_bias 335 | 336 | if 'pos_embed' in all_keys and getattr(model, 'pos_embed', None) is not None: 337 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 338 | embedding_size = pos_embed_checkpoint.shape[-1] 339 | num_patches = model.patch_embed.num_patches 340 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 341 | # height (== width) for the checkpoint position embedding 342 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 343 | # height (== width) for the new position embedding 344 | new_size = int(num_patches ** 0.5) 345 | # class_token and dist_token are kept unchanged 346 | if orig_size != new_size: 347 | logger.info("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 348 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 349 | # only the position tokens are interpolated 350 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 351 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 352 | pos_tokens = torch.nn.functional.interpolate( 353 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 354 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 355 | if num_extra_tokens > 0: 356 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 357 | else: 358 | new_pos_embed = pos_tokens 359 | checkpoint_model['pos_embed'] = new_pos_embed 360 | 361 | # delete relative_coords_table since we always re-init it 362 | relative_coords_table_keys = [k for k in checkpoint_model.keys() if "relative_coords_table" in k] 363 | for k in relative_coords_table_keys: 364 | del checkpoint_model[k] 365 | 366 | return checkpoint_model 367 | 368 | 369 | # Copyright (c) Meta Platforms, Inc. and affiliates. 370 | # All rights reserved. 371 | 372 | # This source code is licensed under the license found in the 373 | # LICENSE file in the root directory of this source tree. 374 | # -------------------------------------------------------- 375 | # LARS optimizer, implementation from MoCo v3: 376 | # https://github.com/facebookresearch/moco-v3 377 | # -------------------------------------------------------- 378 | 379 | 380 | class LARS(torch.optim.Optimizer): 381 | """ 382 | LARS optimizer, no rate scaling or weight decay for parameters <= 1D. 383 | """ 384 | def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001): 385 | defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient) 386 | super().__init__(params, defaults) 387 | 388 | @torch.no_grad() 389 | def step(self): 390 | for g in self.param_groups: 391 | for p in g['params']: 392 | dp = p.grad 393 | 394 | if dp is None: 395 | continue 396 | 397 | if p.ndim > 1: # if not normalization gamma/beta or bias 398 | dp = dp.add(p, alpha=g['weight_decay']) 399 | param_norm = torch.norm(p) 400 | update_norm = torch.norm(dp) 401 | one = torch.ones_like(param_norm) 402 | q = torch.where(param_norm > 0., 403 | torch.where(update_norm > 0, 404 | (g['trust_coefficient'] * param_norm / update_norm), one), 405 | one) 406 | dp = dp.mul(q) 407 | 408 | param_state = self.state[p] 409 | if 'mu' not in param_state: 410 | param_state['mu'] = torch.zeros_like(p) 411 | mu = param_state['mu'] 412 | mu.mul_(g['momentum']).add_(dp) 413 | p.add_(mu, alpha=-g['lr']) --------------------------------------------------------------------------------