├── .gitignore ├── README.md ├── ada_main.py ├── adavit_ckpt ├── deit-s-h-l-tmlp_flops_dict.pth └── t2t-19-h-l-tmlp_flops_dict.pth ├── assets └── adavit_approach.png ├── block_flops_dict.py ├── models ├── __init__.py ├── ada_t2t_vit.py ├── ada_transformer_block.py ├── deit.py ├── losses.py ├── token_performer.py ├── token_transformer.py └── transformer_block.py ├── saver.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | test_case.py 2 | summary.csv 3 | output/ 4 | output_ada/ 5 | data/ 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AdaViT: Adaptive Vision Transformers for Efficient Image Recognition 2 | 3 | Lingchen Meng*1, Hengduo Li*2, Bor-Chun Chen3, Shiyi Lan2, Zuxuan Wu1, Yu-Gang Jiang1, Ser-Nam Lim3
4 | 1Shanghai Key Lab of Intelligent Information Processing, School of Computer Science, Fudan Univeristy, 2University of Maryland, 3Meta AI
5 | \* Equal contribution 6 | 7 | 8 | This repository is an official implementation of the [AdaViT](https://arxiv.org/abs/2111.15668). 9 | Our codes are based on the [pytorch-image-models](https://github.com/rwightman/pytorch-image-models) and [T2T-ViT](https://github.com/yitu-opensource/T2T-ViT) 10 | 11 | ## Abstract 12 | Built on top of self-attention mechanisms, vision transformers have demonstrated remarkable performance on a variety of tasks recently. While achieving excellent performance, they still require relatively intensive computational cost that scales up drastically as the numbers of patches, self-attention heads and transformer blocks increase. In this paper, we argue that due to the large variations among images, their 13 | need for modeling long-range dependencies between patches differ. To this end, we introduce AdaViT, an adaptive computation framework that learns to derive usage policies on which patches, self-attention heads and transformer blocks to use throughout the backbone on a per-input basis, aiming to improve inference efficiency of vision transformers with a minimal drop of accuracy for image recognition. Optimized jointly with a transformer backbone in an end-to-end manner, a light-weight decision network is attached to the backbone to produce decisions on-the-fly. Extensive experiments on ImageNet demonstrate that our method obtains more than $2\times$ improvement on efficiency compared to state-of-the-art vision transformers with only $0.8\%$ drop of accuracy, achieving good efficiency/accuracy trade-offs conditioned on different computational budgets. We further conduct quantitative and qualitative analysis on learned usage polices and provide more insights on the redundancy in vision transformers. 14 | 15 | 16 | 17 | ## Model Zoo 18 | We have put our model checkpoints here. 19 | | Model | Top1 Acc | MACs | Download| 20 | | :--- | :---: | :---: | :---: | 21 | | Ada-T2T-ViT-19 | 81.1 | 3.9G | [link](https://drive.google.com/file/d/1bSh9E2HDM66L5FAbrTqh6slSWaYPeKNR/view?usp=sharing)| 22 | | Ada-DeiT-S | 77.3 | 2.3G | [link](https://drive.google.com/file/d/1vkD6w9J8sf64IPhTBgyfvsTvUlZw6TNa/view?usp=sharing)| 23 | 24 | 25 | ## Eval our model 26 | Download our AdaViT with T2T-ViT-19 from [google drive](https://drive.google.com/file/d/1bSh9E2HDM66L5FAbrTqh6slSWaYPeKNR/view?usp=sharing) and perform the command below. You can expect to get the Acc about 81.1 with 3.9 GFLOPS. 27 | 28 | ```sh 29 | python3 ada_main.py /path/to/your/imagenet \ 30 | --model ada_step_t2t_vit_19_lnorm \ 31 | --ada-head --ada-layer --ada-token-with-mlp \ 32 | --flops-dict adavit_ckpt/t2t-19-h-l-tmlp_flops_dict.pth \ 33 | --eval_checkpoint /path/to/your/checkpoint 34 | 35 | python3 ada_main.py /path/to/your/imagenet \ 36 | --model ada_step_deit_small_patch16_224 \ 37 | --ada-head --ada-layer --ada-token-with-mlp \ 38 | --flops-dict adavit_ckpt/deit-s-h-l-tmlp_flops_dict.pth \ 39 | --eval_checkpoint /path/to/your/checkpoint 40 | ``` 41 | 42 | ## Citation 43 | ``` 44 | @inproceedings{meng2022adavit, 45 | title={AdaViT: Adaptive Vision Transformers for Efficient Image Recognition}, 46 | author={Meng, Lingchen and Li, Hengduo and Chen, Bor-Chun and Lan, Shiyi and Wu, Zuxuan and Jiang, Yu-Gang and Lim, Ser-Nam}, 47 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 48 | pages={12309--12318}, 49 | year={2022} 50 | } 51 | ``` 52 | -------------------------------------------------------------------------------- /ada_main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import yaml 4 | import os 5 | import logging 6 | from collections import OrderedDict 7 | from contextlib import suppress 8 | import datetime 9 | from time import gmtime, strftime 10 | import models 11 | from utils import ada_load_state_dict 12 | from models.losses import AdaHeadLoss, AdaLoss, TeacherLoss 13 | 14 | import torch 15 | import torch.nn as nn 16 | import torchvision.utils 17 | from torch.nn.parallel import DistributedDataParallel as NativeDDP 18 | 19 | from timm.data import Dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset 20 | from timm.models import load_checkpoint, create_model, resume_checkpoint, convert_splitbn_model 21 | from timm.utils import * 22 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy 23 | from timm.optim import create_optimizer 24 | from timm.scheduler import create_scheduler 25 | from timm.utils import ApexScaler, NativeScaler 26 | 27 | import pdb 28 | from saver import MyCheckpointSaver 29 | 30 | torch.backends.cudnn.benchmark = True 31 | _logger = logging.getLogger('train') 32 | 33 | # The first arg parser parses out only the --config argument, this argument is used to 34 | # load a yaml file containing key-values that override the defaults for the main parser below 35 | config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False) 36 | parser.add_argument('-c', '--config', default='', type=str, metavar='FILE', 37 | help='YAML config file specifying default arguments') 38 | 39 | parser = argparse.ArgumentParser(description='T2T-ViT Training and Evaluating') 40 | 41 | parser.add_argument('--head-ratio', type=float, default=2.0, 42 | help='') 43 | parser.add_argument('--layer-ratio', type=float, default=2.0, 44 | help='') 45 | parser.add_argument('--attn-ratio', type=float, default=0., 46 | help='') 47 | parser.add_argument('--hidden-ratio', type=float, default=0., 48 | help='') 49 | parser.add_argument('--pred-ratio', type=float, default=0., 50 | help='') 51 | parser.add_argument('--head-target-ratio', type=float, default=0.5, 52 | help='') 53 | parser.add_argument('--layer-target-ratio', type=float, default=0.5, 54 | help='') 55 | parser.add_argument('--head-diverse-ratio', type=float, default=0., 56 | help='') 57 | parser.add_argument('--layer-diverse-ratio', type=float, default=0., 58 | help='') 59 | parser.add_argument('--head-select-tau', type=float, default=5., 60 | help='') 61 | parser.add_argument('--layer-select-tau', type=float, default=5., 62 | help='') 63 | parser.add_argument('--token-select-tau', type=float, default=5., 64 | help='') 65 | parser.add_argument('--head-entropy-weight', type=float, default=0., 66 | help='') 67 | parser.add_argument('--layer-entropy-weight', type=float, default=0., 68 | help='') 69 | parser.add_argument('--head-minimal-weight', type=float, default=0., 70 | help='') 71 | parser.add_argument('--head-minimal', type=float, default=0., 72 | help='') 73 | parser.add_argument('--layer-minimal-weight', type=float, default=0., 74 | help='') 75 | parser.add_argument('--layer-minimal', type=float, default=0., 76 | help='') 77 | parser.add_argument('--token-ratio', type=float, default=2., 78 | help='') 79 | parser.add_argument('--token-target-ratio', type=float, default=0.5, 80 | help='') 81 | parser.add_argument('--token-minimal', type=float, default=0., 82 | help='') 83 | parser.add_argument('--token-minimal-weight', type=float, default=0., 84 | help='') 85 | 86 | parser.add_argument('--inner-loop', type=int, default=-1, 87 | help='') 88 | 89 | # Dataset / Model parameters 90 | parser.add_argument('data', metavar='DIR', 91 | help='path to dataset') 92 | parser.add_argument('--norm-policy', action='store_true', default=False, dest='norm_policy', 93 | help='') 94 | parser.add_argument('--keep-layers', type=int, default=1, 95 | help='use layers to make selection decision') 96 | parser.add_argument('--ada-layer', action='store_true', default=False, dest='ada_layer', 97 | help='') 98 | parser.add_argument('--ada-block', action='store_true', default=False, dest='ada_block', 99 | help='') 100 | parser.add_argument('--ada-head', action='store_true', default=False, dest='ada_head', 101 | help='') 102 | parser.add_argument('--ada-head-v2', action='store_true', default=False, dest='ada_head_v2', 103 | help='') 104 | parser.add_argument('--dyna-data', action='store_true', default=False, dest='dyna_data', 105 | help='') 106 | parser.add_argument('--ada-head-attn', action='store_true', default=False, dest='ada_head_attn', 107 | help='') 108 | parser.add_argument('--head-slowfast', action='store_true', default=False, dest='head_slowfast', 109 | help='') 110 | 111 | parser.add_argument('--flops-dict', type=str, default='', dest='flops_dict', 112 | help='') 113 | parser.add_argument('--ada-token', action='store_true', default=False, dest='ada_token', 114 | help='ada-token on self-attn') 115 | parser.add_argument('--ada-token-nonstep', action='store_true', default=False, dest='ada_token_nonstep', 116 | help='using nonstep option for token selection, i.e. generate policies for all layers at once at the beginning') 117 | parser.add_argument('--ada-token-with-mlp', action='store_true', default=False, dest='ada_token_with_mlp', 118 | help='ada-token on both self-attn and ffn') 119 | parser.add_argument('--ada-token-start-layer', type=int, default=0) 120 | parser.add_argument('--ada-token-pre-softmax', action='store_true', default=True, dest='ada_token_pre_softmax') 121 | parser.add_argument('--ada-token-post-softmax', action='store_false', default=True, dest='ada_token_pre_softmax') 122 | parser.add_argument('--ada-token-detach-attn', action='store_true', default=True, dest='ada_token_detach_attn') 123 | parser.add_argument('--ada-token-no-detach-attn', action='store_false', default=True, dest='ada_token_detach_attn') 124 | parser.add_argument('--ada-token-detach-attn-at-mlp', action='store_true', default=False, dest='ada_token_detach_attn_at_mlp', 125 | help='whether detaching attn_policy at MLP, when using dynamic token selection (which overrides --ada-token-detach-attn at MLP)') 126 | 127 | parser.add_argument('--no-head-select-bias', action='store_false', default=True, dest='head_select_bias', 128 | help='') 129 | parser.add_argument('--no-layer-select-bias', action='store_false', default=True, dest='layer_select_bias', 130 | help='') 131 | parser.add_argument('--model', default='T2t_vit_14', type=str, metavar='MODEL', 132 | help='Name of model to train (default: "countception"') 133 | parser.add_argument('--freeze-bn', action='store_true', default=False, 134 | help='freeze bn in training') 135 | parser.add_argument('--pretrained', action='store_true', default=False, 136 | help='Start with pretrained version of specified network (if avail)') 137 | parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH', 138 | help='Initialize model from this checkpoint (default: none)') 139 | parser.add_argument('--pretrain-path', default='', type=str, 140 | help='Load pretrain file path') 141 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 142 | help='Resume full model and optimizer state from checkpoint (default: none)') 143 | parser.add_argument('--eval_checkpoint', default='', type=str, metavar='PATH', 144 | help='path to eval checkpoint (default: none)') 145 | parser.add_argument('--use-full-head', action='store_true', default=False, 146 | help='use full model param') 147 | parser.add_argument('--no-resume-opt', action='store_true', default=False, 148 | help='prevent resume of optimizer state when resuming model') 149 | parser.add_argument('--num-classes', type=int, default=1000, metavar='N', 150 | help='number of label classes (default: 1000)') 151 | parser.add_argument('--gp', default=None, type=str, metavar='POOL', 152 | help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') 153 | parser.add_argument('--img-size', type=int, default=224, metavar='N', 154 | help='Image patch size (default: None => model default)') 155 | parser.add_argument('--crop-pct', default=None, type=float, 156 | metavar='N', help='Input image center crop percent (for validation only)') 157 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', 158 | help='Override mean pixel value of dataset') 159 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', 160 | help='Override std deviation of of dataset') 161 | parser.add_argument('--interpolation', default='', type=str, metavar='NAME', 162 | help='Image resize interpolation type (overrides model)') 163 | parser.add_argument('-b', '--batch-size', type=int, default=64, metavar='N', 164 | help='input batch size for training (default: 64)') 165 | parser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N', 166 | help='ratio of validation batch size to training batch size (default: 1)') 167 | 168 | # Optimizer parameters 169 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER', 170 | help='Optimizer (default: "adamw"') 171 | parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON', 172 | help='Optimizer Epsilon (default: None, use opt default)') 173 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', 174 | help='Optimizer Betas (default: None, use opt default)') 175 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 176 | help='Optimizer momentum (default: 0.9)') 177 | parser.add_argument('--weight-decay', type=float, default=0.05, 178 | help='weight decay (default: 0.005 for adamw)') 179 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', 180 | help='Clip gradient norm (default: None, no clipping)') 181 | 182 | # Learning rate schedule parameters 183 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', 184 | help='LR scheduler (default: "cosine"') 185 | parser.add_argument('--lr', type=float, default=5e-4, metavar='LR', 186 | help='learning rate (default: 0.01)') 187 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', 188 | help='learning rate noise on/off epoch percentages') 189 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', 190 | help='learning rate noise limit percent (default: 0.67)') 191 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', 192 | help='learning rate noise std-dev (default: 1.0)') 193 | parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT', 194 | help='learning rate cycle len multiplier (default: 1.0)') 195 | parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N', 196 | help='learning rate cycle limit') 197 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', 198 | help='warmup learning rate (default: 0.0001)') 199 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', 200 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 201 | parser.add_argument('--epochs', type=int, default=300, metavar='N', 202 | help='number of epochs to train (default: 2)') 203 | parser.add_argument('--start-epoch', default=None, type=int, metavar='N', 204 | help='manual epoch number (useful on restarts)') 205 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', 206 | help='epoch interval to decay LR') 207 | parser.add_argument('--warmup-epochs', type=int, default=10, metavar='N', 208 | help='epochs to warmup LR, if scheduler supports') 209 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', 210 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 211 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', 212 | help='patience epochs for Plateau LR scheduler (default: 10') 213 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', 214 | help='LR decay rate (default: 0.1)') 215 | parser.add_argument('--ada-lr-scaling', action='store_true', default=False, 216 | help='rescale the lr of ada subnetworks') 217 | parser.add_argument('--ada-token-lr-scale', type=float, default=1.0, help='') 218 | parser.add_argument('--ada-layer-lr-scale', type=float, default=1.0, help='') 219 | parser.add_argument('--ada-head-lr-scale', type=float, default=1.0, help='') 220 | 221 | # Augmentation & regularization parameters 222 | parser.add_argument('--no-aug', action='store_true', default=False, 223 | help='Disable all training augmentation, override other train aug args') 224 | parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT', 225 | help='Random resize scale (default: 0.08 1.0)') 226 | parser.add_argument('--ratio', type=float, nargs='+', default=[3./4., 4./3.], metavar='RATIO', 227 | help='Random resize aspect ratio (default: 0.75 1.33)') 228 | parser.add_argument('--hflip', type=float, default=0.5, 229 | help='Horizontal flip training aug probability') 230 | parser.add_argument('--vflip', type=float, default=0., 231 | help='Vertical flip training aug probability') 232 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', 233 | help='Color jitter factor (default: 0.4)') 234 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME', 235 | help='Use AutoAugment policy. "v0" or "original". (default: None)'), 236 | parser.add_argument('--aug-splits', type=int, default=0, 237 | help='Number of augmentation splits (default: 0, valid: 0 or >=2)') 238 | parser.add_argument('--jsd', action='store_true', default=False, 239 | help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.') 240 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT', 241 | help='Random erase prob (default: 0.25)') 242 | parser.add_argument('--remode', type=str, default='pixel', 243 | help='Random erase mode (default: "const")') 244 | parser.add_argument('--recount', type=int, default=1, 245 | help='Random erase count (default: 1)') 246 | parser.add_argument('--resplit', action='store_true', default=False, 247 | help='Do not random erase first (clean) augmentation split') 248 | parser.add_argument('--mixup', type=float, default=0.8, 249 | help='mixup alpha, mixup enabled if > 0. (default: 0.)') 250 | parser.add_argument('--cutmix', type=float, default=1.0, 251 | help='cutmix alpha, cutmix enabled if > 0. (default: 0.)') 252 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, 253 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 254 | parser.add_argument('--mixup-prob', type=float, default=1.0, 255 | help='Probability of performing mixup or cutmix when either/both is enabled') 256 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5, 257 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 258 | parser.add_argument('--mixup-mode', type=str, default='batch', 259 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 260 | parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', 261 | help='Turn off mixup after this epoch, disabled if 0 (default: 0)') 262 | parser.add_argument('--smoothing', type=float, default=0.1, 263 | help='Label smoothing (default: 0.1)') 264 | parser.add_argument('--train-interpolation', type=str, default='random', 265 | help='Training interpolation (random, bilinear, bicubic default: "random")') 266 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', 267 | help='Dropout rate (default: 0.0)') 268 | parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT', 269 | help='Drop connect rate, DEPRECATED, use drop-path (default: None)') 270 | parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT', 271 | help='Drop path rate (default: None)') 272 | parser.add_argument('--drop-block', type=float, default=None, metavar='PCT', 273 | help='Drop block rate (default: None)') 274 | 275 | # Batch norm parameters (only works with gen_efficientnet based models currently) 276 | parser.add_argument('--bn-tf', action='store_true', default=False, 277 | help='Use Tensorflow BatchNorm defaults for models that support it (default: False)') 278 | parser.add_argument('--bn-momentum', type=float, default=None, 279 | help='BatchNorm momentum override (if not None)') 280 | parser.add_argument('--bn-eps', type=float, default=None, 281 | help='BatchNorm epsilon override (if not None)') 282 | parser.add_argument('--sync-bn', action='store_true', 283 | help='Enable NVIDIA Apex or Torch synchronized BatchNorm.') 284 | parser.add_argument('--dist-bn', type=str, default='', 285 | help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")') 286 | parser.add_argument('--split-bn', action='store_true', 287 | help='Enable separate BN layers per augmentation split.') 288 | 289 | # Model Exponential Moving Average 290 | parser.add_argument('--model-ema', action='store_true', default=True, 291 | help='Enable tracking moving average of model weights') 292 | parser.add_argument('--no-model-ema', action='store_false', dest='model_ema', 293 | help='Enable tracking moving average of model weights') 294 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, 295 | help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.') 296 | parser.add_argument('--model-ema-decay', type=float, default=0.99996, 297 | help='decay factor for model weights moving average (default: 0.9998)') 298 | 299 | # Misc 300 | parser.add_argument('--seed', type=int, default=42, metavar='S', 301 | help='random seed (default: 42)') 302 | parser.add_argument('--log-interval', type=int, default=50, metavar='N', 303 | help='how many batches to wait before logging training status') 304 | parser.add_argument('--recovery-interval', type=int, default=0, metavar='N', 305 | help='how many batches to wait before writing recovery checkpoint') 306 | parser.add_argument('-j', '--workers', type=int, default=8, metavar='N', 307 | help='how many training processes to use (default: 1)') 308 | parser.add_argument('--num-gpu', type=int, default=1, 309 | help='Number of GPUS to use') 310 | parser.add_argument('--save-images', action='store_true', default=False, 311 | help='save images of input bathes every log interval for debugging') 312 | parser.add_argument('--amp', action='store_true', default=False, 313 | help='use NVIDIA Apex AMP or Native AMP for mixed precision training') 314 | parser.add_argument('--apex-amp', action='store_true', default=False, 315 | help='Use NVIDIA Apex AMP mixed precision') 316 | parser.add_argument('--native-amp', action='store_true', default=False, 317 | help='Use Native Torch AMP mixed precision') 318 | parser.add_argument('--channels-last', action='store_true', default=False, 319 | help='Use channels_last memory layout') 320 | parser.add_argument('--pin-mem', action='store_true', default=False, 321 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 322 | parser.add_argument('--no-prefetcher', action='store_true', default=False, 323 | help='disable fast prefetcher') 324 | parser.add_argument('--output', default='', type=str, metavar='PATH', 325 | help='path to output folder (default: none, current dir)') 326 | parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC', 327 | help='Best metric (default: "top1"') 328 | parser.add_argument('--tta', type=int, default=0, metavar='N', 329 | help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') 330 | parser.add_argument("--local_rank", default=0, type=int) 331 | parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False, 332 | help='use the multi-epochs-loader to save time at the beginning of every epoch') 333 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 334 | 335 | # Random Baseline settings 336 | parser.add_argument('--random-policy', action='store_true', default=False, dest='random_policy') 337 | parser.add_argument('--random-head', action='store_true', default=False, dest='random_head') 338 | parser.add_argument('--random-layer', action='store_true', default=False, dest='random_layer') 339 | parser.add_argument('--random-token', action='store_true', default=False, dest='random_token') 340 | parser.add_argument('--eval_random_baseline', action='store_true', default=False, 341 | help='if True, evaluate random baselines given certain computational budget') 342 | parser.add_argument('--eval_random_layer', action='store_true', default=False, 343 | help='if True, test randomly generated policies for layer selection') 344 | parser.add_argument('--eval_random_head', action='store_true', default=False, 345 | help='if True, test randomly generated policies for head selection') 346 | parser.add_argument('--eval_random_token', action='store_true', default=False, 347 | help='if True, test randomly generated policies for token selection') 348 | parser.add_argument('--eval_random_layer_ratio', type=float, default=1.0, 349 | help='ratio of kept layers in random policies') 350 | parser.add_argument('--eval_random_head_ratio', type=float, default=1.0, 351 | help='ratio of kept layers in random policies') 352 | parser.add_argument('--eval_random_token_ratio', type=float, default=1.0, 353 | help='ratio of kept layers in random policies') 354 | parser.add_argument('--dev', action='store_true', default=False, 355 | help='skip some time-consuming steps when developing, e.g. loading full training set') 356 | parser.add_argument('--print-head-option', action='store_true') 357 | try: 358 | from apex import amp 359 | from apex.parallel import DistributedDataParallel as ApexDDP 360 | from apex.parallel import convert_syncbn_model 361 | 362 | has_apex = True 363 | except ImportError: 364 | has_apex = False 365 | 366 | has_native_amp = False 367 | try: 368 | if getattr(torch.cuda.amp, 'autocast') is not None: 369 | has_native_amp = True 370 | except AttributeError: 371 | pass 372 | 373 | def _parse_args(): 374 | # Do we have a config file to parse? 375 | args_config, remaining = config_parser.parse_known_args() 376 | if args_config.config: 377 | with open(args_config.config, 'r') as f: 378 | cfg = yaml.safe_load(f) 379 | parser.set_defaults(**cfg) 380 | 381 | # The main arg parser parses the rest of the args, the usual 382 | # defaults will have been overridden if config file specified. 383 | args = parser.parse_args(remaining) 384 | if args.ada_block : 385 | args.ada_layer = True 386 | if args.ada_head_attn : 387 | args.ada_head = True 388 | if args.ada_token_with_mlp : 389 | args.ada_token = True 390 | if args.ada_token_nonstep: 391 | args.ada_token = True 392 | if args.ada_head_v2 : 393 | args.ada_head = True 394 | 395 | # Cache the args as a text string to save them in the output dir later 396 | args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) 397 | return args, args_text 398 | 399 | 400 | def main(): 401 | setup_default_logging() 402 | args, args_text = _parse_args() 403 | 404 | args.prefetcher = not args.no_prefetcher 405 | args.distributed = False 406 | if 'WORLD_SIZE' in os.environ: 407 | args.distributed = int(os.environ['WORLD_SIZE']) > 1 408 | if args.distributed and args.num_gpu > 1: 409 | _logger.warning( 410 | 'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.') 411 | args.num_gpu = 1 412 | 413 | args.device = 'cuda:0' 414 | args.world_size = 1 415 | args.rank = 0 # global rank 416 | if args.distributed: 417 | args.num_gpu = 1 418 | args.device = 'cuda:%d' % args.local_rank 419 | torch.cuda.set_device(args.local_rank) 420 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 421 | args.world_size = torch.distributed.get_world_size() 422 | args.rank = torch.distributed.get_rank() 423 | assert args.rank >= 0 424 | 425 | if args.distributed: 426 | _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' 427 | % (args.rank, args.world_size)) 428 | else: 429 | _logger.info('Training with a single process on %d GPUs.' % args.num_gpu) 430 | 431 | torch.manual_seed(args.seed + args.rank) 432 | 433 | model = create_model( 434 | args.model, 435 | pretrained=args.pretrained, 436 | num_classes=args.num_classes, 437 | drop_rate=args.drop, 438 | drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path 439 | drop_path_rate=args.drop_path, 440 | drop_block_rate=args.drop_block, 441 | global_pool=args.gp, 442 | bn_tf=args.bn_tf, 443 | bn_momentum=args.bn_momentum, 444 | bn_eps=args.bn_eps, 445 | checkpoint_path=args.initial_checkpoint, 446 | img_size=args.img_size, 447 | keep_layers=args.keep_layers, 448 | head_select_bias=args.head_select_bias, 449 | layer_select_bias=args.layer_select_bias, 450 | norm_policy = args.norm_policy, 451 | ada_layer = args.ada_layer, 452 | ada_block = args.ada_block, 453 | ada_head = args.ada_head, 454 | ada_head_v2 = args.ada_head_v2, 455 | dyna_data = args.dyna_data, 456 | ada_token = args.ada_token, 457 | ada_token_nonstep = args.ada_token_nonstep, 458 | ada_token_with_mlp = args.ada_token_with_mlp, 459 | ada_head_attn = args.ada_head_attn, 460 | head_slowfast = args.head_slowfast, 461 | ada_token_start_layer = args.ada_token_start_layer, 462 | ada_token_pre_softmax = args.ada_token_pre_softmax, 463 | ada_token_detach_attn = args.ada_token_detach_attn, 464 | ada_token_detach_attn_at_mlp = args.ada_token_detach_attn_at_mlp, 465 | layer_select_tau = args.layer_select_tau, 466 | head_select_tau = args.head_select_tau, 467 | token_select_tau = args.token_select_tau, 468 | ) 469 | 470 | def set_model_attr(model, name, value): 471 | if hasattr(model, name) : 472 | setattr(model, name, value) 473 | if args.random_policy : 474 | model.apply(lambda x : set_model_attr(x, 'random_policy', True)) 475 | model.apply(lambda x : set_model_attr(x, 'random_layer_ratio', args.layer_target_ratio)) 476 | model.apply(lambda x : set_model_attr(x, 'random_head_ratio', args.head_target_ratio)) 477 | model.apply(lambda x : set_model_attr(x, 'random_token_ratio', args.token_target_ratio)) 478 | 479 | if args.random_head : 480 | model.apply(lambda x : set_model_attr(x, 'random_head', True)) 481 | model.apply(lambda x : set_model_attr(x, 'random_head_ratio', args.head_target_ratio)) 482 | if args.random_layer : 483 | model.apply(lambda x : set_model_attr(x, 'random_layer', True)) 484 | model.apply(lambda x : set_model_attr(x, 'random_layer_ratio', args.layer_target_ratio)) 485 | if args.random_token : 486 | model.apply(lambda x : set_model_attr(x, 'random_token', True)) 487 | model.apply(lambda x : set_model_attr(x, 'random_token_ratio', args.token_target_ratio)) 488 | 489 | if args.pretrain_path : 490 | _logger.info('load pretrain model {}'.format(args.pretrain_path)) 491 | # model.load_state_dict(torch.load(args.pretrain_path, map_location='cpu'), strict=False) 492 | ada_load_state_dict(args.pretrain_path, model, use_qkv=False, strict=False) 493 | 494 | if args.local_rank == 0: 495 | _logger.info('Model %s created, param count: %d' % 496 | (args.model, sum([m.numel() for m in model.parameters()]))) 497 | 498 | data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) 499 | 500 | num_aug_splits = 0 501 | if args.aug_splits > 0: 502 | assert args.aug_splits > 1, 'A split of 1 makes no sense' 503 | num_aug_splits = args.aug_splits 504 | 505 | if args.split_bn: 506 | assert num_aug_splits > 1 or args.resplit 507 | model = convert_splitbn_model(model, max(num_aug_splits, 2)) 508 | 509 | use_amp = None 510 | if args.amp: 511 | # for backwards compat, `--amp` arg tries apex before native amp 512 | if has_apex: 513 | args.apex_amp = True 514 | elif has_native_amp: 515 | args.native_amp = True 516 | if args.apex_amp and has_apex: 517 | use_amp = 'apex' 518 | elif args.native_amp and has_native_amp: 519 | use_amp = 'native' 520 | elif args.apex_amp or args.native_amp: 521 | _logger.warning("Neither APEX or native Torch AMP is available, using float32. " 522 | "Install NVIDA apex or upgrade to PyTorch 1.6") 523 | 524 | if args.num_gpu > 1: 525 | if use_amp == 'apex': 526 | _logger.warning( 527 | 'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.') 528 | use_amp = None 529 | model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() 530 | assert not args.channels_last, "Channels last not supported with DP, use DDP." 531 | else: 532 | model.cuda() 533 | if args.channels_last: 534 | model = model.to(memory_format=torch.channels_last) 535 | 536 | if not args.ada_lr_scaling: 537 | optimizer = create_optimizer(args, model) 538 | else: 539 | ada_params_names = ['token_select', 'head_select', 'layer_select'] 540 | no_wd_params_names = list(model.no_weight_decay()) 541 | model_params = [kv for kv in model.named_parameters() if kv[1].requires_grad] 542 | # ada_params_name_check = [kv[0] for kv in model.named_parameters() if any(sub in kv[0] for sub in ada_params_names)] # just for sanity check 543 | ada_params = [kv for kv in model_params if any(sub in kv[0] for sub in ada_params_names)] 544 | base_params = [kv for kv in model_params if not any(sub in kv[0] for sub in ada_params_names)] 545 | 546 | base_params_wd, base_params_no_wd = [], [] 547 | ada_params_token_wd, ada_params_token_no_wd, ada_params_head_wd, ada_params_head_no_wd, ada_params_layer_wd, ada_params_layer_no_wd = [], [], [], [], [], [] 548 | 549 | for name, param in base_params: 550 | if len(param.shape) == 1 or name.endswith(".bias") or name in no_wd_params_names: 551 | base_params_no_wd.append(param) 552 | else: 553 | base_params_wd.append(param) 554 | for name, param in ada_params: 555 | if 'token_select' in name: 556 | if len(param.shape) == 1 or name.endswith(".bias") or name in no_wd_params_names: 557 | ada_params_token_no_wd.append(param) 558 | else: 559 | ada_params_token_wd.append(param) 560 | elif 'head_select' in name: 561 | if len(param.shape) == 1 or name.endswith(".bias") or name in no_wd_params_names: 562 | ada_params_head_no_wd.append(param) 563 | else: 564 | ada_params_head_wd.append(param) 565 | elif 'layer_select' in name: 566 | if len(param.shape) == 1 or name.endswith(".bias") or name in no_wd_params_names: 567 | ada_params_layer_no_wd.append(param) 568 | else: 569 | ada_params_layer_wd.append(param) 570 | 571 | all_params = [ 572 | {'params': ada_params_token_wd, 'lr': args.lr * args.ada_token_lr_scale, 'weight_decay': args.weight_decay}, 573 | {'params': ada_params_token_no_wd, 'lr': args.lr * args.ada_token_lr_scale, 'weight_decay': 0.}, 574 | {'params': ada_params_head_wd, 'lr': args.lr * args.ada_head_lr_scale, 'weight_decay': args.weight_decay}, 575 | {'params': ada_params_head_no_wd, 'lr': args.lr * args.ada_head_lr_scale, 'weight_decay': 0.}, 576 | {'params': ada_params_layer_wd, 'lr': args.lr * args.ada_layer_lr_scale, 'weight_decay': args.weight_decay}, 577 | {'params': ada_params_layer_no_wd, 'lr': args.lr * args.ada_layer_lr_scale, 'weight_decay': 0.}, 578 | {'params': base_params_wd, 'lr': args.lr, 'weight_decay': args.weight_decay}, 579 | {'params': base_params_no_wd, 'lr': args.lr, 'weight_decay': 0.}, 580 | ] 581 | optimizer = torch.optim.AdamW(all_params, lr=args.lr, weight_decay=args.weight_decay) 582 | 583 | amp_autocast = suppress # do nothing 584 | loss_scaler = None 585 | if use_amp == 'apex': 586 | model, optimizer = amp.initialize(model, optimizer, opt_level='O1') 587 | loss_scaler = ApexScaler() 588 | if args.local_rank == 0: 589 | _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') 590 | elif use_amp == 'native': 591 | amp_autocast = torch.cuda.amp.autocast 592 | loss_scaler = NativeScaler() 593 | if args.local_rank == 0: 594 | _logger.info('Using native Torch AMP. Training in mixed precision.') 595 | else: 596 | if args.local_rank == 0: 597 | _logger.info('AMP not enabled. Training in float32.') 598 | 599 | # optionally resume from a checkpoint 600 | resume_epoch = None 601 | if args.resume: 602 | resume_epoch = resume_checkpoint( 603 | model, args.resume, 604 | optimizer=None if args.no_resume_opt else optimizer, 605 | loss_scaler=None if args.no_resume_opt else loss_scaler, 606 | log_info=args.local_rank == 0) 607 | 608 | model_ema = None 609 | if args.model_ema: 610 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper 611 | model_ema = ModelEma( 612 | model, 613 | decay=args.model_ema_decay, 614 | device='cpu' if args.model_ema_force_cpu else '', 615 | resume=args.resume) 616 | 617 | if args.distributed: 618 | if args.sync_bn: 619 | assert not args.split_bn 620 | try: 621 | if has_apex and use_amp != 'native': 622 | # Apex SyncBN preferred unless native amp is activated 623 | model = convert_syncbn_model(model) 624 | else: 625 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 626 | if args.local_rank == 0: 627 | _logger.info( 628 | 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 629 | 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') 630 | except Exception as e: 631 | _logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1') 632 | if has_apex and use_amp != 'native': 633 | # Apex DDP preferred unless native amp is activated 634 | if args.local_rank == 0: 635 | _logger.info("Using NVIDIA APEX DistributedDataParallel.") 636 | model = ApexDDP(model, delay_allreduce=True) 637 | else: 638 | if args.local_rank == 0: 639 | _logger.info("Using native Torch DistributedDataParallel.") 640 | model = NativeDDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1 641 | # NOTE: EMA model does not need to be wrapped by DDP 642 | 643 | lr_scheduler, num_epochs = create_scheduler(args, optimizer) 644 | start_epoch = 0 645 | if args.start_epoch is not None: 646 | # a specified start_epoch will always override the resume epoch 647 | start_epoch = args.start_epoch 648 | elif resume_epoch is not None: 649 | start_epoch = resume_epoch 650 | if lr_scheduler is not None and start_epoch > 0: 651 | lr_scheduler.step(start_epoch) 652 | 653 | if args.local_rank == 0: 654 | _logger.info('Scheduled epochs: {}'.format(num_epochs)) 655 | 656 | train_dir = os.path.join(args.data, 'train' if not args.dev else 'val') 657 | if not os.path.exists(train_dir): 658 | _logger.error('Training folder does not exist at: {}'.format(train_dir)) 659 | exit(1) 660 | dataset_train = Dataset(train_dir) 661 | 662 | collate_fn = None 663 | mixup_fn = None 664 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 665 | print('mixup activte', mixup_active) 666 | if mixup_active: 667 | mixup_args = dict( 668 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 669 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 670 | label_smoothing=args.smoothing, num_classes=args.num_classes) 671 | if args.prefetcher: 672 | assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) 673 | collate_fn = FastCollateMixup(**mixup_args) 674 | else: 675 | mixup_fn = Mixup(**mixup_args) 676 | 677 | if num_aug_splits > 1: 678 | dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) 679 | 680 | train_interpolation = args.train_interpolation 681 | if args.no_aug or not train_interpolation: 682 | train_interpolation = data_config['interpolation'] 683 | loader_train = create_loader( 684 | dataset_train, 685 | input_size=data_config['input_size'], 686 | batch_size=args.batch_size, 687 | is_training=True, 688 | use_prefetcher=args.prefetcher, 689 | no_aug=args.no_aug, 690 | re_prob=args.reprob, 691 | re_mode=args.remode, 692 | re_count=args.recount, 693 | re_split=args.resplit, 694 | scale=args.scale, 695 | ratio=args.ratio, 696 | hflip=args.hflip, 697 | vflip=args.vflip, 698 | color_jitter=args.color_jitter, 699 | auto_augment=args.aa, 700 | num_aug_splits=num_aug_splits, 701 | interpolation=train_interpolation, 702 | mean=data_config['mean'], 703 | std=data_config['std'], 704 | num_workers=args.workers, 705 | distributed=args.distributed, 706 | collate_fn=collate_fn, 707 | pin_memory=args.pin_mem, 708 | use_multi_epochs_loader=args.use_multi_epochs_loader 709 | ) 710 | 711 | eval_dir = os.path.join(args.data, 'val') 712 | if not os.path.isdir(eval_dir): 713 | eval_dir = os.path.join(args.data, 'validation') 714 | if not os.path.isdir(eval_dir): 715 | _logger.error('Validation folder does not exist at: {}'.format(eval_dir)) 716 | exit(1) 717 | dataset_eval = Dataset(eval_dir) 718 | 719 | loader_eval = create_loader( 720 | dataset_eval, 721 | input_size=data_config['input_size'], 722 | batch_size=args.validation_batch_size_multiplier * args.batch_size, 723 | is_training=False, 724 | use_prefetcher=args.prefetcher, 725 | interpolation=data_config['interpolation'], 726 | mean=data_config['mean'], 727 | std=data_config['std'], 728 | num_workers=args.workers, 729 | distributed=args.distributed, 730 | crop_pct=data_config['crop_pct'], 731 | pin_memory=args.pin_mem, 732 | ) 733 | 734 | if args.jsd: 735 | assert num_aug_splits > 1 # JSD only valid with aug splits set 736 | train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() 737 | elif mixup_active: 738 | # smoothing is handled with mixup target transform 739 | train_loss_fn = SoftTargetCrossEntropy().cuda() 740 | elif args.smoothing: 741 | train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda() 742 | else: 743 | train_loss_fn = nn.CrossEntropyLoss().cuda() 744 | # train_loss_fn = AdaHeadLoss(train_loss_fn, target_ratio=args.head_target_ratio,head_loss_ratio=args.head_ratio) 745 | train_loss_fn = AdaLoss(train_loss_fn, 746 | head_target_ratio=args.head_target_ratio, head_loss_ratio=args.head_ratio, 747 | layer_target_ratio=args.layer_target_ratio, layer_loss_ratio=args.layer_ratio, 748 | head_diverse_ratio=args.head_diverse_ratio, layer_diverse_ratio=args.layer_diverse_ratio, 749 | head_entropy_weight=args.head_entropy_weight, layer_entropy_weight=args.layer_entropy_weight, 750 | head_minimal_weight=args.head_minimal_weight, head_minimal=args.head_minimal, 751 | layer_minimal_weight=args.layer_minimal_weight, layer_minimal=args.layer_minimal, 752 | token_target_ratio=args.token_target_ratio, token_loss_ratio=args.token_ratio, token_minimal=args.token_minimal, token_minimal_weight=args.token_minimal_weight) 753 | validate_loss_fn = nn.CrossEntropyLoss().cuda() 754 | 755 | eval_metric = args.eval_metric 756 | best_metric = None 757 | best_epoch = None 758 | 759 | flops_dict = None 760 | if args.flops_dict : 761 | flops_dict = torch.load(args.flops_dict) 762 | if args.eval_random_baseline: 763 | assert args.eval_checkpoint, "Please provide path to the checkpoint when evaluating baselines." 764 | ada_load_state_dict(args.eval_checkpoint, model, use_qkv=False, strict=False) 765 | # set random policy configuration 766 | model.random_layer, model.random_head, model.random_token = args.eval_random_layer, args.eval_random_head, args.eval_random_token 767 | model.random_layer_ratio, model.random_head_ratio, model.random_token_ratio = \ 768 | args.eval_random_layer_ratio, args.eval_random_head_ratio, args.eval_random_token_ratio 769 | val_metrics = validate(model, loader_eval, validate_loss_fn, args, print_head_option=args.print_head_option, flops_dict=flops_dict) 770 | print(f"Top-1 accuracy of the model is: {val_metrics['top1']:.1f}%") 771 | return 772 | 773 | if args.eval_checkpoint: # evaluate the model 774 | ada_load_state_dict(args.eval_checkpoint, model, use_qkv=False, strict=False) 775 | if args.use_full_head : 776 | model.apply(lambda m: setattr(m, 'use_full_linear', True)) 777 | val_metrics = validate(model, loader_eval, validate_loss_fn, args, print_head_option=args.print_head_option, flops_dict=flops_dict) 778 | print(f"Top-1 accuracy of the model is: {val_metrics['top1']:.1f}%") 779 | if 'gflops' in val_metrics : 780 | print(f"avg flops is: {val_metrics['gflops']:.1f} GFLOPS") 781 | return 782 | 783 | saver = None 784 | output_dir = '' 785 | if args.local_rank == 0: 786 | output_base = args.output if args.output else './output' 787 | exp_name = '-'.join([ 788 | datetime.datetime.now().strftime("%Y%m%d-%H%M%S"), 789 | args.model, 790 | ]) 791 | output_dir = get_outdir(output_base, 'train', exp_name) 792 | decreasing = True if eval_metric == 'loss' else False 793 | saver = MyCheckpointSaver( 794 | model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, 795 | checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing) 796 | with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: 797 | f.write(args_text) 798 | 799 | try: # train the model 800 | for epoch in range(start_epoch, num_epochs): 801 | if args.distributed: 802 | loader_train.sampler.set_epoch(epoch) 803 | 804 | train_metrics = train_epoch( 805 | epoch, model, loader_train, optimizer, train_loss_fn, args, 806 | lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, 807 | amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn, total_epochs=num_epochs) 808 | 809 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'): 810 | if args.local_rank == 0: 811 | _logger.info("Distributing BatchNorm running means and vars") 812 | distribute_bn(model, args.world_size, args.dist_bn == 'reduce') 813 | 814 | eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, flops_dict=flops_dict) 815 | 816 | if model_ema is not None and not args.model_ema_force_cpu: 817 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'): 818 | distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') 819 | ema_eval_metrics = validate( 820 | model_ema.ema, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)', flops_dict=flops_dict) 821 | eval_metrics = ema_eval_metrics 822 | 823 | if lr_scheduler is not None: 824 | # step LR for next epoch 825 | lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) 826 | 827 | update_summary( 828 | epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), 829 | write_header=best_metric is None) 830 | 831 | if saver is not None: 832 | # save proper checkpoint with eval metric 833 | save_metric = eval_metrics[eval_metric] 834 | best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric) 835 | 836 | except KeyboardInterrupt: 837 | pass 838 | if best_metric is not None: 839 | _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch)) 840 | 841 | def freeze_bn(model): 842 | for module in model.modules(): 843 | if isinstance(module, (nn.LayerNorm, nn.GroupNorm)): 844 | module.eval() 845 | 846 | def train_epoch( 847 | epoch, model, loader, optimizer, loss_fn, args, 848 | lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress, 849 | loss_scaler=None, model_ema=None, mixup_fn=None, total_epochs=0): 850 | if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: 851 | if args.prefetcher and loader.mixup_enabled: 852 | loader.mixup_enabled = False 853 | elif mixup_fn is not None: 854 | mixup_fn.mixup_enabled = False 855 | 856 | base_model = model.module if hasattr(model, 'module') else model 857 | second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 858 | batch_time_m = AverageMeter() 859 | data_time_m = AverageMeter() 860 | losses_m = AverageMeter() 861 | top1_m = AverageMeter() 862 | top5_m = AverageMeter() 863 | meta_loss_m = {} 864 | 865 | model.train() 866 | if args.freeze_bn : 867 | freeze_bn(model) 868 | 869 | end = time.time() 870 | last_idx = len(loader) - 1 871 | 872 | def update_meta_loss_m(meta_loss) : 873 | for k, v in meta_loss.items(): 874 | if k not in meta_loss_m : 875 | meta_loss_m[k] = AverageMeter() 876 | meta_loss_m[k].update(v.item(), input.size(0)) 877 | 878 | num_updates = epoch * len(loader) 879 | total_num_updates = total_epochs * len(loader) 880 | for batch_idx, (input, target) in enumerate(loader): 881 | last_batch = batch_idx == last_idx 882 | data_time_m.update(time.time() - end) 883 | if not args.prefetcher: 884 | input, target = input.cuda(), target.cuda() 885 | if mixup_fn is not None: 886 | input, target = mixup_fn(input, target) 887 | if args.channels_last: 888 | input = input.contiguous(memory_format=torch.channels_last) 889 | 890 | with amp_autocast(): 891 | outputs = model(input, training=True, ret_attn_list=False) 892 | loss, meta_loss = loss_fn(outputs, target) 893 | 894 | if not args.distributed: 895 | losses_m.update(loss.item(), input.size(0)) 896 | update_meta_loss_m(meta_loss) 897 | 898 | optimizer.zero_grad() 899 | if loss_scaler is not None: 900 | loss_scaler( 901 | loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order) 902 | else: 903 | loss.backward(create_graph=second_order) 904 | if args.clip_grad is not None: 905 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad) 906 | optimizer.step() 907 | 908 | torch.cuda.synchronize() 909 | if model_ema is not None: 910 | model_ema.update(model) 911 | num_updates += 1 912 | 913 | batch_time_m.update(time.time() - end) 914 | if last_batch or batch_idx % args.log_interval == 0: 915 | lrl = [param_group['lr'] for param_group in optimizer.param_groups] 916 | lr = sum(lrl) / len(lrl) 917 | 918 | if args.distributed: 919 | reduced_loss = reduce_tensor(loss.data, args.world_size) 920 | losses_m.update(reduced_loss.item(), input.size(0)) 921 | for k, v in meta_loss.items(): 922 | meta_loss[k] = reduce_tensor(v.data, args.world_size) 923 | update_meta_loss_m(meta_loss) 924 | 925 | eta_seconds = batch_time_m.avg * (total_num_updates - num_updates) 926 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 927 | 928 | if args.local_rank == 0: 929 | _logger.info( 930 | 'Train: {} [{:>4d}/{} ({:>3.0f}%)] ' 931 | 'Current Time: {} ' 932 | 'ETA : {} ' 933 | 'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) ' 934 | 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s ' 935 | '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 936 | 'LR: {lr:.3e} ' 937 | 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format( 938 | epoch, 939 | batch_idx, len(loader), 940 | 100. * batch_idx / last_idx, 941 | strftime("%Y-%m-%d %H:%M:%S", gmtime()), 942 | eta_string, 943 | loss=losses_m, 944 | batch_time=batch_time_m, 945 | rate=input.size(0) * args.world_size / batch_time_m.val, 946 | rate_avg=input.size(0) * args.world_size / batch_time_m.avg, 947 | lr=lr, 948 | data_time=data_time_m)) 949 | 950 | if args.save_images and output_dir: 951 | torchvision.utils.save_image( 952 | input, 953 | os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx), 954 | padding=0, 955 | normalize=True) 956 | 957 | if saver is not None and args.recovery_interval and ( 958 | last_batch or (batch_idx + 1) % args.recovery_interval == 0): 959 | saver.save_recovery(epoch, batch_idx=batch_idx) 960 | 961 | if lr_scheduler is not None: 962 | lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) 963 | 964 | end = time.time() 965 | # end for 966 | 967 | if hasattr(optimizer, 'sync_lookahead'): 968 | optimizer.sync_lookahead() 969 | 970 | return OrderedDict(loss= losses_m.avg, **{x : meta_loss_m[x].avg for x in meta_loss}) 971 | 972 | def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='', ret_head_option=False, print_head_option=False, flops_dict=None): 973 | batch_time_m = AverageMeter() 974 | losses_m = AverageMeter() 975 | top1_m = AverageMeter() 976 | top5_m = AverageMeter() 977 | head_m = AverageMeter() 978 | layer_m = AverageMeter() 979 | token_m = AverageMeter() 980 | flops_m = AverageMeter() 981 | 982 | analyse_m = [[] for _ in range(1000)] 983 | 984 | head_option = None 985 | layer_option = None 986 | token_option = None 987 | 988 | model.eval() 989 | 990 | end = time.time() 991 | last_idx = len(loader) - 1 992 | with torch.no_grad(): 993 | for batch_idx, (input, target) in enumerate(loader): 994 | last_batch = batch_idx == last_idx 995 | if not args.prefetcher: 996 | input = input.cuda() 997 | target = target.cuda() 998 | if args.channels_last: 999 | input = input.contiguous(memory_format=torch.channels_last) 1000 | 1001 | with amp_autocast(): 1002 | output = model(input) 1003 | if isinstance(output, (tuple, list)): 1004 | output, head_select, layer_select, token_select = output[:4] 1005 | # output = output[0] 1006 | # head_select = output[1] 1007 | else : 1008 | head_select = None 1009 | layer_select = None 1010 | token_select = None 1011 | 1012 | # augmentation reduction 1013 | reduce_factor = args.tta 1014 | if reduce_factor > 1: 1015 | output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) 1016 | target = target[0:target.size(0):reduce_factor] 1017 | 1018 | loss = loss_fn(output, target) 1019 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 1020 | 1021 | if args.distributed: 1022 | reduced_loss = reduce_tensor(loss.data, args.world_size) 1023 | acc1 = reduce_tensor(acc1, args.world_size) 1024 | acc5 = reduce_tensor(acc5, args.world_size) 1025 | if head_select is not None : 1026 | head_select = reduce_tensor(head_select, args.world_size) 1027 | if layer_select is not None : 1028 | layer_select = reduce_tensor(layer_select, args.world_size) 1029 | if token_select is not None : 1030 | token_select = reduce_tensor(token_select, args.world_size) 1031 | else: 1032 | reduced_loss = loss.data 1033 | 1034 | torch.cuda.synchronize() 1035 | 1036 | losses_m.update(reduced_loss.item(), input.size(0)) 1037 | top1_m.update(acc1.item(), output.size(0)) 1038 | top5_m.update(acc5.item(), output.size(0)) 1039 | if head_select is not None : 1040 | head_m.update(head_select.mean().item(), output.size(0)) 1041 | if head_option is None : 1042 | head_option = AverageMeter() 1043 | head_option.update(head_select.mean(0).cpu(), output.size(0)) 1044 | if layer_select is not None : 1045 | layer_m.update(layer_select.mean().item(), output.size(0)) 1046 | if layer_option is None : 1047 | layer_option = AverageMeter() 1048 | layer_option.update(layer_select.mean(0).cpu(), output.size(0)) 1049 | if token_select is not None : 1050 | token_m.update(token_select.mean().item(), output.size(0)) 1051 | if token_option is None : 1052 | token_option = AverageMeter() 1053 | token_option.update(token_select.mean((0,-1)), output.size(0)) 1054 | 1055 | if flops_dict is not None : 1056 | bs = output.size(0) 1057 | from block_flops_dict import batch_select_flops 1058 | if 'deit' in args.model: 1059 | b_flops = batch_select_flops(bs, flops_dict, head_select, layer_select, token_select, block_num=12,base_flops=0.06) 1060 | else : 1061 | b_flops = batch_select_flops(bs, flops_dict, head_select, layer_select, token_select, block_num=19, base_flops=0.33) 1062 | flops_m.update(b_flops.mean(), bs) 1063 | def deal_none(x): 1064 | return x if x is not None else [None] * bs 1065 | head_select = deal_none(head_select) 1066 | layer_select = deal_none(layer_select) 1067 | token_select = deal_none(token_select) 1068 | for c, b, ht, lt, tt in zip(target, b_flops, head_select, layer_select, token_select): 1069 | lt = lt.cpu() if lt is not None else None 1070 | ht = ht.cpu() if ht is not None else None 1071 | tt = tt.cpu() if tt is not None else None 1072 | analyse_m[c].append(dict(flops=b.cpu(), layer_select=lt, head_select=ht, token_select=tt)) 1073 | 1074 | batch_time_m.update(time.time() - end) 1075 | end = time.time() 1076 | if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0): 1077 | log_name = 'Test' + log_suffix 1078 | _logger.info( 1079 | '{0}: [{1:>4d}/{2}] ' 1080 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 1081 | 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 1082 | 'Head avg :{head_select.val:>7.4f} ({head_select.avg:>6.4f}) ' 1083 | 'Layer avg :{layer_m.val:>7.4f} ({layer_m.avg:>6.4f}) ' 1084 | 'Token avg :{token_m.val:>7.4f} ({token_m.avg:>6.4f}) ' 1085 | 'Flops avg :{flops_m.val:>7.4f} ({flops_m.avg:>6.4f}) ' 1086 | 'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' 1087 | 'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format( 1088 | log_name, batch_idx, last_idx, batch_time=batch_time_m, 1089 | loss=losses_m, top1=top1_m, top5=top5_m, 1090 | head_select=head_m, layer_m=layer_m, token_m=token_m, flops_m=flops_m)) 1091 | 1092 | metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg), ('head', head_m.avg), ('layer', layer_m.avg)]) 1093 | if flops_dict is not None : 1094 | metrics.update(gflops=flops_m.avg) 1095 | if head_option is not None : 1096 | if ret_head_option : 1097 | metrics.update(head_option=head_option.avg) 1098 | if print_head_option : 1099 | for i, val in enumerate(head_option.avg) : 1100 | if args.rank == 0 : 1101 | print('{}: {}'.format(i+args.keep_layers, ','.join(['{:.2f}'.format(float(x)) for x in val]))) 1102 | 1103 | if layer_option is not None : 1104 | if ret_head_option : 1105 | metrics.update(layer_option=layer_option.avg) 1106 | if print_head_option : 1107 | for i, val in enumerate(layer_option.avg) : 1108 | if args.rank == 0 : 1109 | print('{}: {}'.format(i+args.keep_layers, ','.join(['{:.2f}'.format(float(x)) for x in val]))) 1110 | 1111 | if token_option is not None : 1112 | if ret_head_option : 1113 | metrics.update(token_option=token_option.avg) 1114 | if print_head_option : 1115 | if args.rank == 0 : 1116 | print(' '.join(['{:.2f}'.format(float(x)) for x in token_option.avg])) 1117 | 1118 | if analyse_m is not None : 1119 | torch.save(analyse_m, '{}_analyse.{}.pth'.format(args.model, top1_m.avg)) 1120 | return metrics 1121 | 1122 | 1123 | if __name__ == '__main__': 1124 | main() 1125 | -------------------------------------------------------------------------------- /adavit_ckpt/deit-s-h-l-tmlp_flops_dict.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MengLcool/AdaViT/9857f92c8428045e01a8d988eab5060b1ac4a18b/adavit_ckpt/deit-s-h-l-tmlp_flops_dict.pth -------------------------------------------------------------------------------- /adavit_ckpt/t2t-19-h-l-tmlp_flops_dict.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MengLcool/AdaViT/9857f92c8428045e01a8d988eab5060b1ac4a18b/adavit_ckpt/t2t-19-h-l-tmlp_flops_dict.pth -------------------------------------------------------------------------------- /assets/adavit_approach.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MengLcool/AdaViT/9857f92c8428045e01a8d988eab5060b1ac4a18b/assets/adavit_approach.png -------------------------------------------------------------------------------- /block_flops_dict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import fvcore 3 | from fvcore.nn import FlopCountAnalysis 4 | import argparse 5 | import models 6 | from models.ada_transformer_block import StepAdaBlock 7 | from timm.models import create_model 8 | 9 | parser = argparse.ArgumentParser(description='T2T-ViT Training and Evaluating') 10 | parser.add_argument('--model', default='t2t-19', type=str, metavar='MODEL', 11 | help='Name of model to train (default: "countception"') 12 | parser.add_argument('--ada-layer', action='store_true', default=False, dest='ada_layer', 13 | help='') 14 | parser.add_argument('--ada-head', action='store_true', default=False, dest='ada_head', 15 | help='') 16 | 17 | parser.add_argument('--ada-token', action='store_true', default=False, dest='ada_token', 18 | help='ada-token on self-attn') 19 | parser.add_argument('--ada-token-with-mlp', action='store_true', default=False, dest='ada_token_with_mlp', 20 | help='ada-token on both self-attn and ffn') 21 | parser.add_argument('--keep-layers', type=int, default=1, 22 | help='use layers to make selection decision') 23 | 24 | def _parse_args(): 25 | 26 | args = parser.parse_args() 27 | if args.ada_token_with_mlp : 28 | args.ada_token = True 29 | 30 | return args 31 | 32 | def get_flops_dict(dim, num_heads, mlp_ratio, debug=False, **kwargs): 33 | block = StepAdaBlock(dim, num_heads, mlp_ratio, **kwargs) 34 | inputs = torch.rand((1,197,dim)) 35 | 36 | num_tokens = 197 37 | num_layers = 2 38 | 39 | block.apply(lambda x : setattr(x, 'count_flops', True)) 40 | flops_dict = torch.zeros(num_heads+1, num_tokens+1, 2, 2) 41 | for t in range(1, num_tokens+1) : 42 | for h in range(num_heads+1) : 43 | block.apply(lambda x : setattr(x, 'h_ratio', h/num_heads)) 44 | block.apply(lambda x : setattr(x, 't_ratio', (t)/197)) 45 | 46 | if debug : 47 | block.apply(lambda x : setattr(x, 'h_ratio', 5/7)) 48 | block.apply(lambda x : setattr(x, 't_ratio', (197)/197)) 49 | block.apply(lambda x : setattr(x, 'l_ratio', [1,1])) 50 | flops = FlopCountAnalysis(block, inputs).total() / (1000**3) 51 | print('flops', flops) 52 | exit() 53 | 54 | def fill_dict(l_select): 55 | block.apply(lambda x : setattr(x, 'l_ratio', l_select)) 56 | 57 | xx = block(inputs) 58 | 59 | # flops = FlopCountAnalysis(block, inputs).total() / (1000**3) 60 | # print('flops', h, t, l_select, flops) 61 | # flops_dict[h,t,l_select[0],l_select[1]] = flops 62 | fill_dict([0,0]) 63 | fill_dict([0,1]) 64 | fill_dict([1,0]) 65 | fill_dict([1,1]) 66 | return flops_dict 67 | 68 | def select_flops(flops_dict, head_select, layer_select, token_select, block_num, base_flops=0.33): 69 | ''' 70 | head_select : None or tensor (ada_block_h, h_num) 71 | layer_select : None or tensor (ada_block_l, 2) 72 | token_select : None or tensor (ada_block_t, 197) 73 | ''' 74 | 75 | h, t, _, _ = flops_dict.shape 76 | h = h-1 77 | t = t-1 78 | if head_select is None : 79 | head_select = [h] * block_num 80 | else : 81 | ada_h = head_select.shape[0] 82 | head_select = [h] * (block_num - ada_h) + head_select.sum(-1).int().tolist() 83 | if layer_select is None : 84 | layer_select = [[1,1]] * block_num 85 | else : 86 | ada_l = layer_select.shape[0] 87 | layer_select = [[1,1]] * (block_num-ada_l) + layer_select.int().tolist() 88 | if token_select is None : 89 | token_select = [t] * block_num 90 | else : 91 | ada_t = token_select.shape[0] 92 | token_select = [t] * (block_num - ada_t) + token_select.sum(-1).int().tolist() 93 | 94 | flops = base_flops 95 | for h, t, l in zip(head_select, token_select, layer_select) : 96 | flops += flops_dict[h, t, l[0], l[1]] 97 | return flops 98 | 99 | def batch_select_flops(bs, flops_dict, head_select, layer_select, token_select, block_num=19, base_flops=0.33): 100 | def batch_select(x): 101 | return x if x is not None else [None] * bs 102 | head_select = batch_select(head_select) 103 | layer_select = batch_select(layer_select) 104 | token_select = batch_select(token_select) 105 | 106 | batch_flops = [] 107 | for h, l, t in zip(head_select, layer_select, token_select): 108 | batch_flops.append(select_flops(flops_dict, h, l, t, block_num, base_flops)) 109 | 110 | return torch.tensor(batch_flops) -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .ada_t2t_vit import * 2 | from .deit import * -------------------------------------------------------------------------------- /models/ada_t2t_vit.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from re import L 3 | import torch 4 | import torch.nn as nn 5 | 6 | from timm.models.helpers import load_pretrained 7 | from timm.models.registry import register_model 8 | from timm.models.layers import trunc_normal_ 9 | from timm.models.vision_transformer import PatchEmbed 10 | import numpy as np 11 | from .token_transformer import Token_transformer 12 | from .token_performer import Token_performer 13 | from .transformer_block import Block, get_sinusoid_encoding 14 | from .ada_transformer_block import StepAdaBlock, get_random_policy 15 | 16 | 17 | def _cfg(url='', **kwargs): 18 | return { 19 | 'url': url, 20 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 21 | 'crop_pct': .9, 'interpolation': 'bicubic', 22 | 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 23 | 'classifier': 'head', 24 | **kwargs 25 | } 26 | 27 | default_cfgs = { 28 | 'T2t_vit_7': _cfg(), 29 | 'T2t_vit_10': _cfg(), 30 | 'T2t_vit_12': _cfg(), 31 | 'T2t_vit_14': _cfg(), 32 | 'T2t_vit_19': _cfg(), 33 | 'T2t_vit_24': _cfg(), 34 | 'T2t_vit_t_14': _cfg(), 35 | 'T2t_vit_t_19': _cfg(), 36 | 'T2t_vit_t_24': _cfg(), 37 | 'T2t_vit_14_resnext': _cfg(), 38 | 'T2t_vit_14_wide': _cfg(), 39 | } 40 | 41 | 42 | class T2T_module(nn.Module): 43 | """ 44 | Tokens-to-Token encoding module 45 | """ 46 | def __init__(self, img_size=224, tokens_type='performer', in_chans=3, embed_dim=768, token_dim=64): 47 | super().__init__() 48 | 49 | if tokens_type == 'transformer': 50 | print('adopt transformer encoder for tokens-to-token') 51 | self.soft_split0 = nn.Unfold(kernel_size=(7, 7), stride=(4, 4), padding=(2, 2)) 52 | self.soft_split1 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 53 | self.soft_split2 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 54 | 55 | self.attention1 = Token_transformer(dim=in_chans * 7 * 7, in_dim=token_dim, num_heads=1, mlp_ratio=1.0) 56 | self.attention2 = Token_transformer(dim=token_dim * 3 * 3, in_dim=token_dim, num_heads=1, mlp_ratio=1.0) 57 | self.project = nn.Linear(token_dim * 3 * 3, embed_dim) 58 | 59 | elif tokens_type == 'performer': 60 | print('adopt performer encoder for tokens-to-token') 61 | self.soft_split0 = nn.Unfold(kernel_size=(7, 7), stride=(4, 4), padding=(2, 2)) 62 | self.soft_split1 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 63 | self.soft_split2 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 64 | 65 | #self.attention1 = Token_performer(dim=token_dim, in_dim=in_chans*7*7, kernel_ratio=0.5) 66 | #self.attention2 = Token_performer(dim=token_dim, in_dim=token_dim*3*3, kernel_ratio=0.5) 67 | self.attention1 = Token_performer(dim=in_chans*7*7, in_dim=token_dim, kernel_ratio=0.5) 68 | self.attention2 = Token_performer(dim=token_dim*3*3, in_dim=token_dim, kernel_ratio=0.5) 69 | self.project = nn.Linear(token_dim * 3 * 3, embed_dim) 70 | 71 | elif tokens_type == 'convolution': # just for comparison with conolution, not our model 72 | # for this tokens type, you need change forward as three convolution operation 73 | print('adopt convolution layers for tokens-to-token') 74 | self.soft_split0 = nn.Conv2d(3, token_dim, kernel_size=(7, 7), stride=(4, 4), padding=(2, 2)) # the 1st convolution 75 | self.soft_split1 = nn.Conv2d(token_dim, token_dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) # the 2nd convolution 76 | self.project = nn.Conv2d(token_dim, embed_dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) # the 3rd convolution 77 | 78 | self.num_patches = (img_size // (4 * 2 * 2)) * (img_size // (4 * 2 * 2)) # there are 3 sfot split, stride are 4,2,2 seperately 79 | 80 | def forward(self, x): 81 | # step0: soft split 82 | x = self.soft_split0(x).transpose(1, 2) 83 | 84 | # iteration1: re-structurization/reconstruction 85 | x = self.attention1(x) 86 | B, new_HW, C = x.shape 87 | x = x.transpose(1,2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW))) 88 | # iteration1: soft split 89 | x = self.soft_split1(x).transpose(1, 2) 90 | 91 | # iteration2: re-structurization/reconstruction 92 | x = self.attention2(x) 93 | B, new_HW, C = x.shape 94 | x = x.transpose(1, 2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW))) 95 | # iteration2: soft split 96 | x = self.soft_split2(x).transpose(1, 2) 97 | 98 | # final tokens 99 | x = self.project(x) 100 | 101 | return x 102 | 103 | 104 | class AdaStepT2T_ViT(nn.Module): 105 | def __init__(self, img_size=224, tokens_type='performer', in_chans=3, num_classes=1000, embed_dim=768, depth=12, 106 | norm_policy=False, use_t2t = True, patch_size=16, 107 | ada_head = True, ada_head_v2=False, dyna_data=False, ada_head_attn=False, head_slowfast=False, 108 | ada_layer=False, 109 | ada_block=False, head_select_tau=5., layer_select_tau=5., token_select_tau=5., 110 | ada_token = False, ada_token_with_mlp=False, ada_token_start_layer=0, ada_token_pre_softmax=True, ada_token_detach_attn=True, ada_token_detach_attn_at_mlp=False, 111 | keep_layers = 1, 112 | head_select_bias=True, layer_select_bias=True, 113 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 114 | drop_path_rate=0., norm_layer=nn.LayerNorm, token_dim=64, **kwargs): 115 | super().__init__() 116 | self.num_classes = num_classes 117 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 118 | 119 | self.use_t2t = use_t2t 120 | if use_t2t : 121 | self.tokens_to_token = T2T_module( 122 | img_size=img_size, tokens_type=tokens_type, in_chans=in_chans, embed_dim=embed_dim, token_dim=token_dim) 123 | num_patches = self.tokens_to_token.num_patches 124 | else : 125 | self.patch_embed = PatchEmbed( 126 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 127 | num_patches = self.patch_embed.num_patches 128 | 129 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 130 | self.pos_embed = nn.Parameter(data=get_sinusoid_encoding(n_position=num_patches + 1, d_hid=embed_dim), requires_grad=False) 131 | self.pos_drop = nn.Dropout(p=drop_rate) 132 | 133 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 134 | assert keep_layers >=1, 'keep layers must >=1' 135 | self.keep_layers = keep_layers 136 | self.head_dim = embed_dim // num_heads 137 | print('ada head, ada layer, ada token', ada_head, ada_layer, ada_token) 138 | 139 | self.blocks = nn.ModuleList([ 140 | StepAdaBlock( 141 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 142 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, 143 | is_token_select=ada_token and i >= ada_token_start_layer, ada_token_with_mlp=ada_token_with_mlp, 144 | ada_token_pre_softmax = ada_token_pre_softmax, ada_token_detach_attn=ada_token_detach_attn, dyna_data=dyna_data, ada_head_v2=ada_head_v2, ada_token_detach_attn_at_mlp=ada_token_detach_attn_at_mlp, 145 | ada_head= ada_head and i>=keep_layers, 146 | ada_layer= ada_layer and i >=keep_layers, 147 | norm_policy=norm_policy, 148 | only_head_attn=ada_head_attn, head_slowfast=head_slowfast) 149 | for i in range(depth)]) 150 | self.norm = norm_layer(embed_dim) 151 | 152 | # Classifier head 153 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 154 | 155 | trunc_normal_(self.cls_token, std=.02) 156 | self.apply(self._init_weights) 157 | 158 | def _init_weights(self, m): 159 | if isinstance(m, nn.Linear): 160 | trunc_normal_(m.weight, std=.02) 161 | if isinstance(m, nn.Linear) and m.bias is not None: 162 | nn.init.constant_(m.bias, 0) 163 | elif isinstance(m, nn.LayerNorm): 164 | nn.init.constant_(m.bias, 0) 165 | nn.init.constant_(m.weight, 1.0) 166 | 167 | @torch.jit.ignore 168 | def no_weight_decay(self): 169 | return {'cls_token'} 170 | 171 | def get_classifier(self): 172 | return self.head 173 | 174 | def reset_classifier(self, num_classes, global_pool=''): 175 | self.num_classes = num_classes 176 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 177 | 178 | def forward_features(self, x): 179 | B = x.shape[0] 180 | if self.use_t2t : 181 | x = self.tokens_to_token(x) 182 | else : 183 | x = self.patch_embed(x) 184 | 185 | cls_tokens = self.cls_token.expand(B, -1, -1) 186 | x = torch.cat((cls_tokens, x), dim=1) 187 | x = x + self.pos_embed 188 | x = self.pos_drop(x) 189 | 190 | attn_list = [] 191 | hidden_list = [] 192 | token_select_list = [] 193 | head_select_list = [] 194 | layer_select_list = [] 195 | head_select_logits_list = [] 196 | layer_select_logits_list = [] 197 | 198 | def filter_append(target_list, element): 199 | if element is not None : 200 | target_list.append(element) 201 | 202 | for blk in self.blocks : 203 | x, attn, this_head_select, this_layer_select, this_token_select, this_head_select_logits, this_layer_select_logits = blk(x) 204 | attn_list.append(attn) 205 | hidden_list.append(x) 206 | filter_append(head_select_list, this_head_select) 207 | filter_append(layer_select_list, this_layer_select) 208 | filter_append(token_select_list, this_token_select) 209 | filter_append(head_select_logits_list, this_head_select_logits) 210 | filter_append(layer_select_logits_list, this_layer_select_logits) 211 | 212 | def convert_list_to_tensor(list_convert): 213 | if len(list_convert) : 214 | result = torch.stack(list_convert, dim=1) 215 | else : 216 | result = None 217 | return result 218 | 219 | head_select = convert_list_to_tensor(head_select_list) 220 | if head_select is not None : 221 | head_select = head_select.squeeze(-1) 222 | layer_select = convert_list_to_tensor(layer_select_list) 223 | token_select = convert_list_to_tensor(token_select_list) 224 | head_select_logits = convert_list_to_tensor(head_select_logits_list) 225 | layer_select_logits = convert_list_to_tensor(layer_select_logits_list) 226 | a = [head_select, layer_select, token_select, head_select_logits, layer_select_logits] 227 | x = self.norm(x) 228 | 229 | return x[:, 0], head_select, layer_select, token_select, attn_list, hidden_list, dict(head_select_logits=head_select_logits, layer_select_logits=layer_select_logits) 230 | 231 | def zero_classification_grad(self): 232 | for blk in self.blocks[self.keep_layers:] : 233 | blk.zero_grad() 234 | self.norm.zero_grad() 235 | self.head.zero_grad() 236 | 237 | def forward(self, x, training=False, ret_attn_list=False): 238 | x, head_select, layer_select, token_select, attn_list, hidden_list, select_logtis = self.forward_features(x) 239 | x = self.head(x) 240 | if ret_attn_list : 241 | return x, head_select, layer_select, token_select, attn_list, hidden_list, select_logtis 242 | return x, head_select, layer_select, token_select, select_logtis 243 | 244 | 245 | class T2T_GroupNorm(nn.GroupNorm): 246 | def __init__(self, num_groups: int, num_channels: int, **kwargs) -> None: 247 | super().__init__(num_groups, num_channels, **kwargs) 248 | 249 | def forward(self, input: torch.Tensor) -> torch.Tensor: 250 | if input.dim() > 2 : 251 | input = input.transpose(1,-1) 252 | result = super().forward(input) 253 | if input.dim() > 2 : 254 | result = result.transpose(1,-1) 255 | return result 256 | 257 | 258 | def t2t_group_norm(dim, num_groups): 259 | return T2T_GroupNorm(num_groups, dim) 260 | 261 | @register_model 262 | def ada_step_t2t_vit_14_lnorm(pretrained=False, **kwargs): # adopt performer for tokens to token 263 | if pretrained: 264 | kwargs.setdefault('qk_scale', 384 ** -0.5) 265 | model = AdaStepT2T_ViT(tokens_type='performer', embed_dim=384, depth=14, num_heads=6, mlp_ratio=3., **kwargs) 266 | model.default_cfg = default_cfgs['T2t_vit_14'] 267 | if pretrained: 268 | load_pretrained( 269 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 270 | return model 271 | 272 | @register_model 273 | def ada_step_t2t_vit_19_lnorm(pretrained=False, **kwargs): # adopt performer for tokens to token 274 | if pretrained: 275 | kwargs.setdefault('qk_scale', 448 ** -0.5) 276 | model = AdaStepT2T_ViT(tokens_type='performer', embed_dim=448, depth=19, num_heads=7, mlp_ratio=3., **kwargs) 277 | model.default_cfg = default_cfgs['T2t_vit_19'] 278 | if pretrained: 279 | load_pretrained( 280 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 281 | return model 282 | 283 | @register_model 284 | def ada_step_deit_tiny_patch16_224(pretrained=False, **kwargs): 285 | model = AdaStepT2T_ViT( 286 | use_t2t=False, 287 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, 288 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 289 | model.default_cfg = _cfg() 290 | return model 291 | 292 | 293 | @register_model 294 | def ada_step_deit_small_patch16_224(pretrained=False, **kwargs): 295 | model = AdaStepT2T_ViT( 296 | use_t2t=False, 297 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 298 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 299 | model.default_cfg = _cfg() 300 | return model 301 | 302 | @register_model 303 | def ada_step_deit_base_patch16_224(pretrained=False, **kwargs): # adopt performer for tokens to token 304 | model = AdaStepT2T_ViT( 305 | use_t2t=False, 306 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 307 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 308 | model.default_cfg = _cfg() 309 | return model 310 | -------------------------------------------------------------------------------- /models/ada_transformer_block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from timm.models.layers import DropPath 5 | from torch.nn.modules.normalization import LayerNorm 6 | 7 | def _gumbel_sigmoid( 8 | logits, tau=1, hard=False, eps=1e-10, training = True, threshold = 0.5 9 | ): 10 | if training : 11 | # ~Gumbel(0,1)` 12 | gumbels1 = ( 13 | -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format) 14 | .exponential_() 15 | .log() 16 | ) 17 | gumbels2 = ( 18 | -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format) 19 | .exponential_() 20 | .log() 21 | ) 22 | # Difference of two` gumbels because we apply a sigmoid 23 | gumbels1 = (logits + gumbels1 - gumbels2) / tau 24 | y_soft = gumbels1.sigmoid() 25 | else : 26 | y_soft = logits.sigmoid() 27 | 28 | if hard: 29 | # Straight through. 30 | y_hard = torch.zeros_like( 31 | logits, memory_format=torch.legacy_contiguous_format 32 | ).masked_fill(y_soft > threshold, 1.0) 33 | ret = y_hard - y_soft.detach() + y_soft 34 | else: 35 | ret = y_soft 36 | return ret 37 | 38 | def get_random_policy(policy, ratio): 39 | random_p = torch.empty_like(policy).fill_(ratio).bernoulli() + policy * 0.0 # add policy * 0.0 into the loop of loss calculation to avoid the DDP issue 40 | return random_p 41 | 42 | class SimpleTokenSelect(nn.Module): 43 | def __init__(self, dim_in, tau=5, is_hard=True, threshold=0.5, bias=True, pre_softmax=True, mask_filled_value=float('-inf'), ada_token_nonstep=False, ada_token_detach_attn=True): 44 | super().__init__() 45 | self.count_flops = False 46 | self.ada_token_nonstep = ada_token_nonstep # if using nonstep, no mlp_head is needed in each of these layers 47 | if not ada_token_nonstep: 48 | self.mlp_head = nn.Linear(dim_in, 1, bias=bias) 49 | self.norm = nn.Identity() 50 | self.is_hard = is_hard 51 | self.tau = tau 52 | self.threshold = threshold 53 | self.add_noise = True 54 | self.pre_softmax = pre_softmax 55 | self.mask_filled_value = mask_filled_value 56 | self.ada_token_detach_attn = ada_token_detach_attn 57 | self.random_policy = False 58 | self.random_token = False 59 | self.random_token_ratio = 1. 60 | 61 | def set_tau(self, tau): 62 | self.tau = tau 63 | 64 | def forward(self, x, attn, attn_pre_softmax, token_select=None): 65 | b, l = x.shape[:2] 66 | 67 | if not self.ada_token_nonstep: 68 | # generate token policy step by step in each layer, including the first (couple of) blocks 69 | logits = self.mlp_head(self.norm(x[:,1:])) 70 | token_select = _gumbel_sigmoid(logits, self.tau, self.is_hard, threshold=self.threshold, training=self.training) 71 | if self.random_policy or self.random_token: 72 | token_select = get_random_policy(token_select, self.random_token_ratio) 73 | token_select = torch.cat([token_select.new_ones(b,1,1), token_select], dim=1) 74 | # token_select = token_select.unsqueeze(-1) #(b,l,1) 75 | token_select = token_select.transpose(1,2) #(b,1,l) 76 | else: 77 | if token_select is None: 78 | # when token_select is not given in non-step setting, 79 | # it means this layer is in the first (couple of) trans blocks before head/layer policy generation, 80 | # and thus we do not drop any token in this/these layers as well for consistency 81 | token_select = torch.ones((b, 1, l), device=x.device) 82 | else: 83 | token_select = token_select[:, None, :] 84 | 85 | if self.count_flops : 86 | return attn, token_select.squeeze(1) 87 | 88 | attn_policy = torch.bmm(token_select.transpose(-1,-2), token_select) #(b,l,l) 89 | attn_policy = attn_policy.unsqueeze(1) #(b,1,l,l) 90 | if self.ada_token_detach_attn : 91 | attn_policy = attn_policy.detach() 92 | 93 | # use pre_softmax during inference in both pre-softmax or pre-softmax training 94 | if self.pre_softmax or not self.training : 95 | eye_mat = attn.new_zeros((l,l)) 96 | eye_mat = eye_mat.fill_diagonal_(1) #(1,1,l,l) 97 | attn = attn_pre_softmax * attn_policy + attn_pre_softmax.new_zeros(attn_pre_softmax.shape).masked_fill_((1 - attn_policy - eye_mat)>0, self.mask_filled_value) 98 | attn = attn.softmax(-1) 99 | assert not torch.isnan(attn).any(), 'token select pre softmax nan !' 100 | else : 101 | attn = nn.functional.normalize(attn * attn_policy, 1, -1) 102 | 103 | return attn, token_select.squeeze(1) 104 | 105 | 106 | class BlockHeadSelect(nn.Module): 107 | def __init__(self, dim_in, num_heads, tau=5, is_hard=True, threshold=0.5, bias=True): 108 | super().__init__() 109 | self.mlp_head = nn.Linear(dim_in, num_heads, bias=bias) 110 | # self.norm = Lay rNorm(dim_in) 111 | self.is_hard = is_hard 112 | self.tau = tau 113 | self.threshold = threshold 114 | self.add_noise = True 115 | self.head_dim = dim_in // num_heads 116 | self.random_policy = False 117 | self.random_head = False 118 | self.random_head_ratio = 1. 119 | 120 | def set_tau(self, tau): 121 | self.tau = tau 122 | 123 | def forward(self, x): 124 | ''' 125 | ret : tensor(b, dim, 1) 126 | ''' 127 | bsize = x.shape[0] 128 | logits = self.mlp_head(x) 129 | sample = _gumbel_sigmoid(logits, self.tau, self.is_hard, threshold=self.threshold, training=self.training) 130 | if self.random_policy or self.random_head: 131 | sample = get_random_policy(sample, self.random_head_ratio) 132 | sample = sample.unsqueeze(-1) #(b,h,1) 133 | 134 | width_select = sample.expand(-1,-1,self.head_dim) 135 | width_select = width_select.reshape(bsize, -1, 1) 136 | 137 | return sample, width_select, logits 138 | 139 | 140 | class BlockLayerSelect(nn.Module): 141 | def __init__(self, dim_in, num_sub_layer, tau=5, is_hard=True, threshold=0.5, bias=True): 142 | super().__init__() 143 | self.mlp_head = nn.Linear(dim_in, num_sub_layer, bias=bias) 144 | # self.norm = LayerNorm(dim_in) 145 | self.is_hard = is_hard 146 | self.tau = tau 147 | self.threshold = threshold 148 | self.add_noise = True 149 | self.random_policy = False 150 | self.random_layer = False 151 | self.random_layer_ratio = 1. 152 | 153 | def set_tau(self, tau): 154 | self.tau = tau 155 | 156 | def forward(self, x): 157 | logits = self.mlp_head(x) 158 | sample = _gumbel_sigmoid(logits, self.tau, self.is_hard, threshold=self.threshold, training=self.training) 159 | if self.random_policy or self.random_layer: 160 | sample = get_random_policy(sample, self.random_layer_ratio) 161 | sample = sample #(b,2) 162 | 163 | return sample, logits 164 | 165 | 166 | class DynaLinear(nn.Linear): 167 | def __init__(self, in_features, out_features, num_heads=6, bias=True, dyna_dim=[True, True], dyna_data=False): 168 | super(DynaLinear, self).__init__( 169 | in_features, out_features, bias=bias) 170 | self.in_features_max = in_features 171 | self.out_features_max = out_features 172 | 173 | self.num_heads = num_heads 174 | self.width_mult = 1. 175 | self.dyna_dim = dyna_dim 176 | self.in_features = in_features 177 | self.out_features = out_features 178 | self.use_full_linear = False 179 | self.dyna_data = dyna_data # if dyna_data is False, dyna weights 180 | self.count_flops = False 181 | 182 | def forward(self, input, width_select=None, width_specify=None): 183 | """ 184 | input : tensor (B,L,C) 185 | width_select : tensor(B,1,dims) or (B,dims,1) 186 | """ 187 | if self.use_full_linear : 188 | return super().forward(input) 189 | 190 | if self.count_flops : 191 | if width_select is not None : 192 | assert width_select.shape[0] == 1 193 | width_specify = int(width_select.sum().item()) 194 | width_select = None 195 | 196 | if self.dyna_data and width_select is not None : 197 | # only support input shape of (b,l,c) 198 | assert input.dim() == 3 199 | assert width_select.dim() == 3 200 | assert width_select.shape[1] == 1 or width_select.shape[2] == 1 201 | # if output is static, then input is dynamic 202 | if width_select.shape[1] == 1 : 203 | input_mask = width_select 204 | else : 205 | input_mask = 1 206 | if width_select.shape[2] == 1 : 207 | output_mask = width_select[...,0].unsqueeze(1) #(b,1,c) 208 | else : 209 | output_mask = 1 210 | input = input * input_mask 211 | result = super().forward(input) * output_mask 212 | return result 213 | 214 | if width_select is not None: 215 | weight = self.weight * width_select 216 | b, n, c = input.shape 217 | input = input.transpose(1,2).reshape(1,-1,n) 218 | weight = weight.view(-1,c,1) 219 | if self.bias is None : 220 | bias = self.bias 221 | elif width_select.shape[-1] == 1: 222 | bias = self.bias * width_select.squeeze() 223 | bias = bias.view(-1) 224 | else : 225 | bias = self.bias.unsqueeze(0).expand(b,-1) 226 | bias = bias.reshape(-1) 227 | result = nn.functional.conv1d(input, weight, bias, groups=b) 228 | result = result.view(b,-1,n).transpose(1,2) #(b,n,c) 229 | return result 230 | 231 | if width_specify is not None : 232 | if self.dyna_dim[0] : 233 | self.in_features = width_specify 234 | if self.dyna_dim[1] : 235 | self.out_features = width_specify 236 | weight = self.weight[:self.out_features, :self.in_features] 237 | if self.bias is not None: 238 | bias = self.bias[:self.out_features] 239 | else: 240 | bias = self.bias 241 | 242 | return nn.functional.linear(input, weight, bias) 243 | 244 | 245 | class AdaAttention(nn.Module): 246 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., ada_token=False, ada_token_nonstep=False, ada_token_pre_softmax=True, ada_token_detach_attn=True, dyna_data=False, ada_token_threshold=0.6): 247 | super().__init__() 248 | self.count_flops = False 249 | self.t_ratio = 1 250 | 251 | self.num_heads = num_heads 252 | # head_dim = dim // num_heads 253 | self.head_dim = dim // num_heads 254 | self.scale = qk_scale or self.head_dim ** -0.5 255 | 256 | # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 257 | self.query = DynaLinear(dim, dim, bias=qkv_bias, dyna_dim=[False, True], dyna_data=dyna_data) 258 | self.key = DynaLinear(dim, dim, bias=qkv_bias, dyna_dim=[False, True], dyna_data=dyna_data) 259 | self.value = DynaLinear(dim, dim, bias=qkv_bias, dyna_dim=[False, True], dyna_data=dyna_data) 260 | 261 | self.attn_drop = nn.Dropout(attn_drop) 262 | self.proj = DynaLinear(dim, dim, dyna_dim=[True, True], dyna_data=dyna_data) 263 | self.proj_drop = nn.Dropout(proj_drop) 264 | 265 | if ada_token : 266 | self.token_select = SimpleTokenSelect(dim, pre_softmax=ada_token_pre_softmax, ada_token_detach_attn=ada_token_detach_attn, ada_token_nonstep=ada_token_nonstep, threshold=ada_token_threshold) 267 | else : 268 | self.token_select = None 269 | 270 | def forward_count_flops(self, x, width_select=None, head_select=None, token_select=None, only_head_attn=False): 271 | width_specify=None 272 | B, N, C = x.shape 273 | width_select_qk = width_select 274 | if only_head_attn : 275 | assert head_select is not None 276 | width_select = None 277 | 278 | if self.token_select is not None : 279 | token_active = int(x.shape[1] * self.t_ratio) 280 | else : 281 | token_active = x.shape[1] 282 | 283 | q = self.query(x[:,:token_active], width_select=width_select_qk, width_specify=width_specify).reshape(B, token_active, -1, C//self.num_heads).permute(0,2,1,3) 284 | k = self.key(x[:,:token_active], width_select=width_select_qk, width_specify=width_specify).reshape(B, token_active, -1, C//self.num_heads).permute(0,2,1,3) 285 | v = self.value(x[:,:token_active], width_select=width_select, width_specify=width_specify).reshape(B, token_active, -1, C//self.num_heads).permute(0,2,1,3) 286 | 287 | attn = (q @ k.transpose(-2, -1)) * self.scale 288 | attn_pre_softmax = attn 289 | attn = attn.softmax(dim=-1) 290 | 291 | attn_origin = attn 292 | 293 | if self.token_select is not None : 294 | attn, token_select = self.token_select(x, attn, attn_pre_softmax, token_select=token_select) 295 | 296 | attn = self.attn_drop(attn) 297 | if only_head_attn : 298 | v[:,:attn.shape[1]] = (attn @ v[:,:attn.shape[1]]) 299 | x = v.transpose(1, 2).reshape(B, token_active, -1) 300 | else : 301 | x = (attn @ v).transpose(1, 2).reshape(B, token_active, -1) 302 | if width_select is not None : 303 | width_select = width_select.transpose(-1,-2) 304 | x = self.proj(x, width_select, width_specify=width_specify) 305 | 306 | x[:, :token_active] = self.proj_drop(x[:, :token_active]) 307 | 308 | return x, attn_origin, token_select 309 | 310 | def forward(self, x, mask=None, value_mask_fill=-1e4, head_mask=None, width_select=None, head_select=None, token_select=None, width_specify=None, token_keep_ratio=None, only_head_attn=False, 311 | random_token_select_ratio=1.0): 312 | 313 | if self.count_flops : 314 | return self.forward_count_flops(x, width_select, head_select, token_select, only_head_attn) 315 | B, N, C = x.shape 316 | if only_head_attn : 317 | assert head_select is not None 318 | width_select = None 319 | 320 | q = self.query(x, width_select=width_select, width_specify=width_specify).reshape(B, N, -1, C//self.num_heads).permute(0,2,1,3) 321 | k = self.key(x, width_select=width_select, width_specify=width_specify).reshape(B, N, -1, C//self.num_heads).permute(0,2,1,3) 322 | v = self.value(x, width_select=width_select, width_specify=width_specify).reshape(B, N, -1, C//self.num_heads).permute(0,2,1,3) 323 | 324 | attn = (q @ k.transpose(-2, -1)) * self.scale 325 | if mask is not None : 326 | mask = mask.view(B, 1, N, 1).expand_as(attn) 327 | attn[~mask] = value_mask_fill 328 | attn_pre_softmax = attn 329 | attn = attn.softmax(dim=-1) 330 | if only_head_attn : 331 | head_select = head_select.view(*head_select.shape, *([1]*(4-head_select.dim()))) 332 | eye_mat = attn.new_zeros(attn.shape[-2:]) 333 | eye_mat.fill_diagonal_(1).view(1,1,*eye_mat.shape) #(1,1,l,l) 334 | attn = attn * head_select + eye_mat * (1 - head_select) 335 | 336 | attn_origin = attn 337 | if head_mask is not None : 338 | attn = attn * head_mask 339 | 340 | if self.token_select is not None : 341 | attn, token_select = self.token_select(x, attn, attn_pre_softmax, token_select=token_select) 342 | if only_head_attn and not self.training : 343 | head_select = head_select.view(*head_select.shape, *([1]*(4-head_select.dim()))) 344 | eye_mat = attn.new_zeros(attn.shape[-2:]) 345 | eye_mat.fill_diagonal_(1).view(1,1,*eye_mat.shape) #(1,1,l,l) 346 | attn = attn * head_select + eye_mat * (1 - head_select) 347 | 348 | if random_token_select_ratio != 1.0: 349 | # test random baseline with predefined token select ratio 350 | token_select = (torch.rand((B, N), device=x.device) < random_token_select_ratio).float() 351 | token_select[:, 0] = 1.0 # CLS token is always kept 352 | attn_policy = torch.bmm(token_select[:, None, :].transpose(-1,-2), token_select[:, None, :]) #(b,l,l) 353 | attn = nn.functional.normalize(attn * attn_policy.unsqueeze(1), 1, -1) 354 | 355 | attn = self.attn_drop(attn) 356 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1) 357 | if width_select is not None : 358 | width_select = width_select.transpose(-1,-2) 359 | x = self.proj(x, width_select, width_specify=width_specify) 360 | 361 | if token_select is not None : 362 | x = x * token_select.unsqueeze(-1) 363 | x = self.proj_drop(x) 364 | 365 | return x, attn_origin, token_select 366 | 367 | 368 | class AdaMlp(nn.Module): 369 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., dyna_data=False): 370 | super().__init__() 371 | out_features = out_features or in_features 372 | hidden_features = hidden_features or in_features 373 | self.act = act_layer() 374 | self.fc1 = DynaLinear(in_features, hidden_features, dyna_dim=[True, False], dyna_data=dyna_data) 375 | self.fc2 = DynaLinear(hidden_features, out_features, dyna_dim=[False, False], dyna_data=dyna_data) 376 | self.drop = nn.Dropout(drop) 377 | 378 | def forward(self, x, mask=None, width_select=None, width_specify=None): 379 | if mask is not None : 380 | assert mask.shape[:2] == x.shape[:2] 381 | if mask.dim() == 2: 382 | mask = mask.unsqueeze(-1) 383 | if mask.dtype != x.dtype : 384 | mask = mask.type_as(x) 385 | else : 386 | mask = x.new_ones(x.shape[:2]).unsqueeze(-1) 387 | x = self.fc1(x, width_select=width_select, width_specify=width_specify) 388 | x = x * mask 389 | x = self.act(x) 390 | x = self.drop(x) 391 | width_select = None 392 | x = self.fc2(x, width_select=width_select, width_specify=width_specify) 393 | x = x * mask 394 | x = self.drop(x) 395 | return x 396 | 397 | 398 | class StepAdaBlock(nn.Module): 399 | 400 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 401 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, 402 | ada_head=False, ada_layer=False, is_token_select=False, ada_token_pre_softmax=True, ada_token_detach_attn=True, 403 | ada_token_with_mlp=False, ada_token_detach_attn_at_mlp=True, 404 | dyna_data=False, ada_head_v2=False, only_head_attn=False, head_slowfast=False, 405 | norm_policy=False): 406 | super().__init__() 407 | self.count_flops = False 408 | self.h_ratio, self.t_ratio = 1., 1. 409 | self.l_ratio = [1, 1] 410 | self.is_token_select = is_token_select 411 | 412 | self.norm_policy = None 413 | if norm_policy and (ada_head or ada_layer): 414 | self.norm_policy = norm_layer(dim) 415 | self.norm1 = norm_layer(dim) 416 | self.attn = AdaAttention( 417 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, ada_token=is_token_select, ada_token_pre_softmax=ada_token_pre_softmax, ada_token_detach_attn=ada_token_detach_attn, 418 | dyna_data=dyna_data) 419 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 420 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 421 | self.norm2 = norm_layer(dim) 422 | mlp_hidden_dim = int(dim * mlp_ratio) 423 | self.mlp_ratio = mlp_ratio 424 | self.ada_head_v2 = ada_head_v2 425 | self.mlp = AdaMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, dyna_data=dyna_data) 426 | self.ada_token_with_mlp = ada_token_with_mlp 427 | self.ada_token_detach_attn_at_mlp = ada_token_detach_attn_at_mlp 428 | self.only_head_attn = only_head_attn and ada_head 429 | self.head_slowfast = head_slowfast 430 | 431 | self.head_select = None 432 | self.layer_select = None 433 | if ada_head : 434 | self.head_select = BlockHeadSelect(dim, num_heads) 435 | if ada_layer : 436 | self.layer_select = BlockLayerSelect(dim, 2) 437 | 438 | def forward_count_flops(self, x) : 439 | if self.norm_policy is not None : 440 | policy_token = self.norm_policy(x)[:,0] 441 | else : 442 | policy_token = x[:,0] 443 | if self.layer_select is not None : 444 | sub_layer_select, layer_logits = self.layer_select(policy_token) 445 | else : 446 | sub_layer_select, layer_logits = None, None 447 | if self.head_select is not None : 448 | head_select, width_select, head_logits = self.head_select(policy_token) 449 | else : 450 | head_select, width_select, head_logits = None, None, None 451 | 452 | def mask_policy(policy, ratio) : 453 | if policy is not None : 454 | policy = torch.zeros_like(policy) 455 | policy[:, :int(policy.shape[1] * ratio)] = 1 456 | return policy 457 | h_ratio, l_ratio, t_ratio = [1, 1, 1] 458 | if head_select is not None : 459 | h_ratio = self.h_ratio 460 | head_select = mask_policy(head_select, h_ratio) 461 | if width_select is not None : 462 | width_select = mask_policy(width_select, h_ratio) 463 | if sub_layer_select is not None : 464 | l_ratio = self.l_ratio 465 | sub_layer_select[:,0] = l_ratio[0] 466 | sub_layer_select[:,1] = l_ratio[1] 467 | if self.is_token_select : 468 | t_ratio = self.t_ratio 469 | 470 | if width_select is not None : 471 | # TODO 472 | if width_select.sum() == 0 : 473 | return x 474 | width_select_attn = width_select #(b,c,1) 475 | if self.only_head_attn : 476 | assert head_select is not None 477 | width_select_mlp = None 478 | elif self.ada_head_v2 : 479 | bs = width_select.shape[0] 480 | width_select_mlp = width_select.expand(-1,-1,int(self.mlp_ratio)).reshape(bs,-1,1) 481 | else : 482 | width_select_mlp = width_select.transpose(-1,-2) 483 | else : 484 | width_select_attn, width_select_mlp = [None] * 2 485 | 486 | if sub_layer_select is None or sub_layer_select[0,0] : 487 | attn_x, attn_origin, token_select = self.attn(self.norm1(x), width_select=width_select_attn, 488 | head_select=head_select, only_head_attn=self.only_head_attn) 489 | x[:,:int(x.shape[1] * t_ratio),:attn_x.shape[-1]] = x[:,:int(x.shape[1] * t_ratio),:attn_x.shape[-1]] + attn_x 490 | elif sub_layer_select[0,0] == 0: 491 | attn_x = 0 492 | x = x + attn_x 493 | 494 | x = self.norm2(x) 495 | if self.only_head_attn : 496 | mlp_x = x 497 | else : 498 | mlp_x = x 499 | 500 | if sub_layer_select is not None and sub_layer_select[0,1] == 0 : 501 | x = x + 0 502 | else : 503 | if self.ada_token_with_mlp : 504 | token_active = int(x.shape[1] * t_ratio) 505 | else : 506 | token_active = x.shape[1] 507 | 508 | mlp_x = self.mlp(mlp_x[:,:token_active], width_select=width_select_mlp) 509 | x[:,:token_active] = x[:,:token_active] + mlp_x 510 | 511 | return x 512 | 513 | def forward(self, x, mask=None, head_mask=None, width_specify=None, # only_head_attn=False, head_slowfast=False, 514 | random_token_select_ratio=1.0): 515 | """ 516 | width_select : (b,c,1) 517 | """ 518 | if self.count_flops : 519 | return self.forward_count_flops(x) 520 | 521 | if self.norm_policy is not None : 522 | policy_token = self.norm_policy(x)[:,0] 523 | else : 524 | policy_token = x[:,0] 525 | if self.layer_select is not None : 526 | sub_layer_select, layer_logits = self.layer_select(policy_token) 527 | else : 528 | sub_layer_select, layer_logits = None, None 529 | if self.head_select is not None : 530 | head_select, width_select, head_logits = self.head_select(policy_token) 531 | else : 532 | head_select, width_select, head_logits = None, None, None 533 | 534 | # start 535 | if self.only_head_attn : 536 | assert head_select is not None 537 | width_select = None 538 | if width_select is not None : 539 | # TODO 540 | width_select_attn = width_select #(b,c,1) 541 | if self.ada_head_v2 : 542 | bs = width_select.shape[0] 543 | width_select_mlp = width_select.expand(-1,-1,int(self.mlp_ratio)).reshape(bs,-1,1) 544 | else : 545 | width_select_mlp = width_select.transpose(-1,-2) 546 | else : 547 | width_select_attn, width_select_mlp = [None] * 2 548 | 549 | attn_x, attn_origin, token_select = self.attn(self.norm1(x), mask=mask, head_mask=head_mask, width_select=width_select_attn, 550 | width_specify=width_specify, head_select=head_select, only_head_attn=self.only_head_attn, 551 | random_token_select_ratio=random_token_select_ratio) 552 | 553 | if sub_layer_select is None : 554 | x = x + self.drop_path(attn_x) 555 | mlp_x = self.mlp(self.norm2(x), width_select=width_select_mlp, width_specify=width_specify) 556 | if self.ada_token_with_mlp and token_select is not None : 557 | t_select = token_select.unsqueeze(-1) 558 | if self.ada_token_detach_attn_at_mlp: 559 | t_select = t_select.detach() 560 | mlp_x = t_select * mlp_x 561 | x = x + self.drop_path(mlp_x) 562 | else : 563 | x = x + sub_layer_select[:,0][:,None,None] * attn_x 564 | mlp_x = self.mlp(self.norm2(x), width_select=width_select_mlp, width_specify=width_specify) 565 | if self.ada_token_with_mlp and token_select is not None : 566 | t_select = token_select.unsqueeze(-1) 567 | if self.ada_token_detach_attn_at_mlp: 568 | t_select = t_select.detach() 569 | mlp_x = t_select * mlp_x 570 | x = x + sub_layer_select[:,1][:,None,None] * mlp_x 571 | return x, attn_origin, head_select, sub_layer_select, token_select, head_logits, layer_logits 572 | -------------------------------------------------------------------------------- /models/deit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | import torch 4 | import torch.nn as nn 5 | from functools import partial 6 | 7 | from timm.models.vision_transformer import VisionTransformer, _cfg 8 | from timm.models.registry import register_model 9 | from timm.models.layers import trunc_normal_ 10 | 11 | 12 | __all__ = [ 13 | 'deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224', 14 | 'deit_tiny_distilled_patch16_224', 'deit_small_distilled_patch16_224', 15 | 'deit_base_distilled_patch16_224', 'deit_base_patch16_384', 16 | 'deit_base_distilled_patch16_384', 17 | ] 18 | 19 | 20 | class DistilledVisionTransformer(VisionTransformer): 21 | def __init__(self, *args, **kwargs): 22 | super().__init__(*args, **kwargs) 23 | self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) 24 | num_patches = self.patch_embed.num_patches 25 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim)) 26 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() 27 | 28 | trunc_normal_(self.dist_token, std=.02) 29 | trunc_normal_(self.pos_embed, std=.02) 30 | self.head_dist.apply(self._init_weights) 31 | 32 | def forward_features(self, x): 33 | # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 34 | # with slight modifications to add the dist_token 35 | B = x.shape[0] 36 | x = self.patch_embed(x) 37 | 38 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 39 | dist_token = self.dist_token.expand(B, -1, -1) 40 | x = torch.cat((cls_tokens, dist_token, x), dim=1) 41 | 42 | x = x + self.pos_embed 43 | x = self.pos_drop(x) 44 | 45 | for blk in self.blocks: 46 | x = blk(x) 47 | 48 | x = self.norm(x) 49 | return x[:, 0], x[:, 1] 50 | 51 | def forward(self, x): 52 | x, x_dist = self.forward_features(x) 53 | x = self.head(x) 54 | x_dist = self.head_dist(x_dist) 55 | if self.training: 56 | return x, x_dist 57 | else: 58 | # during inference, return the average of both classifier predictions 59 | return (x + x_dist) / 2 60 | 61 | 62 | @register_model 63 | def deit_tiny_patch16_224(pretrained=False, **kwargs): 64 | model = VisionTransformer( 65 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, 66 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 67 | model.default_cfg = _cfg() 68 | if pretrained: 69 | checkpoint = torch.hub.load_state_dict_from_url( 70 | url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth", 71 | map_location="cpu", check_hash=True 72 | ) 73 | model.load_state_dict(checkpoint["model"]) 74 | return model 75 | 76 | 77 | @register_model 78 | def deit_small_patch16_224(pretrained=False, **kwargs): 79 | model = VisionTransformer( 80 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 81 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 82 | model.default_cfg = _cfg() 83 | if pretrained: 84 | checkpoint = torch.hub.load_state_dict_from_url( 85 | url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth", 86 | map_location="cpu", check_hash=True 87 | ) 88 | model.load_state_dict(checkpoint["model"]) 89 | return model 90 | 91 | 92 | @register_model 93 | def deit_base_patch16_224(pretrained=False, **kwargs): 94 | model = VisionTransformer( 95 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 96 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 97 | model.default_cfg = _cfg() 98 | if pretrained: 99 | checkpoint = torch.hub.load_state_dict_from_url( 100 | url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", 101 | map_location="cpu", check_hash=True 102 | ) 103 | model.load_state_dict(checkpoint["model"]) 104 | return model 105 | 106 | 107 | @register_model 108 | def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs): 109 | model = DistilledVisionTransformer( 110 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, 111 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 112 | model.default_cfg = _cfg() 113 | if pretrained: 114 | checkpoint = torch.hub.load_state_dict_from_url( 115 | url="https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth", 116 | map_location="cpu", check_hash=True 117 | ) 118 | model.load_state_dict(checkpoint["model"]) 119 | return model 120 | 121 | 122 | @register_model 123 | def deit_small_distilled_patch16_224(pretrained=False, **kwargs): 124 | model = DistilledVisionTransformer( 125 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 126 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 127 | model.default_cfg = _cfg() 128 | if pretrained: 129 | checkpoint = torch.hub.load_state_dict_from_url( 130 | url="https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth", 131 | map_location="cpu", check_hash=True 132 | ) 133 | model.load_state_dict(checkpoint["model"]) 134 | return model 135 | 136 | 137 | @register_model 138 | def deit_base_distilled_patch16_224(pretrained=False, **kwargs): 139 | model = DistilledVisionTransformer( 140 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 141 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 142 | model.default_cfg = _cfg() 143 | if pretrained: 144 | checkpoint = torch.hub.load_state_dict_from_url( 145 | url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth", 146 | map_location="cpu", check_hash=True 147 | ) 148 | model.load_state_dict(checkpoint["model"]) 149 | return model 150 | 151 | 152 | @register_model 153 | def deit_base_patch16_384(pretrained=False, **kwargs): 154 | model = VisionTransformer( 155 | img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 156 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 157 | model.default_cfg = _cfg() 158 | if pretrained: 159 | checkpoint = torch.hub.load_state_dict_from_url( 160 | url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth", 161 | map_location="cpu", check_hash=True 162 | ) 163 | model.load_state_dict(checkpoint["model"]) 164 | return model 165 | 166 | 167 | @register_model 168 | def deit_base_distilled_patch16_384(pretrained=False, **kwargs): 169 | model = DistilledVisionTransformer( 170 | img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 171 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 172 | model.default_cfg = _cfg() 173 | if pretrained: 174 | checkpoint = torch.hub.load_state_dict_from_url( 175 | url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth", 176 | map_location="cpu", check_hash=True 177 | ) 178 | model.load_state_dict(checkpoint["model"]) 179 | return model -------------------------------------------------------------------------------- /models/losses.py: -------------------------------------------------------------------------------- 1 | from numpy.lib.arraysetops import isin 2 | from timm import loss 3 | from timm.data.transforms_factory import transforms_imagenet_train 4 | import torch 5 | from torch.functional import Tensor 6 | import torch.nn as nn 7 | 8 | def binaray_entropy(prob, eps=1e-7): 9 | neg_entro = prob * prob.clamp(min=eps).log() + (1-prob) * (1-prob).clamp(min=eps).log() 10 | return - neg_entro 11 | 12 | class BgLoss(nn.Module): 13 | def __init__(self, base_criterion): 14 | super().__init__() 15 | self.base_criterion = base_criterion 16 | 17 | def forward(self, outputs, y) : 18 | assert isinstance(outputs, tuple) and len(outputs) == 2, 'err {} {}'.format(type(outputs), len(outputs)) 19 | cls_pred, bg_pred = outputs 20 | if y.dim() == 2 : 21 | bsize, c = y.shape 22 | y_min = y.min(-1, keepdim=True)[0] 23 | y = torch.cat([y_min, y], dim=1) 24 | y_bg = torch.cat([y.new_ones(bsize,1), y.new_ones(bsize,c) * y_min], dim=1) 25 | y = nn.functional.normalize(y, 1, -1) 26 | y_bg = nn.functional.normalize(y_bg, 1, -1) 27 | else : 28 | y = y+1 29 | y_bg = y.new_zeros(y.shape) 30 | base_loss = self.base_criterion(cls_pred, y) 31 | bg_loss = self.base_criterion(bg_pred, y_bg) 32 | 33 | return (base_loss + bg_loss).mean() 34 | 35 | 36 | class AdaHeadLoss(nn.Module): 37 | def __init__(self, base_criterion, target_ratio=0.5, head_loss_ratio=2., diverse_ratio=0.1): 38 | super().__init__() 39 | self.base_criterion = base_criterion 40 | self.target_ratio = target_ratio 41 | self.head_loss_ratio = head_loss_ratio 42 | self.diverse_ratio = diverse_ratio 43 | 44 | def forward(self, outputs, y): 45 | ''' 46 | head_select: (b, num_layers, num_head) 47 | ''' 48 | assert len(outputs) >= 2 49 | x, head_select = outputs[:2] 50 | base_loss = self.base_criterion(x, y) 51 | head_mean = head_select.mean() 52 | flops_loss = (head_mean - self.target_ratio).abs().mean() 53 | 54 | head_mean = head_select.mean(0) # (num_layers, num_head) 55 | diverse_loss = (head_mean - self.target_ratio).abs().mean() 56 | 57 | head_loss = flops_loss + self.diverse_ratio * diverse_loss 58 | 59 | loss = base_loss + self.head_loss_ratio * head_loss 60 | 61 | return loss, dict(base_loss=base_loss, head_loss=head_loss) 62 | 63 | 64 | class AdaLoss(nn.Module): 65 | def __init__(self, base_criterion, head_target_ratio=0.5, layer_target_ratio=0.5, head_loss_ratio=2.,layer_loss_ratio=2., head_diverse_ratio=0.1, layer_diverse_ratio=0.1, 66 | head_entropy_weight=0.1, layer_entropy_weight=0.1, 67 | head_minimal_weight=0., head_minimal=0., 68 | layer_minimal_weight=0., layer_minimal=0., 69 | token_target_ratio=0.5, token_loss_ratio=2., token_minimal=0.1, token_minimal_weight=1.): 70 | super().__init__() 71 | self.base_criterion = base_criterion 72 | self.head_target_ratio = head_target_ratio 73 | self.layer_target_ratio = layer_target_ratio 74 | 75 | self.head_loss_ratio = head_loss_ratio 76 | self.layer_loss_ratio = layer_loss_ratio 77 | 78 | self.head_diverse_ratio = head_diverse_ratio 79 | self.layer_diverse_ratio = layer_diverse_ratio 80 | 81 | self.head_entropy_weight = head_entropy_weight 82 | self.layer_entropy_weight = layer_entropy_weight 83 | 84 | self.head_minimal_weight = head_minimal_weight 85 | self.head_minimal = head_minimal 86 | 87 | self.layer_minimal_weight = layer_minimal_weight 88 | self.layer_minimal = layer_minimal 89 | 90 | self.token_target_ratio = token_target_ratio 91 | self.token_loss_ratio = token_loss_ratio 92 | self.token_minimal = token_minimal 93 | self.token_minimal_weight = token_minimal_weight 94 | 95 | def forward(self, outputs, y): 96 | ''' 97 | head_select: (b, num_layers, num_head) 98 | ''' 99 | assert len(outputs) >= 3 100 | x, head_select, layer_select, token_select = outputs[:4] 101 | logits_set = outputs[-1] 102 | 103 | base_loss = self.base_criterion(x, y) 104 | layer_loss = self._get_layer_loss(x, layer_select, logits_set) 105 | head_loss = self._get_head_loss(x, head_select, logits_set, layer_select) 106 | token_loss = self._get_token_loss(x, token_select) 107 | 108 | loss = base_loss + self.head_loss_ratio * head_loss + self.layer_loss_ratio * layer_loss + self.token_loss_ratio * token_loss 109 | 110 | return loss, dict(base_loss=base_loss, head_loss=head_loss, layer_loss=layer_loss, token_loss=token_loss) 111 | 112 | def _get_head_loss(self, x, head_select, logits_set, layer_select): 113 | eps = 1e-6 114 | if head_select is not None : 115 | if layer_select is not None : 116 | block_select = layer_select.sum(-1, keepdim=True) 117 | head_select_mask = (block_select > 0).type_as(block_select) 118 | head_select_mask = head_select_mask.expand(-1,-1,head_select.shape[-1]) 119 | assert head_select.shape == head_select_mask.shape 120 | else : 121 | head_select_mask = head_select.new_ones(head_select.shape) 122 | head_mean = (head_select * head_select_mask).sum() / (head_select_mask.sum() + eps) 123 | head_flops_loss = (head_mean - self.head_target_ratio).abs().mean() 124 | 125 | if self.head_diverse_ratio > 0 : 126 | # head_mean = head_select.mean(0) # (num_layers, num_head) 127 | head_mean = (head_select * head_select_mask).sum(0) / (head_select_mask.sum(0) + eps) 128 | head_diverse_loss = (head_mean - self.head_target_ratio).abs().mean() 129 | else : 130 | head_diverse_loss = 0 131 | 132 | if self.head_minimal_weight > 0 : 133 | # head_per_layer = head_select.sum(-1) #(b, num_layers) 134 | # head_minimal_loss = (1 - head_per_layer).clamp(min=0.).sum(-1).mean() 135 | # head_mean = head_select.mean(0) # (num_layers, num_head) 136 | head_mean = (head_select * head_select_mask).sum(0) / (head_select_mask.sum(0) + eps) 137 | head_minimal_loss = (self.head_minimal - head_mean).clamp(min=0.).sum() 138 | else : 139 | head_minimal_loss = 0 140 | 141 | if self.head_entropy_weight > 0 : 142 | head_select_logits = logits_set['head_select_logits'] 143 | head_entropy = binaray_entropy(head_select_logits.sigmoid()).mean() 144 | else : 145 | head_entropy = 0 146 | 147 | head_loss = head_flops_loss + self.head_diverse_ratio * head_diverse_loss - self.head_entropy_weight * head_entropy \ 148 | + self.head_minimal_weight * head_minimal_loss 149 | else : 150 | head_loss = x.new_zeros(1).mean() 151 | 152 | return head_loss 153 | 154 | def _get_layer_loss(self, x, layer_select, logits_set): 155 | if layer_select is not None : 156 | layer_mean = layer_select.mean() 157 | layer_flops_loss = (layer_mean - self.layer_target_ratio).abs().mean() 158 | 159 | if self.layer_diverse_ratio > 0 : 160 | layer_mean = layer_select.mean((0,-1)) 161 | layer_diverse_loss = (layer_mean - self.layer_target_ratio).abs().mean() 162 | else : 163 | layer_diverse_loss = 0 164 | 165 | if self.layer_entropy_weight > 0 : 166 | layer_select_logits = logits_set['layer_select_logits'] 167 | layer_entropy = binaray_entropy(layer_select_logits.sigmoid()).mean() 168 | else : 169 | layer_entropy = 0 170 | 171 | if self.layer_minimal_weight > 0 : 172 | layer_mean = layer_select.mean(0) #(num_layers, 2) 173 | layer_minimal_loss = (self.layer_minimal - layer_mean).clamp(min=0.).sum() 174 | else : 175 | layer_minimal_loss = 0 176 | 177 | layer_loss = layer_flops_loss + self.layer_diverse_ratio * layer_diverse_loss - self.layer_entropy_weight * layer_entropy \ 178 | + self.layer_minimal_weight * layer_minimal_loss 179 | else : 180 | layer_loss = x.new_zeros(1).mean() 181 | 182 | return layer_loss 183 | 184 | def _get_token_loss(self, x, token_select): 185 | """ 186 | token_select : tensor (b, num_layer, l) 187 | 188 | """ 189 | if token_select is not None : 190 | token_mean = token_select.mean() 191 | # token_flops_loss = (token_mean - self.token_target_ratio).abs().mean() 192 | # token_flops_loss = (token_mean - self.token_target_ratio).clamp(min=0.).mean() 193 | token_flops_loss = ((token_mean - self.token_target_ratio)**2).mean() 194 | 195 | if self.token_minimal_weight > 0 : 196 | token_mean = token_select.mean(-1) 197 | token_minimal_loss = (self.token_minimal - token_mean).clamp(min=0.).sum() 198 | else : 199 | token_minimal_loss = 0 200 | 201 | token_loss = token_flops_loss + self.token_minimal_weight * token_minimal_loss 202 | else : 203 | token_loss = x.new_zeros(1).mean() 204 | 205 | return token_loss 206 | 207 | def soft_cross_entropy(predicts, targets): 208 | student_likelihood = torch.nn.functional.log_softmax(predicts, dim=-1) 209 | targets_prob = torch.nn.functional.softmax(targets, dim=-1) 210 | return (- targets_prob * student_likelihood).sum(-1).mean() 211 | 212 | # TODO : hard or soft distill loss 213 | class TeacherLoss(nn.Module): 214 | def __init__(self, teacher_model, base_criterion, kd_ratio=1., tau=5., attn_ratio=.5, hidden_ratio=.1, pred_ratio=.5, keep_layers=1): 215 | super().__init__() 216 | self.kd_ratio = kd_ratio 217 | print('self tau', tau) 218 | self.tau = tau 219 | self.base_criterion = base_criterion 220 | self.teacher_model = teacher_model 221 | self.mse_loss = nn.MSELoss(reduction='none') 222 | self.l1_loss = nn.L1Loss(reduction='none') 223 | self.attn_ratio = attn_ratio 224 | self.hidden_ratio = hidden_ratio 225 | # self.layer_ratio = layer_ratio 226 | self.pred_ratio = pred_ratio 227 | self.teacher_model.eval() 228 | self.keep_layers = keep_layers 229 | 230 | def forward(self, x, outputs, y): 231 | assert len(outputs) >= 5 232 | logits, head_select, layer_select, token_select, attn_list, hidden_list = outputs[:6] 233 | base_loss, meta_loss = self.base_criterion(outputs, y) 234 | 235 | with torch.no_grad(): 236 | logits_teacher, attn_list_teacher, hidden_list_teacher = self.teacher_model(x, ret_attn_list=True) 237 | 238 | attn_loss = x.new_zeros(1).mean() 239 | hidden_loss = x.new_zeros(1).mean() 240 | if head_select is not None : 241 | head_select_start = len(attn_list) - head_select.shape[1] 242 | for i, (attn_s, attn_t, hidden_s, hidden_t) in enumerate(zip(attn_list, attn_list_teacher, hidden_list, hidden_list_teacher)) : 243 | assert attn_s.dim() == 4 and attn_t.dim() ==4 244 | if i >= self.keep_layers and layer_select is not None : 245 | this_select = layer_select[:,i-self.keep_layers].detach() #(b,2) 246 | this_select = (this_select > 0.5).float().unsqueeze(-1) 247 | attn_select = this_select[:,0] 248 | hidden_select = this_select[:,1] 249 | else : 250 | this_select = 1. 251 | attn_select = 1. 252 | hidden_select = 1. 253 | 254 | if head_select is not None and i >= head_select_start : 255 | attn_mask = head_select.detach()[:,i-head_select_start] #(b,head,1) 256 | if attn_mask.dim() == 2 : 257 | attn_mask = attn_mask.unsqueeze(-1) 258 | else : 259 | attn_mask = 1 260 | 261 | attn_s = attn_s[:,:,0] * attn_mask #(b,head,len) 262 | # TODO : normalize teacher attn 263 | attn_t = attn_t[:,:,0] * attn_mask 264 | hidden_s = hidden_s[:,0] #(b, c) 265 | hidden_t = hidden_t[:,0] 266 | # attn_s = torch.where(attn_s <=1e-2, attn_s.new_zeros(attn_s.shape), attn_s) 267 | # attn_t = torch.where(attn_t <=1e-2, attn_t.new_zeros(attn_t.shape), attn_t) 268 | # TODO : how to design attn loss 269 | if self.attn_ratio > 0 : 270 | this_attn_loss = self.l1_loss(attn_s, attn_t).sum(-1) 271 | this_attn_loss = (attn_select * this_attn_loss).mean() 272 | attn_loss +=this_attn_loss 273 | if self.hidden_ratio > 0 : 274 | this_hidden_loss = self.mse_loss(hidden_s, hidden_t) 275 | this_hidden_loss = (hidden_select * this_hidden_loss).mean() 276 | hidden_loss += this_hidden_loss 277 | 278 | if self.pred_ratio > 0 : 279 | T = self.tau 280 | pred_loss = soft_cross_entropy(logits/T, logits_teacher/T) 281 | else : 282 | pred_loss = 0 283 | 284 | loss = base_loss + self.attn_ratio*attn_loss+ self.hidden_ratio*hidden_loss + self.pred_ratio*pred_loss 285 | meta_loss.update(attn_loss=attn_loss, hidden_loss=hidden_loss, pred_loss=pred_loss) 286 | return loss, meta_loss -------------------------------------------------------------------------------- /models/token_performer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Borrow from t2t-vit(https://github.com/yitu-opensource/T2T-ViT 3 | Take Performer as T2T Transformer 4 | """ 5 | import math 6 | import torch 7 | import torch.nn as nn 8 | 9 | class Token_performer(nn.Module): 10 | def __init__(self, dim, in_dim, head_cnt=1, kernel_ratio=0.5, dp1=0.1, dp2 = 0.1): 11 | super().__init__() 12 | self.emb = in_dim * head_cnt # we use 1, so it is no need here 13 | self.kqv = nn.Linear(dim, 3 * self.emb) 14 | self.dp = nn.Dropout(dp1) 15 | self.proj = nn.Linear(self.emb, self.emb) 16 | self.head_cnt = head_cnt 17 | self.norm1 = nn.LayerNorm(dim) 18 | self.norm2 = nn.LayerNorm(self.emb) 19 | self.epsilon = 1e-8 # for stable in division 20 | 21 | self.mlp = nn.Sequential( 22 | nn.Linear(self.emb, 1 * self.emb), 23 | nn.GELU(), 24 | nn.Linear(1 * self.emb, self.emb), 25 | nn.Dropout(dp2), 26 | ) 27 | 28 | self.m = int(self.emb * kernel_ratio) 29 | self.w = torch.randn(self.m, self.emb) 30 | self.w = nn.Parameter(nn.init.orthogonal_(self.w) * math.sqrt(self.m), requires_grad=False) 31 | 32 | def prm_exp(self, x): 33 | # part of the function is borrow from https://github.com/lucidrains/performer-pytorch 34 | # and Simo Ryu (https://github.com/cloneofsimo) 35 | # ==== positive random features for gaussian kernels ==== 36 | # x = (B, T, hs) 37 | # w = (m, hs) 38 | # return : x : B, T, m 39 | # SM(x, y) = E_w[exp(w^T x - |x|/2) exp(w^T y - |y|/2)] 40 | # therefore return exp(w^Tx - |x|/2)/sqrt(m) 41 | xd = ((x * x).sum(dim=-1, keepdim=True)).repeat(1, 1, self.m) / 2 42 | wtx = torch.einsum('bti,mi->btm', x.float(), self.w) 43 | 44 | return torch.exp(wtx - xd) / math.sqrt(self.m) 45 | 46 | def single_attn(self, x): 47 | k, q, v = torch.split(self.kqv(x), self.emb, dim=-1) 48 | kp, qp = self.prm_exp(k), self.prm_exp(q) # (B, T, m), (B, T, m) 49 | D = torch.einsum('bti,bi->bt', qp, kp.sum(dim=1)).unsqueeze(dim=2) # (B, T, m) * (B, m) -> (B, T, 1) 50 | kptv = torch.einsum('bin,bim->bnm', v.float(), kp) # (B, emb, m) 51 | y = torch.einsum('bti,bni->btn', qp, kptv) / (D.repeat(1, 1, self.emb) + self.epsilon) # (B, T, emb)/Diag 52 | # skip connection 53 | y = v + self.dp(self.proj(y)) # same as token_transformer in T2T layer, use v as skip connection 54 | 55 | return y 56 | 57 | def forward(self, x): 58 | x = self.single_attn(self.norm1(x)) 59 | x = x + self.mlp(self.norm2(x)) 60 | return x 61 | 62 | -------------------------------------------------------------------------------- /models/token_transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Borrow from t2t-vit(https://github.com/yitu-opensource/T2T-ViT) 3 | Take the standard Transformer as T2T Transformer 4 | """ 5 | import torch.nn as nn 6 | from timm.models.layers import DropPath 7 | from .transformer_block import Mlp 8 | 9 | class Attention(nn.Module): 10 | def __init__(self, dim, num_heads=8, in_dim = None, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 11 | super().__init__() 12 | self.num_heads = num_heads 13 | self.in_dim = in_dim 14 | head_dim = dim // num_heads 15 | self.scale = qk_scale or head_dim ** -0.5 16 | 17 | self.qkv = nn.Linear(dim, in_dim * 3, bias=qkv_bias) 18 | self.attn_drop = nn.Dropout(attn_drop) 19 | self.proj = nn.Linear(in_dim, in_dim) 20 | self.proj_drop = nn.Dropout(proj_drop) 21 | 22 | def forward(self, x): 23 | B, N, C = x.shape 24 | 25 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.in_dim).permute(2, 0, 3, 1, 4) 26 | q, k, v = qkv[0], qkv[1], qkv[2] 27 | 28 | attn = (q * self.scale) @ k.transpose(-2, -1) 29 | attn = attn.softmax(dim=-1) 30 | attn = self.attn_drop(attn) 31 | 32 | x = (attn @ v).transpose(1, 2).reshape(B, N, self.in_dim) 33 | x = self.proj(x) 34 | x = self.proj_drop(x) 35 | 36 | # skip connection 37 | x = v.squeeze(1) + x # because the original x has different size with current x, use v to do skip connection 38 | 39 | return x 40 | 41 | class Token_transformer(nn.Module): 42 | 43 | def __init__(self, dim, in_dim, num_heads, mlp_ratio=1., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 44 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 45 | super().__init__() 46 | self.norm1 = norm_layer(dim) 47 | self.attn = Attention( 48 | dim, in_dim=in_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 49 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 50 | self.norm2 = norm_layer(in_dim) 51 | self.mlp = Mlp(in_features=in_dim, hidden_features=int(in_dim*mlp_ratio), out_features=in_dim, act_layer=act_layer, drop=drop) 52 | 53 | def forward(self, x): 54 | x = self.attn(self.norm1(x)) 55 | x = x + self.drop_path(self.mlp(self.norm2(x))) 56 | return x 57 | -------------------------------------------------------------------------------- /models/transformer_block.py: -------------------------------------------------------------------------------- 1 | """ 2 | Borrow from timm(https://github.com/rwightman/pytorch-image-models) 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | from timm.models.layers import DropPath 8 | 9 | class Mlp(nn.Module): 10 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 11 | super().__init__() 12 | out_features = out_features or in_features 13 | hidden_features = hidden_features or in_features 14 | self.fc1 = nn.Linear(in_features, hidden_features) 15 | self.act = act_layer() 16 | self.fc2 = nn.Linear(hidden_features, out_features) 17 | self.drop = nn.Dropout(drop) 18 | 19 | def forward(self, x): 20 | x = self.fc1(x) 21 | x = self.act(x) 22 | x = self.drop(x) 23 | x = self.fc2(x) 24 | x = self.drop(x) 25 | return x 26 | 27 | class Attention(nn.Module): 28 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 29 | super().__init__() 30 | self.num_heads = num_heads 31 | head_dim = dim // num_heads 32 | 33 | self.scale = qk_scale or head_dim ** -0.5 34 | 35 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 36 | self.attn_drop = nn.Dropout(attn_drop) 37 | self.proj = nn.Linear(dim, dim) 38 | self.proj_drop = nn.Dropout(proj_drop) 39 | 40 | def forward(self, x, ret_attn=False): 41 | B, N, C = x.shape 42 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 43 | q, k, v = qkv[0], qkv[1], qkv[2] 44 | 45 | attn = (q @ k.transpose(-2, -1)) * self.scale 46 | attn = attn.softmax(dim=-1) 47 | attn = self.attn_drop(attn) 48 | 49 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 50 | x = self.proj(x) 51 | x = self.proj_drop(x) 52 | if ret_attn : 53 | return x, attn 54 | return x 55 | 56 | class Block(nn.Module): 57 | 58 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 59 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 60 | super().__init__() 61 | if norm_layer == nn.GroupNorm : 62 | self.norm1 = norm_layer(num_heads, dim) 63 | self.norm2 = norm_layer(num_heads, dim) 64 | else: 65 | self.norm1 = norm_layer(dim) 66 | self.norm2 = norm_layer(dim) 67 | self.attn = Attention( 68 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 69 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 70 | mlp_hidden_dim = int(dim * mlp_ratio) 71 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 72 | 73 | def forward(self, x, ret_attn=False): 74 | x_attn, attn = self.attn(self.norm1(x), ret_attn=ret_attn) 75 | x = x + self.drop_path(x_attn) 76 | x = x + self.drop_path(self.mlp(self.norm2(x))) 77 | if ret_attn : 78 | return x, attn 79 | return x 80 | 81 | 82 | def get_sinusoid_encoding(n_position, d_hid): 83 | ''' Sinusoid position encoding table ''' 84 | 85 | def get_position_angle_vec(position): 86 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] 87 | 88 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) 89 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 90 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 91 | 92 | return torch.FloatTensor(sinusoid_table).unsqueeze(0) 93 | -------------------------------------------------------------------------------- /saver.py: -------------------------------------------------------------------------------- 1 | """ Checkpoint Saver 2 | 3 | Track top-n training checkpoints and maintain recovery checkpoints on specified intervals. 4 | 5 | Hacked together by / Copyright 2020 Ross Wightman 6 | """ 7 | 8 | import glob 9 | import operator 10 | import os 11 | import logging 12 | 13 | import torch 14 | 15 | from timm.utils.model import unwrap_model, get_state_dict 16 | from shutil import copyfile 17 | 18 | _logger = logging.getLogger(__name__) 19 | 20 | 21 | class MyCheckpointSaver: 22 | def __init__( 23 | self, 24 | model, 25 | optimizer, 26 | args=None, 27 | model_ema=None, 28 | amp_scaler=None, 29 | checkpoint_prefix='checkpoint', 30 | recovery_prefix='recovery', 31 | checkpoint_dir='', 32 | recovery_dir='', 33 | decreasing=False, 34 | max_history=10, 35 | unwrap_fn=unwrap_model): 36 | 37 | # objects to save state_dicts of 38 | self.model = model 39 | self.optimizer = optimizer 40 | self.args = args 41 | self.model_ema = model_ema 42 | self.amp_scaler = amp_scaler 43 | 44 | # state 45 | self.checkpoint_files = [] # (filename, metric) tuples in order of decreasing betterness 46 | self.best_epoch = None 47 | self.best_metric = None 48 | self.curr_recovery_file = '' 49 | self.last_recovery_file = '' 50 | 51 | # config 52 | self.checkpoint_dir = checkpoint_dir 53 | self.recovery_dir = recovery_dir 54 | self.save_prefix = checkpoint_prefix 55 | self.recovery_prefix = recovery_prefix 56 | self.extension = '.pth.tar' 57 | self.decreasing = decreasing # a lower metric is better if True 58 | self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs 59 | self.max_history = max_history 60 | self.unwrap_fn = unwrap_fn 61 | assert self.max_history >= 1 62 | 63 | def save_checkpoint(self, epoch, metric=None): 64 | assert epoch >= 0 65 | tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension) 66 | last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension) 67 | self._save(tmp_save_path, epoch, metric) 68 | if os.path.exists(last_save_path): 69 | os.remove(last_save_path) # required for Windows support. 70 | os.rename(tmp_save_path, last_save_path) 71 | worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None 72 | if (len(self.checkpoint_files) < self.max_history 73 | or metric is None or self.cmp(metric, worst_file[1])): 74 | if len(self.checkpoint_files) >= self.max_history: 75 | self._cleanup_checkpoints(1) 76 | filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension 77 | save_path = os.path.join(self.checkpoint_dir, filename) 78 | copyfile(last_save_path, save_path) 79 | self.checkpoint_files.append((save_path, metric)) 80 | self.checkpoint_files = sorted( 81 | self.checkpoint_files, key=lambda x: x[1], 82 | reverse=not self.decreasing) # sort in descending order if a lower metric is not better 83 | 84 | checkpoints_str = "Current checkpoints:\n" 85 | for c in self.checkpoint_files: 86 | checkpoints_str += ' {}\n'.format(c) 87 | _logger.info(checkpoints_str) 88 | 89 | if metric is not None and (self.best_metric is None or self.cmp(metric, self.best_metric)): 90 | self.best_epoch = epoch 91 | self.best_metric = metric 92 | best_save_path = os.path.join(self.checkpoint_dir, 'model_best' + self.extension) 93 | if os.path.exists(best_save_path): 94 | os.remove(best_save_path) 95 | copyfile(last_save_path, best_save_path) 96 | 97 | return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch) 98 | 99 | def _save(self, save_path, epoch, metric=None): 100 | save_state = { 101 | 'epoch': epoch, 102 | 'arch': type(self.model).__name__.lower(), 103 | 'state_dict': get_state_dict(self.model, self.unwrap_fn), 104 | 'optimizer': self.optimizer.state_dict(), 105 | 'version': 2, # version < 2 increments epoch before save 106 | } 107 | if self.args is not None: 108 | save_state['arch'] = self.args.model 109 | save_state['args'] = self.args 110 | if self.amp_scaler is not None: 111 | save_state[self.amp_scaler.state_dict_key] = self.amp_scaler.state_dict() 112 | if self.model_ema is not None: 113 | save_state['state_dict_ema'] = get_state_dict(self.model_ema, self.unwrap_fn) 114 | if metric is not None: 115 | save_state['metric'] = metric 116 | torch.save(save_state, save_path) 117 | 118 | def _cleanup_checkpoints(self, trim=0): 119 | trim = min(len(self.checkpoint_files), trim) 120 | delete_index = self.max_history - trim 121 | if delete_index <= 0 or len(self.checkpoint_files) <= delete_index: 122 | return 123 | to_delete = self.checkpoint_files[delete_index:] 124 | for d in to_delete: 125 | try: 126 | _logger.debug("Cleaning checkpoint: {}".format(d)) 127 | os.remove(d[0]) 128 | except Exception as e: 129 | _logger.error("Exception '{}' while deleting checkpoint".format(e)) 130 | self.checkpoint_files = self.checkpoint_files[:delete_index] 131 | 132 | def save_recovery(self, epoch, batch_idx=0): 133 | assert epoch >= 0 134 | filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension 135 | save_path = os.path.join(self.recovery_dir, filename) 136 | self._save(save_path, epoch) 137 | if os.path.exists(self.last_recovery_file): 138 | try: 139 | _logger.debug("Cleaning recovery: {}".format(self.last_recovery_file)) 140 | os.remove(self.last_recovery_file) 141 | except Exception as e: 142 | _logger.error("Exception '{}' while removing {}".format(e, self.last_recovery_file)) 143 | self.last_recovery_file = self.curr_recovery_file 144 | self.curr_recovery_file = save_path 145 | 146 | def find_recovery(self): 147 | recovery_path = os.path.join(self.recovery_dir, self.recovery_prefix) 148 | files = glob.glob(recovery_path + '*' + self.extension) 149 | files = sorted(files) 150 | if len(files): 151 | return files[0] 152 | else: 153 | return '' 154 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Borrow from t2t-vit(https://github.com/yitu-opensource/T2T-ViT 3 | - load_for_transfer_learning: load pretrained paramters to model in transfer learning 4 | - get_mean_and_std: calculate the mean and std value of dataset. 5 | - msr_init: net parameter initialization. 6 | - progress_bar: progress bar mimic xlua.progress. 7 | ''' 8 | import os 9 | import sys 10 | import time 11 | import torch 12 | 13 | import torch.nn as nn 14 | import torch.nn.init as init 15 | import logging 16 | import os 17 | from collections import OrderedDict 18 | 19 | _logger = logging.getLogger(__name__) 20 | 21 | def convert_qkv(origin_dict): 22 | model_dict = OrderedDict() 23 | for k,v in origin_dict.items(): 24 | if 'qkv' in k : 25 | #bias 26 | if v.dim() == 1 : 27 | dim = v.shape[-1] // 3 28 | tmp_bias = v 29 | b_q, b_k, b_v = [None] *3 if tmp_bias is None else [tmp_bias[i*dim:(i+1)*dim] for i in range(3)] 30 | print('bias q k v', b_q.shape, b_k.shape, b_v.shape) 31 | model_dict[k.replace('qkv', 'query')] = b_q 32 | model_dict[k.replace('qkv', 'key')] = b_k 33 | model_dict[k.replace('qkv', 'value')] = b_v 34 | else : 35 | dim = v.shape[-1] 36 | tmp_weight = v 37 | w_q, w_k, w_v = [tmp_weight[i*dim:(i+1)*dim] for i in range(3)] 38 | model_dict[k.replace('qkv', 'query')] = w_q 39 | model_dict[k.replace('qkv', 'key')] = w_k 40 | model_dict[k.replace('qkv', 'value')] = w_v 41 | else : 42 | model_dict[k] = v 43 | 44 | return model_dict 45 | 46 | def ada_load_state_dict(checkpoint_path, model, use_qkv=False, strict=True): 47 | if not os.path.isfile(checkpoint_path) : 48 | _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) 49 | raise FileNotFoundError() 50 | 51 | checkpoint = torch.load(checkpoint_path, map_location='cpu') 52 | if 'state_dict_ema' in checkpoint : 53 | checkpoint = checkpoint['state_dict_ema'] 54 | elif 'state_dict' in checkpoint : 55 | checkpoint = checkpoint['state_dict'] 56 | elif 'model' in checkpoint : 57 | checkpoint = checkpoint['model'] # load deit type model 58 | 59 | if not use_qkv : 60 | checkpoint = convert_qkv(checkpoint) 61 | 62 | new_state_dict = OrderedDict() 63 | for k, v in checkpoint.items(): 64 | # strip `module.` prefix 65 | name = k[7:] if k.startswith('module') else k 66 | new_state_dict[name] = v 67 | checkpoint = new_state_dict 68 | 69 | info = model.load_state_dict(checkpoint, strict=strict) 70 | if not strict : 71 | print('state dict load info', info) 72 | 73 | 74 | def deit_load_state_dict(checkpoint_path, model, strict=True): 75 | if not os.path.isfile(checkpoint_path) : 76 | _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) 77 | raise FileNotFoundError() 78 | 79 | checkpoint = torch.load(checkpoint_path, map_location='cpu') 80 | checkpoint = checkpoint['model'] 81 | 82 | new_state_dict = OrderedDict() 83 | for k, v in checkpoint.items(): 84 | # strip `module.` prefix 85 | name = k[7:] if k.startswith('module') else k 86 | new_state_dict[name] = v 87 | checkpoint = new_state_dict 88 | 89 | info = model.load_state_dict(checkpoint, strict=strict) 90 | if not strict : 91 | print('state dict load info', info) 92 | 93 | 94 | def resize_pos_embed(posemb, posemb_new): # example: 224:(14x14+1)-> 384: (24x24+1) 95 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from 96 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 97 | ntok_new = posemb_new.shape[1] 98 | if True: 99 | posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] # posemb_tok is for cls token, posemb_grid for the following tokens 100 | ntok_new -= 1 101 | else: 102 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0] 103 | gs_old = int(math.sqrt(len(posemb_grid))) # 14 104 | gs_new = int(math.sqrt(ntok_new)) # 24 105 | _logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new) 106 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) # [1, 196, dim]->[1, 14, 14, dim]->[1, dim, 14, 14] 107 | posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bicubic') # [1, dim, 14, 14] -> [1, dim, 24, 24] 108 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1) # [1, dim, 24, 24] -> [1, 24*24, dim] 109 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) # [1, 24*24+1, dim] 110 | return posemb 111 | 112 | def load_state_dict(checkpoint_path, model, use_ema=False, num_classes=1000, del_posemb=False): 113 | if checkpoint_path and os.path.isfile(checkpoint_path): 114 | checkpoint = torch.load(checkpoint_path, map_location='cpu') 115 | state_dict_key = 'state_dict' 116 | if isinstance(checkpoint, dict): 117 | if use_ema and 'state_dict_ema' in checkpoint: 118 | state_dict_key = 'state_dict_ema' 119 | if state_dict_key and state_dict_key in checkpoint: 120 | new_state_dict = OrderedDict() 121 | for k, v in checkpoint[state_dict_key].items(): 122 | # strip `module.` prefix 123 | name = k[7:] if k.startswith('module') else k 124 | new_state_dict[name] = v 125 | state_dict = new_state_dict 126 | else: 127 | state_dict = checkpoint 128 | _logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path)) 129 | if num_classes != 1000: 130 | # completely discard fully connected for all other differences between pretrained and created model 131 | del state_dict['head' + '.weight'] 132 | del state_dict['head' + '.bias'] 133 | 134 | if del_posemb==True: 135 | del state_dict['pos_embed'] 136 | 137 | old_posemb = state_dict['pos_embed'] 138 | if model.pos_embed.shape != old_posemb.shape: # need resize the position embedding by interpolate 139 | new_posemb = resize_pos_embed(old_posemb, model.pos_embed) 140 | state_dict['pos_embed'] = new_posemb 141 | 142 | return state_dict 143 | else: 144 | _logger.error("No checkpoint found at '{}'".format(checkpoint_path)) 145 | raise FileNotFoundError() 146 | 147 | 148 | 149 | def load_for_transfer_learning(model, checkpoint_path, use_ema=False, strict=True, num_classes=1000): 150 | state_dict = load_state_dict(checkpoint_path, use_ema, num_classes) 151 | model.load_state_dict(state_dict, strict=strict) 152 | 153 | 154 | def get_mean_and_std(dataset): 155 | '''Compute the mean and std value of dataset.''' 156 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) 157 | mean = torch.zeros(3) 158 | std = torch.zeros(3) 159 | print('==> Computing mean and std..') 160 | for inputs, targets in dataloader: 161 | for i in range(3): 162 | mean[i] += inputs[:,i,:,:].mean() 163 | std[i] += inputs[:,i,:,:].std() 164 | mean.div_(len(dataset)) 165 | std.div_(len(dataset)) 166 | return mean, std 167 | 168 | def init_params(net): 169 | '''Init layer parameters.''' 170 | for m in net.modules(): 171 | if isinstance(m, nn.Conv2d): 172 | init.kaiming_normal(m.weight, mode='fan_out') 173 | if m.bias: 174 | init.constant(m.bias, 0) 175 | elif isinstance(m, nn.BatchNorm2d): 176 | init.constant(m.weight, 1) 177 | init.constant(m.bias, 0) 178 | elif isinstance(m, nn.Linear): 179 | init.normal(m.weight, std=1e-3) 180 | if m.bias: 181 | init.constant(m.bias, 0) 182 | 183 | 184 | # _, term_width = os.popen('stty size', 'r').read().split() 185 | # term_width = int(term_width) 186 | term_width = 100 # since we're not using transfer_learning.py or progress bar, just set a constant to make no-terminal job submission happy 187 | 188 | TOTAL_BAR_LENGTH = 65. 189 | last_time = time.time() 190 | begin_time = last_time 191 | def progress_bar(current, total, msg=None): 192 | global last_time, begin_time 193 | if current == 0: 194 | begin_time = time.time() # Reset for new bar. 195 | 196 | cur_len = int(TOTAL_BAR_LENGTH*current/total) 197 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 198 | 199 | sys.stdout.write(' [') 200 | for i in range(cur_len): 201 | sys.stdout.write('=') 202 | sys.stdout.write('>') 203 | for i in range(rest_len): 204 | sys.stdout.write('.') 205 | sys.stdout.write(']') 206 | 207 | cur_time = time.time() 208 | step_time = cur_time - last_time 209 | last_time = cur_time 210 | tot_time = cur_time - begin_time 211 | 212 | L = [] 213 | L.append(' Step: %s' % format_time(step_time)) 214 | L.append(' | Tot: %s' % format_time(tot_time)) 215 | if msg: 216 | L.append(' | ' + msg) 217 | 218 | msg = ''.join(L) 219 | sys.stdout.write(msg) 220 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): 221 | sys.stdout.write(' ') 222 | 223 | # Go back to the center of the bar. 224 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): 225 | sys.stdout.write('\b') 226 | sys.stdout.write(' %d/%d ' % (current+1, total)) 227 | 228 | if current < total-1: 229 | sys.stdout.write('\r') 230 | else: 231 | sys.stdout.write('\n') 232 | sys.stdout.flush() 233 | 234 | def format_time(seconds): 235 | days = int(seconds / 3600/24) 236 | seconds = seconds - days*3600*24 237 | hours = int(seconds / 3600) 238 | seconds = seconds - hours*3600 239 | minutes = int(seconds / 60) 240 | seconds = seconds - minutes*60 241 | secondsf = int(seconds) 242 | seconds = seconds - secondsf 243 | millis = int(seconds*1000) 244 | 245 | f = '' 246 | i = 1 247 | if days > 0: 248 | f += str(days) + 'D' 249 | i += 1 250 | if hours > 0 and i <= 2: 251 | f += str(hours) + 'h' 252 | i += 1 253 | if minutes > 0 and i <= 2: 254 | f += str(minutes) + 'm' 255 | i += 1 256 | if secondsf > 0 and i <= 2: 257 | f += str(secondsf) + 's' 258 | i += 1 259 | if millis > 0 and i <= 2: 260 | f += str(millis) + 'ms' 261 | i += 1 262 | if f == '': 263 | f = '0ms' 264 | return f 265 | --------------------------------------------------------------------------------