├── .gitignore
├── README.md
├── ada_main.py
├── adavit_ckpt
├── deit-s-h-l-tmlp_flops_dict.pth
└── t2t-19-h-l-tmlp_flops_dict.pth
├── assets
└── adavit_approach.png
├── block_flops_dict.py
├── models
├── __init__.py
├── ada_t2t_vit.py
├── ada_transformer_block.py
├── deit.py
├── losses.py
├── token_performer.py
├── token_transformer.py
└── transformer_block.py
├── saver.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | test_case.py
2 | summary.csv
3 | output/
4 | output_ada/
5 | data/
6 |
7 | # Byte-compiled / optimized / DLL files
8 | __pycache__/
9 | *.py[cod]
10 | *$py.class
11 |
12 | # C extensions
13 | *.so
14 |
15 | # Distribution / packaging
16 | .Python
17 | build/
18 | develop-eggs/
19 | dist/
20 | downloads/
21 | eggs/
22 | .eggs/
23 | lib/
24 | lib64/
25 | parts/
26 | sdist/
27 | var/
28 | wheels/
29 | pip-wheel-metadata/
30 | share/python-wheels/
31 | *.egg-info/
32 | .installed.cfg
33 | *.egg
34 | MANIFEST
35 |
36 | # PyInstaller
37 | # Usually these files are written by a python script from a template
38 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
39 | *.manifest
40 | *.spec
41 |
42 | # Installer logs
43 | pip-log.txt
44 | pip-delete-this-directory.txt
45 |
46 | # Unit test / coverage reports
47 | htmlcov/
48 | .tox/
49 | .nox/
50 | .coverage
51 | .coverage.*
52 | .cache
53 | nosetests.xml
54 | coverage.xml
55 | *.cover
56 | *.py,cover
57 | .hypothesis/
58 | .pytest_cache/
59 |
60 | # Translations
61 | *.mo
62 | *.pot
63 |
64 | # Django stuff:
65 | *.log
66 | local_settings.py
67 | db.sqlite3
68 | db.sqlite3-journal
69 |
70 | # Flask stuff:
71 | instance/
72 | .webassets-cache
73 |
74 | # Scrapy stuff:
75 | .scrapy
76 |
77 | # Sphinx documentation
78 | docs/_build/
79 |
80 | # PyBuilder
81 | target/
82 |
83 | # Jupyter Notebook
84 | .ipynb_checkpoints
85 |
86 | # IPython
87 | profile_default/
88 | ipython_config.py
89 |
90 | # pyenv
91 | .python-version
92 |
93 | # pipenv
94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
97 | # install all needed dependencies.
98 | #Pipfile.lock
99 |
100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
101 | __pypackages__/
102 |
103 | # Celery stuff
104 | celerybeat-schedule
105 | celerybeat.pid
106 |
107 | # SageMath parsed files
108 | *.sage.py
109 |
110 | # Environments
111 | .env
112 | .venv
113 | env/
114 | venv/
115 | ENV/
116 | env.bak/
117 | venv.bak/
118 |
119 | # Spyder project settings
120 | .spyderproject
121 | .spyproject
122 |
123 | # Rope project settings
124 | .ropeproject
125 |
126 | # mkdocs documentation
127 | /site
128 |
129 | # mypy
130 | .mypy_cache/
131 | .dmypy.json
132 | dmypy.json
133 |
134 | # Pyre type checker
135 | .pyre/
136 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AdaViT: Adaptive Vision Transformers for Efficient Image Recognition
2 |
3 | Lingchen Meng*1, Hengduo Li*2, Bor-Chun Chen3, Shiyi Lan2, Zuxuan Wu1, Yu-Gang Jiang1, Ser-Nam Lim3
4 | 1Shanghai Key Lab of Intelligent Information Processing, School of Computer Science, Fudan Univeristy, 2University of Maryland, 3Meta AI
5 | \* Equal contribution
6 |
7 |
8 | This repository is an official implementation of the [AdaViT](https://arxiv.org/abs/2111.15668).
9 | Our codes are based on the [pytorch-image-models](https://github.com/rwightman/pytorch-image-models) and [T2T-ViT](https://github.com/yitu-opensource/T2T-ViT)
10 |
11 | ## Abstract
12 | Built on top of self-attention mechanisms, vision transformers have demonstrated remarkable performance on a variety of tasks recently. While achieving excellent performance, they still require relatively intensive computational cost that scales up drastically as the numbers of patches, self-attention heads and transformer blocks increase. In this paper, we argue that due to the large variations among images, their
13 | need for modeling long-range dependencies between patches differ. To this end, we introduce AdaViT, an adaptive computation framework that learns to derive usage policies on which patches, self-attention heads and transformer blocks to use throughout the backbone on a per-input basis, aiming to improve inference efficiency of vision transformers with a minimal drop of accuracy for image recognition. Optimized jointly with a transformer backbone in an end-to-end manner, a light-weight decision network is attached to the backbone to produce decisions on-the-fly. Extensive experiments on ImageNet demonstrate that our method obtains more than $2\times$ improvement on efficiency compared to state-of-the-art vision transformers with only $0.8\%$ drop of accuracy, achieving good efficiency/accuracy trade-offs conditioned on different computational budgets. We further conduct quantitative and qualitative analysis on learned usage polices and provide more insights on the redundancy in vision transformers.
14 |
15 |
16 |
17 | ## Model Zoo
18 | We have put our model checkpoints here.
19 | | Model | Top1 Acc | MACs | Download|
20 | | :--- | :---: | :---: | :---: |
21 | | Ada-T2T-ViT-19 | 81.1 | 3.9G | [link](https://drive.google.com/file/d/1bSh9E2HDM66L5FAbrTqh6slSWaYPeKNR/view?usp=sharing)|
22 | | Ada-DeiT-S | 77.3 | 2.3G | [link](https://drive.google.com/file/d/1vkD6w9J8sf64IPhTBgyfvsTvUlZw6TNa/view?usp=sharing)|
23 |
24 |
25 | ## Eval our model
26 | Download our AdaViT with T2T-ViT-19 from [google drive](https://drive.google.com/file/d/1bSh9E2HDM66L5FAbrTqh6slSWaYPeKNR/view?usp=sharing) and perform the command below. You can expect to get the Acc about 81.1 with 3.9 GFLOPS.
27 |
28 | ```sh
29 | python3 ada_main.py /path/to/your/imagenet \
30 | --model ada_step_t2t_vit_19_lnorm \
31 | --ada-head --ada-layer --ada-token-with-mlp \
32 | --flops-dict adavit_ckpt/t2t-19-h-l-tmlp_flops_dict.pth \
33 | --eval_checkpoint /path/to/your/checkpoint
34 |
35 | python3 ada_main.py /path/to/your/imagenet \
36 | --model ada_step_deit_small_patch16_224 \
37 | --ada-head --ada-layer --ada-token-with-mlp \
38 | --flops-dict adavit_ckpt/deit-s-h-l-tmlp_flops_dict.pth \
39 | --eval_checkpoint /path/to/your/checkpoint
40 | ```
41 |
42 | ## Citation
43 | ```
44 | @inproceedings{meng2022adavit,
45 | title={AdaViT: Adaptive Vision Transformers for Efficient Image Recognition},
46 | author={Meng, Lingchen and Li, Hengduo and Chen, Bor-Chun and Lan, Shiyi and Wu, Zuxuan and Jiang, Yu-Gang and Lim, Ser-Nam},
47 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
48 | pages={12309--12318},
49 | year={2022}
50 | }
51 | ```
52 |
--------------------------------------------------------------------------------
/ada_main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | import yaml
4 | import os
5 | import logging
6 | from collections import OrderedDict
7 | from contextlib import suppress
8 | import datetime
9 | from time import gmtime, strftime
10 | import models
11 | from utils import ada_load_state_dict
12 | from models.losses import AdaHeadLoss, AdaLoss, TeacherLoss
13 |
14 | import torch
15 | import torch.nn as nn
16 | import torchvision.utils
17 | from torch.nn.parallel import DistributedDataParallel as NativeDDP
18 |
19 | from timm.data import Dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
20 | from timm.models import load_checkpoint, create_model, resume_checkpoint, convert_splitbn_model
21 | from timm.utils import *
22 | from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
23 | from timm.optim import create_optimizer
24 | from timm.scheduler import create_scheduler
25 | from timm.utils import ApexScaler, NativeScaler
26 |
27 | import pdb
28 | from saver import MyCheckpointSaver
29 |
30 | torch.backends.cudnn.benchmark = True
31 | _logger = logging.getLogger('train')
32 |
33 | # The first arg parser parses out only the --config argument, this argument is used to
34 | # load a yaml file containing key-values that override the defaults for the main parser below
35 | config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
36 | parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
37 | help='YAML config file specifying default arguments')
38 |
39 | parser = argparse.ArgumentParser(description='T2T-ViT Training and Evaluating')
40 |
41 | parser.add_argument('--head-ratio', type=float, default=2.0,
42 | help='')
43 | parser.add_argument('--layer-ratio', type=float, default=2.0,
44 | help='')
45 | parser.add_argument('--attn-ratio', type=float, default=0.,
46 | help='')
47 | parser.add_argument('--hidden-ratio', type=float, default=0.,
48 | help='')
49 | parser.add_argument('--pred-ratio', type=float, default=0.,
50 | help='')
51 | parser.add_argument('--head-target-ratio', type=float, default=0.5,
52 | help='')
53 | parser.add_argument('--layer-target-ratio', type=float, default=0.5,
54 | help='')
55 | parser.add_argument('--head-diverse-ratio', type=float, default=0.,
56 | help='')
57 | parser.add_argument('--layer-diverse-ratio', type=float, default=0.,
58 | help='')
59 | parser.add_argument('--head-select-tau', type=float, default=5.,
60 | help='')
61 | parser.add_argument('--layer-select-tau', type=float, default=5.,
62 | help='')
63 | parser.add_argument('--token-select-tau', type=float, default=5.,
64 | help='')
65 | parser.add_argument('--head-entropy-weight', type=float, default=0.,
66 | help='')
67 | parser.add_argument('--layer-entropy-weight', type=float, default=0.,
68 | help='')
69 | parser.add_argument('--head-minimal-weight', type=float, default=0.,
70 | help='')
71 | parser.add_argument('--head-minimal', type=float, default=0.,
72 | help='')
73 | parser.add_argument('--layer-minimal-weight', type=float, default=0.,
74 | help='')
75 | parser.add_argument('--layer-minimal', type=float, default=0.,
76 | help='')
77 | parser.add_argument('--token-ratio', type=float, default=2.,
78 | help='')
79 | parser.add_argument('--token-target-ratio', type=float, default=0.5,
80 | help='')
81 | parser.add_argument('--token-minimal', type=float, default=0.,
82 | help='')
83 | parser.add_argument('--token-minimal-weight', type=float, default=0.,
84 | help='')
85 |
86 | parser.add_argument('--inner-loop', type=int, default=-1,
87 | help='')
88 |
89 | # Dataset / Model parameters
90 | parser.add_argument('data', metavar='DIR',
91 | help='path to dataset')
92 | parser.add_argument('--norm-policy', action='store_true', default=False, dest='norm_policy',
93 | help='')
94 | parser.add_argument('--keep-layers', type=int, default=1,
95 | help='use layers to make selection decision')
96 | parser.add_argument('--ada-layer', action='store_true', default=False, dest='ada_layer',
97 | help='')
98 | parser.add_argument('--ada-block', action='store_true', default=False, dest='ada_block',
99 | help='')
100 | parser.add_argument('--ada-head', action='store_true', default=False, dest='ada_head',
101 | help='')
102 | parser.add_argument('--ada-head-v2', action='store_true', default=False, dest='ada_head_v2',
103 | help='')
104 | parser.add_argument('--dyna-data', action='store_true', default=False, dest='dyna_data',
105 | help='')
106 | parser.add_argument('--ada-head-attn', action='store_true', default=False, dest='ada_head_attn',
107 | help='')
108 | parser.add_argument('--head-slowfast', action='store_true', default=False, dest='head_slowfast',
109 | help='')
110 |
111 | parser.add_argument('--flops-dict', type=str, default='', dest='flops_dict',
112 | help='')
113 | parser.add_argument('--ada-token', action='store_true', default=False, dest='ada_token',
114 | help='ada-token on self-attn')
115 | parser.add_argument('--ada-token-nonstep', action='store_true', default=False, dest='ada_token_nonstep',
116 | help='using nonstep option for token selection, i.e. generate policies for all layers at once at the beginning')
117 | parser.add_argument('--ada-token-with-mlp', action='store_true', default=False, dest='ada_token_with_mlp',
118 | help='ada-token on both self-attn and ffn')
119 | parser.add_argument('--ada-token-start-layer', type=int, default=0)
120 | parser.add_argument('--ada-token-pre-softmax', action='store_true', default=True, dest='ada_token_pre_softmax')
121 | parser.add_argument('--ada-token-post-softmax', action='store_false', default=True, dest='ada_token_pre_softmax')
122 | parser.add_argument('--ada-token-detach-attn', action='store_true', default=True, dest='ada_token_detach_attn')
123 | parser.add_argument('--ada-token-no-detach-attn', action='store_false', default=True, dest='ada_token_detach_attn')
124 | parser.add_argument('--ada-token-detach-attn-at-mlp', action='store_true', default=False, dest='ada_token_detach_attn_at_mlp',
125 | help='whether detaching attn_policy at MLP, when using dynamic token selection (which overrides --ada-token-detach-attn at MLP)')
126 |
127 | parser.add_argument('--no-head-select-bias', action='store_false', default=True, dest='head_select_bias',
128 | help='')
129 | parser.add_argument('--no-layer-select-bias', action='store_false', default=True, dest='layer_select_bias',
130 | help='')
131 | parser.add_argument('--model', default='T2t_vit_14', type=str, metavar='MODEL',
132 | help='Name of model to train (default: "countception"')
133 | parser.add_argument('--freeze-bn', action='store_true', default=False,
134 | help='freeze bn in training')
135 | parser.add_argument('--pretrained', action='store_true', default=False,
136 | help='Start with pretrained version of specified network (if avail)')
137 | parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
138 | help='Initialize model from this checkpoint (default: none)')
139 | parser.add_argument('--pretrain-path', default='', type=str,
140 | help='Load pretrain file path')
141 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
142 | help='Resume full model and optimizer state from checkpoint (default: none)')
143 | parser.add_argument('--eval_checkpoint', default='', type=str, metavar='PATH',
144 | help='path to eval checkpoint (default: none)')
145 | parser.add_argument('--use-full-head', action='store_true', default=False,
146 | help='use full model param')
147 | parser.add_argument('--no-resume-opt', action='store_true', default=False,
148 | help='prevent resume of optimizer state when resuming model')
149 | parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
150 | help='number of label classes (default: 1000)')
151 | parser.add_argument('--gp', default=None, type=str, metavar='POOL',
152 | help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
153 | parser.add_argument('--img-size', type=int, default=224, metavar='N',
154 | help='Image patch size (default: None => model default)')
155 | parser.add_argument('--crop-pct', default=None, type=float,
156 | metavar='N', help='Input image center crop percent (for validation only)')
157 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
158 | help='Override mean pixel value of dataset')
159 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
160 | help='Override std deviation of of dataset')
161 | parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
162 | help='Image resize interpolation type (overrides model)')
163 | parser.add_argument('-b', '--batch-size', type=int, default=64, metavar='N',
164 | help='input batch size for training (default: 64)')
165 | parser.add_argument('-vb', '--validation-batch-size-multiplier', type=int, default=1, metavar='N',
166 | help='ratio of validation batch size to training batch size (default: 1)')
167 |
168 | # Optimizer parameters
169 | parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
170 | help='Optimizer (default: "adamw"')
171 | parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON',
172 | help='Optimizer Epsilon (default: None, use opt default)')
173 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
174 | help='Optimizer Betas (default: None, use opt default)')
175 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
176 | help='Optimizer momentum (default: 0.9)')
177 | parser.add_argument('--weight-decay', type=float, default=0.05,
178 | help='weight decay (default: 0.005 for adamw)')
179 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
180 | help='Clip gradient norm (default: None, no clipping)')
181 |
182 | # Learning rate schedule parameters
183 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
184 | help='LR scheduler (default: "cosine"')
185 | parser.add_argument('--lr', type=float, default=5e-4, metavar='LR',
186 | help='learning rate (default: 0.01)')
187 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
188 | help='learning rate noise on/off epoch percentages')
189 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
190 | help='learning rate noise limit percent (default: 0.67)')
191 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
192 | help='learning rate noise std-dev (default: 1.0)')
193 | parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
194 | help='learning rate cycle len multiplier (default: 1.0)')
195 | parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
196 | help='learning rate cycle limit')
197 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
198 | help='warmup learning rate (default: 0.0001)')
199 | parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
200 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
201 | parser.add_argument('--epochs', type=int, default=300, metavar='N',
202 | help='number of epochs to train (default: 2)')
203 | parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
204 | help='manual epoch number (useful on restarts)')
205 | parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
206 | help='epoch interval to decay LR')
207 | parser.add_argument('--warmup-epochs', type=int, default=10, metavar='N',
208 | help='epochs to warmup LR, if scheduler supports')
209 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
210 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
211 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
212 | help='patience epochs for Plateau LR scheduler (default: 10')
213 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
214 | help='LR decay rate (default: 0.1)')
215 | parser.add_argument('--ada-lr-scaling', action='store_true', default=False,
216 | help='rescale the lr of ada subnetworks')
217 | parser.add_argument('--ada-token-lr-scale', type=float, default=1.0, help='')
218 | parser.add_argument('--ada-layer-lr-scale', type=float, default=1.0, help='')
219 | parser.add_argument('--ada-head-lr-scale', type=float, default=1.0, help='')
220 |
221 | # Augmentation & regularization parameters
222 | parser.add_argument('--no-aug', action='store_true', default=False,
223 | help='Disable all training augmentation, override other train aug args')
224 | parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
225 | help='Random resize scale (default: 0.08 1.0)')
226 | parser.add_argument('--ratio', type=float, nargs='+', default=[3./4., 4./3.], metavar='RATIO',
227 | help='Random resize aspect ratio (default: 0.75 1.33)')
228 | parser.add_argument('--hflip', type=float, default=0.5,
229 | help='Horizontal flip training aug probability')
230 | parser.add_argument('--vflip', type=float, default=0.,
231 | help='Vertical flip training aug probability')
232 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
233 | help='Color jitter factor (default: 0.4)')
234 | parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
235 | help='Use AutoAugment policy. "v0" or "original". (default: None)'),
236 | parser.add_argument('--aug-splits', type=int, default=0,
237 | help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
238 | parser.add_argument('--jsd', action='store_true', default=False,
239 | help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
240 | parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
241 | help='Random erase prob (default: 0.25)')
242 | parser.add_argument('--remode', type=str, default='pixel',
243 | help='Random erase mode (default: "const")')
244 | parser.add_argument('--recount', type=int, default=1,
245 | help='Random erase count (default: 1)')
246 | parser.add_argument('--resplit', action='store_true', default=False,
247 | help='Do not random erase first (clean) augmentation split')
248 | parser.add_argument('--mixup', type=float, default=0.8,
249 | help='mixup alpha, mixup enabled if > 0. (default: 0.)')
250 | parser.add_argument('--cutmix', type=float, default=1.0,
251 | help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
252 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
253 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
254 | parser.add_argument('--mixup-prob', type=float, default=1.0,
255 | help='Probability of performing mixup or cutmix when either/both is enabled')
256 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
257 | help='Probability of switching to cutmix when both mixup and cutmix enabled')
258 | parser.add_argument('--mixup-mode', type=str, default='batch',
259 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
260 | parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
261 | help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
262 | parser.add_argument('--smoothing', type=float, default=0.1,
263 | help='Label smoothing (default: 0.1)')
264 | parser.add_argument('--train-interpolation', type=str, default='random',
265 | help='Training interpolation (random, bilinear, bicubic default: "random")')
266 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
267 | help='Dropout rate (default: 0.0)')
268 | parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
269 | help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
270 | parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
271 | help='Drop path rate (default: None)')
272 | parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
273 | help='Drop block rate (default: None)')
274 |
275 | # Batch norm parameters (only works with gen_efficientnet based models currently)
276 | parser.add_argument('--bn-tf', action='store_true', default=False,
277 | help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
278 | parser.add_argument('--bn-momentum', type=float, default=None,
279 | help='BatchNorm momentum override (if not None)')
280 | parser.add_argument('--bn-eps', type=float, default=None,
281 | help='BatchNorm epsilon override (if not None)')
282 | parser.add_argument('--sync-bn', action='store_true',
283 | help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
284 | parser.add_argument('--dist-bn', type=str, default='',
285 | help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
286 | parser.add_argument('--split-bn', action='store_true',
287 | help='Enable separate BN layers per augmentation split.')
288 |
289 | # Model Exponential Moving Average
290 | parser.add_argument('--model-ema', action='store_true', default=True,
291 | help='Enable tracking moving average of model weights')
292 | parser.add_argument('--no-model-ema', action='store_false', dest='model_ema',
293 | help='Enable tracking moving average of model weights')
294 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
295 | help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
296 | parser.add_argument('--model-ema-decay', type=float, default=0.99996,
297 | help='decay factor for model weights moving average (default: 0.9998)')
298 |
299 | # Misc
300 | parser.add_argument('--seed', type=int, default=42, metavar='S',
301 | help='random seed (default: 42)')
302 | parser.add_argument('--log-interval', type=int, default=50, metavar='N',
303 | help='how many batches to wait before logging training status')
304 | parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
305 | help='how many batches to wait before writing recovery checkpoint')
306 | parser.add_argument('-j', '--workers', type=int, default=8, metavar='N',
307 | help='how many training processes to use (default: 1)')
308 | parser.add_argument('--num-gpu', type=int, default=1,
309 | help='Number of GPUS to use')
310 | parser.add_argument('--save-images', action='store_true', default=False,
311 | help='save images of input bathes every log interval for debugging')
312 | parser.add_argument('--amp', action='store_true', default=False,
313 | help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
314 | parser.add_argument('--apex-amp', action='store_true', default=False,
315 | help='Use NVIDIA Apex AMP mixed precision')
316 | parser.add_argument('--native-amp', action='store_true', default=False,
317 | help='Use Native Torch AMP mixed precision')
318 | parser.add_argument('--channels-last', action='store_true', default=False,
319 | help='Use channels_last memory layout')
320 | parser.add_argument('--pin-mem', action='store_true', default=False,
321 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
322 | parser.add_argument('--no-prefetcher', action='store_true', default=False,
323 | help='disable fast prefetcher')
324 | parser.add_argument('--output', default='', type=str, metavar='PATH',
325 | help='path to output folder (default: none, current dir)')
326 | parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
327 | help='Best metric (default: "top1"')
328 | parser.add_argument('--tta', type=int, default=0, metavar='N',
329 | help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
330 | parser.add_argument("--local_rank", default=0, type=int)
331 | parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
332 | help='use the multi-epochs-loader to save time at the beginning of every epoch')
333 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
334 |
335 | # Random Baseline settings
336 | parser.add_argument('--random-policy', action='store_true', default=False, dest='random_policy')
337 | parser.add_argument('--random-head', action='store_true', default=False, dest='random_head')
338 | parser.add_argument('--random-layer', action='store_true', default=False, dest='random_layer')
339 | parser.add_argument('--random-token', action='store_true', default=False, dest='random_token')
340 | parser.add_argument('--eval_random_baseline', action='store_true', default=False,
341 | help='if True, evaluate random baselines given certain computational budget')
342 | parser.add_argument('--eval_random_layer', action='store_true', default=False,
343 | help='if True, test randomly generated policies for layer selection')
344 | parser.add_argument('--eval_random_head', action='store_true', default=False,
345 | help='if True, test randomly generated policies for head selection')
346 | parser.add_argument('--eval_random_token', action='store_true', default=False,
347 | help='if True, test randomly generated policies for token selection')
348 | parser.add_argument('--eval_random_layer_ratio', type=float, default=1.0,
349 | help='ratio of kept layers in random policies')
350 | parser.add_argument('--eval_random_head_ratio', type=float, default=1.0,
351 | help='ratio of kept layers in random policies')
352 | parser.add_argument('--eval_random_token_ratio', type=float, default=1.0,
353 | help='ratio of kept layers in random policies')
354 | parser.add_argument('--dev', action='store_true', default=False,
355 | help='skip some time-consuming steps when developing, e.g. loading full training set')
356 | parser.add_argument('--print-head-option', action='store_true')
357 | try:
358 | from apex import amp
359 | from apex.parallel import DistributedDataParallel as ApexDDP
360 | from apex.parallel import convert_syncbn_model
361 |
362 | has_apex = True
363 | except ImportError:
364 | has_apex = False
365 |
366 | has_native_amp = False
367 | try:
368 | if getattr(torch.cuda.amp, 'autocast') is not None:
369 | has_native_amp = True
370 | except AttributeError:
371 | pass
372 |
373 | def _parse_args():
374 | # Do we have a config file to parse?
375 | args_config, remaining = config_parser.parse_known_args()
376 | if args_config.config:
377 | with open(args_config.config, 'r') as f:
378 | cfg = yaml.safe_load(f)
379 | parser.set_defaults(**cfg)
380 |
381 | # The main arg parser parses the rest of the args, the usual
382 | # defaults will have been overridden if config file specified.
383 | args = parser.parse_args(remaining)
384 | if args.ada_block :
385 | args.ada_layer = True
386 | if args.ada_head_attn :
387 | args.ada_head = True
388 | if args.ada_token_with_mlp :
389 | args.ada_token = True
390 | if args.ada_token_nonstep:
391 | args.ada_token = True
392 | if args.ada_head_v2 :
393 | args.ada_head = True
394 |
395 | # Cache the args as a text string to save them in the output dir later
396 | args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
397 | return args, args_text
398 |
399 |
400 | def main():
401 | setup_default_logging()
402 | args, args_text = _parse_args()
403 |
404 | args.prefetcher = not args.no_prefetcher
405 | args.distributed = False
406 | if 'WORLD_SIZE' in os.environ:
407 | args.distributed = int(os.environ['WORLD_SIZE']) > 1
408 | if args.distributed and args.num_gpu > 1:
409 | _logger.warning(
410 | 'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.')
411 | args.num_gpu = 1
412 |
413 | args.device = 'cuda:0'
414 | args.world_size = 1
415 | args.rank = 0 # global rank
416 | if args.distributed:
417 | args.num_gpu = 1
418 | args.device = 'cuda:%d' % args.local_rank
419 | torch.cuda.set_device(args.local_rank)
420 | torch.distributed.init_process_group(backend='nccl', init_method='env://')
421 | args.world_size = torch.distributed.get_world_size()
422 | args.rank = torch.distributed.get_rank()
423 | assert args.rank >= 0
424 |
425 | if args.distributed:
426 | _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
427 | % (args.rank, args.world_size))
428 | else:
429 | _logger.info('Training with a single process on %d GPUs.' % args.num_gpu)
430 |
431 | torch.manual_seed(args.seed + args.rank)
432 |
433 | model = create_model(
434 | args.model,
435 | pretrained=args.pretrained,
436 | num_classes=args.num_classes,
437 | drop_rate=args.drop,
438 | drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path
439 | drop_path_rate=args.drop_path,
440 | drop_block_rate=args.drop_block,
441 | global_pool=args.gp,
442 | bn_tf=args.bn_tf,
443 | bn_momentum=args.bn_momentum,
444 | bn_eps=args.bn_eps,
445 | checkpoint_path=args.initial_checkpoint,
446 | img_size=args.img_size,
447 | keep_layers=args.keep_layers,
448 | head_select_bias=args.head_select_bias,
449 | layer_select_bias=args.layer_select_bias,
450 | norm_policy = args.norm_policy,
451 | ada_layer = args.ada_layer,
452 | ada_block = args.ada_block,
453 | ada_head = args.ada_head,
454 | ada_head_v2 = args.ada_head_v2,
455 | dyna_data = args.dyna_data,
456 | ada_token = args.ada_token,
457 | ada_token_nonstep = args.ada_token_nonstep,
458 | ada_token_with_mlp = args.ada_token_with_mlp,
459 | ada_head_attn = args.ada_head_attn,
460 | head_slowfast = args.head_slowfast,
461 | ada_token_start_layer = args.ada_token_start_layer,
462 | ada_token_pre_softmax = args.ada_token_pre_softmax,
463 | ada_token_detach_attn = args.ada_token_detach_attn,
464 | ada_token_detach_attn_at_mlp = args.ada_token_detach_attn_at_mlp,
465 | layer_select_tau = args.layer_select_tau,
466 | head_select_tau = args.head_select_tau,
467 | token_select_tau = args.token_select_tau,
468 | )
469 |
470 | def set_model_attr(model, name, value):
471 | if hasattr(model, name) :
472 | setattr(model, name, value)
473 | if args.random_policy :
474 | model.apply(lambda x : set_model_attr(x, 'random_policy', True))
475 | model.apply(lambda x : set_model_attr(x, 'random_layer_ratio', args.layer_target_ratio))
476 | model.apply(lambda x : set_model_attr(x, 'random_head_ratio', args.head_target_ratio))
477 | model.apply(lambda x : set_model_attr(x, 'random_token_ratio', args.token_target_ratio))
478 |
479 | if args.random_head :
480 | model.apply(lambda x : set_model_attr(x, 'random_head', True))
481 | model.apply(lambda x : set_model_attr(x, 'random_head_ratio', args.head_target_ratio))
482 | if args.random_layer :
483 | model.apply(lambda x : set_model_attr(x, 'random_layer', True))
484 | model.apply(lambda x : set_model_attr(x, 'random_layer_ratio', args.layer_target_ratio))
485 | if args.random_token :
486 | model.apply(lambda x : set_model_attr(x, 'random_token', True))
487 | model.apply(lambda x : set_model_attr(x, 'random_token_ratio', args.token_target_ratio))
488 |
489 | if args.pretrain_path :
490 | _logger.info('load pretrain model {}'.format(args.pretrain_path))
491 | # model.load_state_dict(torch.load(args.pretrain_path, map_location='cpu'), strict=False)
492 | ada_load_state_dict(args.pretrain_path, model, use_qkv=False, strict=False)
493 |
494 | if args.local_rank == 0:
495 | _logger.info('Model %s created, param count: %d' %
496 | (args.model, sum([m.numel() for m in model.parameters()])))
497 |
498 | data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
499 |
500 | num_aug_splits = 0
501 | if args.aug_splits > 0:
502 | assert args.aug_splits > 1, 'A split of 1 makes no sense'
503 | num_aug_splits = args.aug_splits
504 |
505 | if args.split_bn:
506 | assert num_aug_splits > 1 or args.resplit
507 | model = convert_splitbn_model(model, max(num_aug_splits, 2))
508 |
509 | use_amp = None
510 | if args.amp:
511 | # for backwards compat, `--amp` arg tries apex before native amp
512 | if has_apex:
513 | args.apex_amp = True
514 | elif has_native_amp:
515 | args.native_amp = True
516 | if args.apex_amp and has_apex:
517 | use_amp = 'apex'
518 | elif args.native_amp and has_native_amp:
519 | use_amp = 'native'
520 | elif args.apex_amp or args.native_amp:
521 | _logger.warning("Neither APEX or native Torch AMP is available, using float32. "
522 | "Install NVIDA apex or upgrade to PyTorch 1.6")
523 |
524 | if args.num_gpu > 1:
525 | if use_amp == 'apex':
526 | _logger.warning(
527 | 'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.')
528 | use_amp = None
529 | model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
530 | assert not args.channels_last, "Channels last not supported with DP, use DDP."
531 | else:
532 | model.cuda()
533 | if args.channels_last:
534 | model = model.to(memory_format=torch.channels_last)
535 |
536 | if not args.ada_lr_scaling:
537 | optimizer = create_optimizer(args, model)
538 | else:
539 | ada_params_names = ['token_select', 'head_select', 'layer_select']
540 | no_wd_params_names = list(model.no_weight_decay())
541 | model_params = [kv for kv in model.named_parameters() if kv[1].requires_grad]
542 | # ada_params_name_check = [kv[0] for kv in model.named_parameters() if any(sub in kv[0] for sub in ada_params_names)] # just for sanity check
543 | ada_params = [kv for kv in model_params if any(sub in kv[0] for sub in ada_params_names)]
544 | base_params = [kv for kv in model_params if not any(sub in kv[0] for sub in ada_params_names)]
545 |
546 | base_params_wd, base_params_no_wd = [], []
547 | ada_params_token_wd, ada_params_token_no_wd, ada_params_head_wd, ada_params_head_no_wd, ada_params_layer_wd, ada_params_layer_no_wd = [], [], [], [], [], []
548 |
549 | for name, param in base_params:
550 | if len(param.shape) == 1 or name.endswith(".bias") or name in no_wd_params_names:
551 | base_params_no_wd.append(param)
552 | else:
553 | base_params_wd.append(param)
554 | for name, param in ada_params:
555 | if 'token_select' in name:
556 | if len(param.shape) == 1 or name.endswith(".bias") or name in no_wd_params_names:
557 | ada_params_token_no_wd.append(param)
558 | else:
559 | ada_params_token_wd.append(param)
560 | elif 'head_select' in name:
561 | if len(param.shape) == 1 or name.endswith(".bias") or name in no_wd_params_names:
562 | ada_params_head_no_wd.append(param)
563 | else:
564 | ada_params_head_wd.append(param)
565 | elif 'layer_select' in name:
566 | if len(param.shape) == 1 or name.endswith(".bias") or name in no_wd_params_names:
567 | ada_params_layer_no_wd.append(param)
568 | else:
569 | ada_params_layer_wd.append(param)
570 |
571 | all_params = [
572 | {'params': ada_params_token_wd, 'lr': args.lr * args.ada_token_lr_scale, 'weight_decay': args.weight_decay},
573 | {'params': ada_params_token_no_wd, 'lr': args.lr * args.ada_token_lr_scale, 'weight_decay': 0.},
574 | {'params': ada_params_head_wd, 'lr': args.lr * args.ada_head_lr_scale, 'weight_decay': args.weight_decay},
575 | {'params': ada_params_head_no_wd, 'lr': args.lr * args.ada_head_lr_scale, 'weight_decay': 0.},
576 | {'params': ada_params_layer_wd, 'lr': args.lr * args.ada_layer_lr_scale, 'weight_decay': args.weight_decay},
577 | {'params': ada_params_layer_no_wd, 'lr': args.lr * args.ada_layer_lr_scale, 'weight_decay': 0.},
578 | {'params': base_params_wd, 'lr': args.lr, 'weight_decay': args.weight_decay},
579 | {'params': base_params_no_wd, 'lr': args.lr, 'weight_decay': 0.},
580 | ]
581 | optimizer = torch.optim.AdamW(all_params, lr=args.lr, weight_decay=args.weight_decay)
582 |
583 | amp_autocast = suppress # do nothing
584 | loss_scaler = None
585 | if use_amp == 'apex':
586 | model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
587 | loss_scaler = ApexScaler()
588 | if args.local_rank == 0:
589 | _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
590 | elif use_amp == 'native':
591 | amp_autocast = torch.cuda.amp.autocast
592 | loss_scaler = NativeScaler()
593 | if args.local_rank == 0:
594 | _logger.info('Using native Torch AMP. Training in mixed precision.')
595 | else:
596 | if args.local_rank == 0:
597 | _logger.info('AMP not enabled. Training in float32.')
598 |
599 | # optionally resume from a checkpoint
600 | resume_epoch = None
601 | if args.resume:
602 | resume_epoch = resume_checkpoint(
603 | model, args.resume,
604 | optimizer=None if args.no_resume_opt else optimizer,
605 | loss_scaler=None if args.no_resume_opt else loss_scaler,
606 | log_info=args.local_rank == 0)
607 |
608 | model_ema = None
609 | if args.model_ema:
610 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
611 | model_ema = ModelEma(
612 | model,
613 | decay=args.model_ema_decay,
614 | device='cpu' if args.model_ema_force_cpu else '',
615 | resume=args.resume)
616 |
617 | if args.distributed:
618 | if args.sync_bn:
619 | assert not args.split_bn
620 | try:
621 | if has_apex and use_amp != 'native':
622 | # Apex SyncBN preferred unless native amp is activated
623 | model = convert_syncbn_model(model)
624 | else:
625 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
626 | if args.local_rank == 0:
627 | _logger.info(
628 | 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
629 | 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
630 | except Exception as e:
631 | _logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
632 | if has_apex and use_amp != 'native':
633 | # Apex DDP preferred unless native amp is activated
634 | if args.local_rank == 0:
635 | _logger.info("Using NVIDIA APEX DistributedDataParallel.")
636 | model = ApexDDP(model, delay_allreduce=True)
637 | else:
638 | if args.local_rank == 0:
639 | _logger.info("Using native Torch DistributedDataParallel.")
640 | model = NativeDDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1
641 | # NOTE: EMA model does not need to be wrapped by DDP
642 |
643 | lr_scheduler, num_epochs = create_scheduler(args, optimizer)
644 | start_epoch = 0
645 | if args.start_epoch is not None:
646 | # a specified start_epoch will always override the resume epoch
647 | start_epoch = args.start_epoch
648 | elif resume_epoch is not None:
649 | start_epoch = resume_epoch
650 | if lr_scheduler is not None and start_epoch > 0:
651 | lr_scheduler.step(start_epoch)
652 |
653 | if args.local_rank == 0:
654 | _logger.info('Scheduled epochs: {}'.format(num_epochs))
655 |
656 | train_dir = os.path.join(args.data, 'train' if not args.dev else 'val')
657 | if not os.path.exists(train_dir):
658 | _logger.error('Training folder does not exist at: {}'.format(train_dir))
659 | exit(1)
660 | dataset_train = Dataset(train_dir)
661 |
662 | collate_fn = None
663 | mixup_fn = None
664 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
665 | print('mixup activte', mixup_active)
666 | if mixup_active:
667 | mixup_args = dict(
668 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
669 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
670 | label_smoothing=args.smoothing, num_classes=args.num_classes)
671 | if args.prefetcher:
672 | assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)
673 | collate_fn = FastCollateMixup(**mixup_args)
674 | else:
675 | mixup_fn = Mixup(**mixup_args)
676 |
677 | if num_aug_splits > 1:
678 | dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
679 |
680 | train_interpolation = args.train_interpolation
681 | if args.no_aug or not train_interpolation:
682 | train_interpolation = data_config['interpolation']
683 | loader_train = create_loader(
684 | dataset_train,
685 | input_size=data_config['input_size'],
686 | batch_size=args.batch_size,
687 | is_training=True,
688 | use_prefetcher=args.prefetcher,
689 | no_aug=args.no_aug,
690 | re_prob=args.reprob,
691 | re_mode=args.remode,
692 | re_count=args.recount,
693 | re_split=args.resplit,
694 | scale=args.scale,
695 | ratio=args.ratio,
696 | hflip=args.hflip,
697 | vflip=args.vflip,
698 | color_jitter=args.color_jitter,
699 | auto_augment=args.aa,
700 | num_aug_splits=num_aug_splits,
701 | interpolation=train_interpolation,
702 | mean=data_config['mean'],
703 | std=data_config['std'],
704 | num_workers=args.workers,
705 | distributed=args.distributed,
706 | collate_fn=collate_fn,
707 | pin_memory=args.pin_mem,
708 | use_multi_epochs_loader=args.use_multi_epochs_loader
709 | )
710 |
711 | eval_dir = os.path.join(args.data, 'val')
712 | if not os.path.isdir(eval_dir):
713 | eval_dir = os.path.join(args.data, 'validation')
714 | if not os.path.isdir(eval_dir):
715 | _logger.error('Validation folder does not exist at: {}'.format(eval_dir))
716 | exit(1)
717 | dataset_eval = Dataset(eval_dir)
718 |
719 | loader_eval = create_loader(
720 | dataset_eval,
721 | input_size=data_config['input_size'],
722 | batch_size=args.validation_batch_size_multiplier * args.batch_size,
723 | is_training=False,
724 | use_prefetcher=args.prefetcher,
725 | interpolation=data_config['interpolation'],
726 | mean=data_config['mean'],
727 | std=data_config['std'],
728 | num_workers=args.workers,
729 | distributed=args.distributed,
730 | crop_pct=data_config['crop_pct'],
731 | pin_memory=args.pin_mem,
732 | )
733 |
734 | if args.jsd:
735 | assert num_aug_splits > 1 # JSD only valid with aug splits set
736 | train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()
737 | elif mixup_active:
738 | # smoothing is handled with mixup target transform
739 | train_loss_fn = SoftTargetCrossEntropy().cuda()
740 | elif args.smoothing:
741 | train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
742 | else:
743 | train_loss_fn = nn.CrossEntropyLoss().cuda()
744 | # train_loss_fn = AdaHeadLoss(train_loss_fn, target_ratio=args.head_target_ratio,head_loss_ratio=args.head_ratio)
745 | train_loss_fn = AdaLoss(train_loss_fn,
746 | head_target_ratio=args.head_target_ratio, head_loss_ratio=args.head_ratio,
747 | layer_target_ratio=args.layer_target_ratio, layer_loss_ratio=args.layer_ratio,
748 | head_diverse_ratio=args.head_diverse_ratio, layer_diverse_ratio=args.layer_diverse_ratio,
749 | head_entropy_weight=args.head_entropy_weight, layer_entropy_weight=args.layer_entropy_weight,
750 | head_minimal_weight=args.head_minimal_weight, head_minimal=args.head_minimal,
751 | layer_minimal_weight=args.layer_minimal_weight, layer_minimal=args.layer_minimal,
752 | token_target_ratio=args.token_target_ratio, token_loss_ratio=args.token_ratio, token_minimal=args.token_minimal, token_minimal_weight=args.token_minimal_weight)
753 | validate_loss_fn = nn.CrossEntropyLoss().cuda()
754 |
755 | eval_metric = args.eval_metric
756 | best_metric = None
757 | best_epoch = None
758 |
759 | flops_dict = None
760 | if args.flops_dict :
761 | flops_dict = torch.load(args.flops_dict)
762 | if args.eval_random_baseline:
763 | assert args.eval_checkpoint, "Please provide path to the checkpoint when evaluating baselines."
764 | ada_load_state_dict(args.eval_checkpoint, model, use_qkv=False, strict=False)
765 | # set random policy configuration
766 | model.random_layer, model.random_head, model.random_token = args.eval_random_layer, args.eval_random_head, args.eval_random_token
767 | model.random_layer_ratio, model.random_head_ratio, model.random_token_ratio = \
768 | args.eval_random_layer_ratio, args.eval_random_head_ratio, args.eval_random_token_ratio
769 | val_metrics = validate(model, loader_eval, validate_loss_fn, args, print_head_option=args.print_head_option, flops_dict=flops_dict)
770 | print(f"Top-1 accuracy of the model is: {val_metrics['top1']:.1f}%")
771 | return
772 |
773 | if args.eval_checkpoint: # evaluate the model
774 | ada_load_state_dict(args.eval_checkpoint, model, use_qkv=False, strict=False)
775 | if args.use_full_head :
776 | model.apply(lambda m: setattr(m, 'use_full_linear', True))
777 | val_metrics = validate(model, loader_eval, validate_loss_fn, args, print_head_option=args.print_head_option, flops_dict=flops_dict)
778 | print(f"Top-1 accuracy of the model is: {val_metrics['top1']:.1f}%")
779 | if 'gflops' in val_metrics :
780 | print(f"avg flops is: {val_metrics['gflops']:.1f} GFLOPS")
781 | return
782 |
783 | saver = None
784 | output_dir = ''
785 | if args.local_rank == 0:
786 | output_base = args.output if args.output else './output'
787 | exp_name = '-'.join([
788 | datetime.datetime.now().strftime("%Y%m%d-%H%M%S"),
789 | args.model,
790 | ])
791 | output_dir = get_outdir(output_base, 'train', exp_name)
792 | decreasing = True if eval_metric == 'loss' else False
793 | saver = MyCheckpointSaver(
794 | model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
795 | checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing)
796 | with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
797 | f.write(args_text)
798 |
799 | try: # train the model
800 | for epoch in range(start_epoch, num_epochs):
801 | if args.distributed:
802 | loader_train.sampler.set_epoch(epoch)
803 |
804 | train_metrics = train_epoch(
805 | epoch, model, loader_train, optimizer, train_loss_fn, args,
806 | lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
807 | amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn, total_epochs=num_epochs)
808 |
809 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
810 | if args.local_rank == 0:
811 | _logger.info("Distributing BatchNorm running means and vars")
812 | distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
813 |
814 | eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, flops_dict=flops_dict)
815 |
816 | if model_ema is not None and not args.model_ema_force_cpu:
817 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
818 | distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
819 | ema_eval_metrics = validate(
820 | model_ema.ema, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)', flops_dict=flops_dict)
821 | eval_metrics = ema_eval_metrics
822 |
823 | if lr_scheduler is not None:
824 | # step LR for next epoch
825 | lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
826 |
827 | update_summary(
828 | epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
829 | write_header=best_metric is None)
830 |
831 | if saver is not None:
832 | # save proper checkpoint with eval metric
833 | save_metric = eval_metrics[eval_metric]
834 | best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)
835 |
836 | except KeyboardInterrupt:
837 | pass
838 | if best_metric is not None:
839 | _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
840 |
841 | def freeze_bn(model):
842 | for module in model.modules():
843 | if isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
844 | module.eval()
845 |
846 | def train_epoch(
847 | epoch, model, loader, optimizer, loss_fn, args,
848 | lr_scheduler=None, saver=None, output_dir='', amp_autocast=suppress,
849 | loss_scaler=None, model_ema=None, mixup_fn=None, total_epochs=0):
850 | if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
851 | if args.prefetcher and loader.mixup_enabled:
852 | loader.mixup_enabled = False
853 | elif mixup_fn is not None:
854 | mixup_fn.mixup_enabled = False
855 |
856 | base_model = model.module if hasattr(model, 'module') else model
857 | second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
858 | batch_time_m = AverageMeter()
859 | data_time_m = AverageMeter()
860 | losses_m = AverageMeter()
861 | top1_m = AverageMeter()
862 | top5_m = AverageMeter()
863 | meta_loss_m = {}
864 |
865 | model.train()
866 | if args.freeze_bn :
867 | freeze_bn(model)
868 |
869 | end = time.time()
870 | last_idx = len(loader) - 1
871 |
872 | def update_meta_loss_m(meta_loss) :
873 | for k, v in meta_loss.items():
874 | if k not in meta_loss_m :
875 | meta_loss_m[k] = AverageMeter()
876 | meta_loss_m[k].update(v.item(), input.size(0))
877 |
878 | num_updates = epoch * len(loader)
879 | total_num_updates = total_epochs * len(loader)
880 | for batch_idx, (input, target) in enumerate(loader):
881 | last_batch = batch_idx == last_idx
882 | data_time_m.update(time.time() - end)
883 | if not args.prefetcher:
884 | input, target = input.cuda(), target.cuda()
885 | if mixup_fn is not None:
886 | input, target = mixup_fn(input, target)
887 | if args.channels_last:
888 | input = input.contiguous(memory_format=torch.channels_last)
889 |
890 | with amp_autocast():
891 | outputs = model(input, training=True, ret_attn_list=False)
892 | loss, meta_loss = loss_fn(outputs, target)
893 |
894 | if not args.distributed:
895 | losses_m.update(loss.item(), input.size(0))
896 | update_meta_loss_m(meta_loss)
897 |
898 | optimizer.zero_grad()
899 | if loss_scaler is not None:
900 | loss_scaler(
901 | loss, optimizer, clip_grad=args.clip_grad, parameters=model.parameters(), create_graph=second_order)
902 | else:
903 | loss.backward(create_graph=second_order)
904 | if args.clip_grad is not None:
905 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
906 | optimizer.step()
907 |
908 | torch.cuda.synchronize()
909 | if model_ema is not None:
910 | model_ema.update(model)
911 | num_updates += 1
912 |
913 | batch_time_m.update(time.time() - end)
914 | if last_batch or batch_idx % args.log_interval == 0:
915 | lrl = [param_group['lr'] for param_group in optimizer.param_groups]
916 | lr = sum(lrl) / len(lrl)
917 |
918 | if args.distributed:
919 | reduced_loss = reduce_tensor(loss.data, args.world_size)
920 | losses_m.update(reduced_loss.item(), input.size(0))
921 | for k, v in meta_loss.items():
922 | meta_loss[k] = reduce_tensor(v.data, args.world_size)
923 | update_meta_loss_m(meta_loss)
924 |
925 | eta_seconds = batch_time_m.avg * (total_num_updates - num_updates)
926 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
927 |
928 | if args.local_rank == 0:
929 | _logger.info(
930 | 'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
931 | 'Current Time: {} '
932 | 'ETA : {} '
933 | 'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) '
934 | 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
935 | '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
936 | 'LR: {lr:.3e} '
937 | 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
938 | epoch,
939 | batch_idx, len(loader),
940 | 100. * batch_idx / last_idx,
941 | strftime("%Y-%m-%d %H:%M:%S", gmtime()),
942 | eta_string,
943 | loss=losses_m,
944 | batch_time=batch_time_m,
945 | rate=input.size(0) * args.world_size / batch_time_m.val,
946 | rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
947 | lr=lr,
948 | data_time=data_time_m))
949 |
950 | if args.save_images and output_dir:
951 | torchvision.utils.save_image(
952 | input,
953 | os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
954 | padding=0,
955 | normalize=True)
956 |
957 | if saver is not None and args.recovery_interval and (
958 | last_batch or (batch_idx + 1) % args.recovery_interval == 0):
959 | saver.save_recovery(epoch, batch_idx=batch_idx)
960 |
961 | if lr_scheduler is not None:
962 | lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
963 |
964 | end = time.time()
965 | # end for
966 |
967 | if hasattr(optimizer, 'sync_lookahead'):
968 | optimizer.sync_lookahead()
969 |
970 | return OrderedDict(loss= losses_m.avg, **{x : meta_loss_m[x].avg for x in meta_loss})
971 |
972 | def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='', ret_head_option=False, print_head_option=False, flops_dict=None):
973 | batch_time_m = AverageMeter()
974 | losses_m = AverageMeter()
975 | top1_m = AverageMeter()
976 | top5_m = AverageMeter()
977 | head_m = AverageMeter()
978 | layer_m = AverageMeter()
979 | token_m = AverageMeter()
980 | flops_m = AverageMeter()
981 |
982 | analyse_m = [[] for _ in range(1000)]
983 |
984 | head_option = None
985 | layer_option = None
986 | token_option = None
987 |
988 | model.eval()
989 |
990 | end = time.time()
991 | last_idx = len(loader) - 1
992 | with torch.no_grad():
993 | for batch_idx, (input, target) in enumerate(loader):
994 | last_batch = batch_idx == last_idx
995 | if not args.prefetcher:
996 | input = input.cuda()
997 | target = target.cuda()
998 | if args.channels_last:
999 | input = input.contiguous(memory_format=torch.channels_last)
1000 |
1001 | with amp_autocast():
1002 | output = model(input)
1003 | if isinstance(output, (tuple, list)):
1004 | output, head_select, layer_select, token_select = output[:4]
1005 | # output = output[0]
1006 | # head_select = output[1]
1007 | else :
1008 | head_select = None
1009 | layer_select = None
1010 | token_select = None
1011 |
1012 | # augmentation reduction
1013 | reduce_factor = args.tta
1014 | if reduce_factor > 1:
1015 | output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
1016 | target = target[0:target.size(0):reduce_factor]
1017 |
1018 | loss = loss_fn(output, target)
1019 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
1020 |
1021 | if args.distributed:
1022 | reduced_loss = reduce_tensor(loss.data, args.world_size)
1023 | acc1 = reduce_tensor(acc1, args.world_size)
1024 | acc5 = reduce_tensor(acc5, args.world_size)
1025 | if head_select is not None :
1026 | head_select = reduce_tensor(head_select, args.world_size)
1027 | if layer_select is not None :
1028 | layer_select = reduce_tensor(layer_select, args.world_size)
1029 | if token_select is not None :
1030 | token_select = reduce_tensor(token_select, args.world_size)
1031 | else:
1032 | reduced_loss = loss.data
1033 |
1034 | torch.cuda.synchronize()
1035 |
1036 | losses_m.update(reduced_loss.item(), input.size(0))
1037 | top1_m.update(acc1.item(), output.size(0))
1038 | top5_m.update(acc5.item(), output.size(0))
1039 | if head_select is not None :
1040 | head_m.update(head_select.mean().item(), output.size(0))
1041 | if head_option is None :
1042 | head_option = AverageMeter()
1043 | head_option.update(head_select.mean(0).cpu(), output.size(0))
1044 | if layer_select is not None :
1045 | layer_m.update(layer_select.mean().item(), output.size(0))
1046 | if layer_option is None :
1047 | layer_option = AverageMeter()
1048 | layer_option.update(layer_select.mean(0).cpu(), output.size(0))
1049 | if token_select is not None :
1050 | token_m.update(token_select.mean().item(), output.size(0))
1051 | if token_option is None :
1052 | token_option = AverageMeter()
1053 | token_option.update(token_select.mean((0,-1)), output.size(0))
1054 |
1055 | if flops_dict is not None :
1056 | bs = output.size(0)
1057 | from block_flops_dict import batch_select_flops
1058 | if 'deit' in args.model:
1059 | b_flops = batch_select_flops(bs, flops_dict, head_select, layer_select, token_select, block_num=12,base_flops=0.06)
1060 | else :
1061 | b_flops = batch_select_flops(bs, flops_dict, head_select, layer_select, token_select, block_num=19, base_flops=0.33)
1062 | flops_m.update(b_flops.mean(), bs)
1063 | def deal_none(x):
1064 | return x if x is not None else [None] * bs
1065 | head_select = deal_none(head_select)
1066 | layer_select = deal_none(layer_select)
1067 | token_select = deal_none(token_select)
1068 | for c, b, ht, lt, tt in zip(target, b_flops, head_select, layer_select, token_select):
1069 | lt = lt.cpu() if lt is not None else None
1070 | ht = ht.cpu() if ht is not None else None
1071 | tt = tt.cpu() if tt is not None else None
1072 | analyse_m[c].append(dict(flops=b.cpu(), layer_select=lt, head_select=ht, token_select=tt))
1073 |
1074 | batch_time_m.update(time.time() - end)
1075 | end = time.time()
1076 | if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
1077 | log_name = 'Test' + log_suffix
1078 | _logger.info(
1079 | '{0}: [{1:>4d}/{2}] '
1080 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
1081 | 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
1082 | 'Head avg :{head_select.val:>7.4f} ({head_select.avg:>6.4f}) '
1083 | 'Layer avg :{layer_m.val:>7.4f} ({layer_m.avg:>6.4f}) '
1084 | 'Token avg :{token_m.val:>7.4f} ({token_m.avg:>6.4f}) '
1085 | 'Flops avg :{flops_m.val:>7.4f} ({flops_m.avg:>6.4f}) '
1086 | 'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
1087 | 'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
1088 | log_name, batch_idx, last_idx, batch_time=batch_time_m,
1089 | loss=losses_m, top1=top1_m, top5=top5_m,
1090 | head_select=head_m, layer_m=layer_m, token_m=token_m, flops_m=flops_m))
1091 |
1092 | metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg), ('head', head_m.avg), ('layer', layer_m.avg)])
1093 | if flops_dict is not None :
1094 | metrics.update(gflops=flops_m.avg)
1095 | if head_option is not None :
1096 | if ret_head_option :
1097 | metrics.update(head_option=head_option.avg)
1098 | if print_head_option :
1099 | for i, val in enumerate(head_option.avg) :
1100 | if args.rank == 0 :
1101 | print('{}: {}'.format(i+args.keep_layers, ','.join(['{:.2f}'.format(float(x)) for x in val])))
1102 |
1103 | if layer_option is not None :
1104 | if ret_head_option :
1105 | metrics.update(layer_option=layer_option.avg)
1106 | if print_head_option :
1107 | for i, val in enumerate(layer_option.avg) :
1108 | if args.rank == 0 :
1109 | print('{}: {}'.format(i+args.keep_layers, ','.join(['{:.2f}'.format(float(x)) for x in val])))
1110 |
1111 | if token_option is not None :
1112 | if ret_head_option :
1113 | metrics.update(token_option=token_option.avg)
1114 | if print_head_option :
1115 | if args.rank == 0 :
1116 | print(' '.join(['{:.2f}'.format(float(x)) for x in token_option.avg]))
1117 |
1118 | if analyse_m is not None :
1119 | torch.save(analyse_m, '{}_analyse.{}.pth'.format(args.model, top1_m.avg))
1120 | return metrics
1121 |
1122 |
1123 | if __name__ == '__main__':
1124 | main()
1125 |
--------------------------------------------------------------------------------
/adavit_ckpt/deit-s-h-l-tmlp_flops_dict.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MengLcool/AdaViT/9857f92c8428045e01a8d988eab5060b1ac4a18b/adavit_ckpt/deit-s-h-l-tmlp_flops_dict.pth
--------------------------------------------------------------------------------
/adavit_ckpt/t2t-19-h-l-tmlp_flops_dict.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MengLcool/AdaViT/9857f92c8428045e01a8d988eab5060b1ac4a18b/adavit_ckpt/t2t-19-h-l-tmlp_flops_dict.pth
--------------------------------------------------------------------------------
/assets/adavit_approach.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MengLcool/AdaViT/9857f92c8428045e01a8d988eab5060b1ac4a18b/assets/adavit_approach.png
--------------------------------------------------------------------------------
/block_flops_dict.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import fvcore
3 | from fvcore.nn import FlopCountAnalysis
4 | import argparse
5 | import models
6 | from models.ada_transformer_block import StepAdaBlock
7 | from timm.models import create_model
8 |
9 | parser = argparse.ArgumentParser(description='T2T-ViT Training and Evaluating')
10 | parser.add_argument('--model', default='t2t-19', type=str, metavar='MODEL',
11 | help='Name of model to train (default: "countception"')
12 | parser.add_argument('--ada-layer', action='store_true', default=False, dest='ada_layer',
13 | help='')
14 | parser.add_argument('--ada-head', action='store_true', default=False, dest='ada_head',
15 | help='')
16 |
17 | parser.add_argument('--ada-token', action='store_true', default=False, dest='ada_token',
18 | help='ada-token on self-attn')
19 | parser.add_argument('--ada-token-with-mlp', action='store_true', default=False, dest='ada_token_with_mlp',
20 | help='ada-token on both self-attn and ffn')
21 | parser.add_argument('--keep-layers', type=int, default=1,
22 | help='use layers to make selection decision')
23 |
24 | def _parse_args():
25 |
26 | args = parser.parse_args()
27 | if args.ada_token_with_mlp :
28 | args.ada_token = True
29 |
30 | return args
31 |
32 | def get_flops_dict(dim, num_heads, mlp_ratio, debug=False, **kwargs):
33 | block = StepAdaBlock(dim, num_heads, mlp_ratio, **kwargs)
34 | inputs = torch.rand((1,197,dim))
35 |
36 | num_tokens = 197
37 | num_layers = 2
38 |
39 | block.apply(lambda x : setattr(x, 'count_flops', True))
40 | flops_dict = torch.zeros(num_heads+1, num_tokens+1, 2, 2)
41 | for t in range(1, num_tokens+1) :
42 | for h in range(num_heads+1) :
43 | block.apply(lambda x : setattr(x, 'h_ratio', h/num_heads))
44 | block.apply(lambda x : setattr(x, 't_ratio', (t)/197))
45 |
46 | if debug :
47 | block.apply(lambda x : setattr(x, 'h_ratio', 5/7))
48 | block.apply(lambda x : setattr(x, 't_ratio', (197)/197))
49 | block.apply(lambda x : setattr(x, 'l_ratio', [1,1]))
50 | flops = FlopCountAnalysis(block, inputs).total() / (1000**3)
51 | print('flops', flops)
52 | exit()
53 |
54 | def fill_dict(l_select):
55 | block.apply(lambda x : setattr(x, 'l_ratio', l_select))
56 |
57 | xx = block(inputs)
58 |
59 | # flops = FlopCountAnalysis(block, inputs).total() / (1000**3)
60 | # print('flops', h, t, l_select, flops)
61 | # flops_dict[h,t,l_select[0],l_select[1]] = flops
62 | fill_dict([0,0])
63 | fill_dict([0,1])
64 | fill_dict([1,0])
65 | fill_dict([1,1])
66 | return flops_dict
67 |
68 | def select_flops(flops_dict, head_select, layer_select, token_select, block_num, base_flops=0.33):
69 | '''
70 | head_select : None or tensor (ada_block_h, h_num)
71 | layer_select : None or tensor (ada_block_l, 2)
72 | token_select : None or tensor (ada_block_t, 197)
73 | '''
74 |
75 | h, t, _, _ = flops_dict.shape
76 | h = h-1
77 | t = t-1
78 | if head_select is None :
79 | head_select = [h] * block_num
80 | else :
81 | ada_h = head_select.shape[0]
82 | head_select = [h] * (block_num - ada_h) + head_select.sum(-1).int().tolist()
83 | if layer_select is None :
84 | layer_select = [[1,1]] * block_num
85 | else :
86 | ada_l = layer_select.shape[0]
87 | layer_select = [[1,1]] * (block_num-ada_l) + layer_select.int().tolist()
88 | if token_select is None :
89 | token_select = [t] * block_num
90 | else :
91 | ada_t = token_select.shape[0]
92 | token_select = [t] * (block_num - ada_t) + token_select.sum(-1).int().tolist()
93 |
94 | flops = base_flops
95 | for h, t, l in zip(head_select, token_select, layer_select) :
96 | flops += flops_dict[h, t, l[0], l[1]]
97 | return flops
98 |
99 | def batch_select_flops(bs, flops_dict, head_select, layer_select, token_select, block_num=19, base_flops=0.33):
100 | def batch_select(x):
101 | return x if x is not None else [None] * bs
102 | head_select = batch_select(head_select)
103 | layer_select = batch_select(layer_select)
104 | token_select = batch_select(token_select)
105 |
106 | batch_flops = []
107 | for h, l, t in zip(head_select, layer_select, token_select):
108 | batch_flops.append(select_flops(flops_dict, h, l, t, block_num, base_flops))
109 |
110 | return torch.tensor(batch_flops)
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .ada_t2t_vit import *
2 | from .deit import *
--------------------------------------------------------------------------------
/models/ada_t2t_vit.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from re import L
3 | import torch
4 | import torch.nn as nn
5 |
6 | from timm.models.helpers import load_pretrained
7 | from timm.models.registry import register_model
8 | from timm.models.layers import trunc_normal_
9 | from timm.models.vision_transformer import PatchEmbed
10 | import numpy as np
11 | from .token_transformer import Token_transformer
12 | from .token_performer import Token_performer
13 | from .transformer_block import Block, get_sinusoid_encoding
14 | from .ada_transformer_block import StepAdaBlock, get_random_policy
15 |
16 |
17 | def _cfg(url='', **kwargs):
18 | return {
19 | 'url': url,
20 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
21 | 'crop_pct': .9, 'interpolation': 'bicubic',
22 | 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225),
23 | 'classifier': 'head',
24 | **kwargs
25 | }
26 |
27 | default_cfgs = {
28 | 'T2t_vit_7': _cfg(),
29 | 'T2t_vit_10': _cfg(),
30 | 'T2t_vit_12': _cfg(),
31 | 'T2t_vit_14': _cfg(),
32 | 'T2t_vit_19': _cfg(),
33 | 'T2t_vit_24': _cfg(),
34 | 'T2t_vit_t_14': _cfg(),
35 | 'T2t_vit_t_19': _cfg(),
36 | 'T2t_vit_t_24': _cfg(),
37 | 'T2t_vit_14_resnext': _cfg(),
38 | 'T2t_vit_14_wide': _cfg(),
39 | }
40 |
41 |
42 | class T2T_module(nn.Module):
43 | """
44 | Tokens-to-Token encoding module
45 | """
46 | def __init__(self, img_size=224, tokens_type='performer', in_chans=3, embed_dim=768, token_dim=64):
47 | super().__init__()
48 |
49 | if tokens_type == 'transformer':
50 | print('adopt transformer encoder for tokens-to-token')
51 | self.soft_split0 = nn.Unfold(kernel_size=(7, 7), stride=(4, 4), padding=(2, 2))
52 | self.soft_split1 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
53 | self.soft_split2 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
54 |
55 | self.attention1 = Token_transformer(dim=in_chans * 7 * 7, in_dim=token_dim, num_heads=1, mlp_ratio=1.0)
56 | self.attention2 = Token_transformer(dim=token_dim * 3 * 3, in_dim=token_dim, num_heads=1, mlp_ratio=1.0)
57 | self.project = nn.Linear(token_dim * 3 * 3, embed_dim)
58 |
59 | elif tokens_type == 'performer':
60 | print('adopt performer encoder for tokens-to-token')
61 | self.soft_split0 = nn.Unfold(kernel_size=(7, 7), stride=(4, 4), padding=(2, 2))
62 | self.soft_split1 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
63 | self.soft_split2 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
64 |
65 | #self.attention1 = Token_performer(dim=token_dim, in_dim=in_chans*7*7, kernel_ratio=0.5)
66 | #self.attention2 = Token_performer(dim=token_dim, in_dim=token_dim*3*3, kernel_ratio=0.5)
67 | self.attention1 = Token_performer(dim=in_chans*7*7, in_dim=token_dim, kernel_ratio=0.5)
68 | self.attention2 = Token_performer(dim=token_dim*3*3, in_dim=token_dim, kernel_ratio=0.5)
69 | self.project = nn.Linear(token_dim * 3 * 3, embed_dim)
70 |
71 | elif tokens_type == 'convolution': # just for comparison with conolution, not our model
72 | # for this tokens type, you need change forward as three convolution operation
73 | print('adopt convolution layers for tokens-to-token')
74 | self.soft_split0 = nn.Conv2d(3, token_dim, kernel_size=(7, 7), stride=(4, 4), padding=(2, 2)) # the 1st convolution
75 | self.soft_split1 = nn.Conv2d(token_dim, token_dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) # the 2nd convolution
76 | self.project = nn.Conv2d(token_dim, embed_dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) # the 3rd convolution
77 |
78 | self.num_patches = (img_size // (4 * 2 * 2)) * (img_size // (4 * 2 * 2)) # there are 3 sfot split, stride are 4,2,2 seperately
79 |
80 | def forward(self, x):
81 | # step0: soft split
82 | x = self.soft_split0(x).transpose(1, 2)
83 |
84 | # iteration1: re-structurization/reconstruction
85 | x = self.attention1(x)
86 | B, new_HW, C = x.shape
87 | x = x.transpose(1,2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW)))
88 | # iteration1: soft split
89 | x = self.soft_split1(x).transpose(1, 2)
90 |
91 | # iteration2: re-structurization/reconstruction
92 | x = self.attention2(x)
93 | B, new_HW, C = x.shape
94 | x = x.transpose(1, 2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW)))
95 | # iteration2: soft split
96 | x = self.soft_split2(x).transpose(1, 2)
97 |
98 | # final tokens
99 | x = self.project(x)
100 |
101 | return x
102 |
103 |
104 | class AdaStepT2T_ViT(nn.Module):
105 | def __init__(self, img_size=224, tokens_type='performer', in_chans=3, num_classes=1000, embed_dim=768, depth=12,
106 | norm_policy=False, use_t2t = True, patch_size=16,
107 | ada_head = True, ada_head_v2=False, dyna_data=False, ada_head_attn=False, head_slowfast=False,
108 | ada_layer=False,
109 | ada_block=False, head_select_tau=5., layer_select_tau=5., token_select_tau=5.,
110 | ada_token = False, ada_token_with_mlp=False, ada_token_start_layer=0, ada_token_pre_softmax=True, ada_token_detach_attn=True, ada_token_detach_attn_at_mlp=False,
111 | keep_layers = 1,
112 | head_select_bias=True, layer_select_bias=True,
113 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
114 | drop_path_rate=0., norm_layer=nn.LayerNorm, token_dim=64, **kwargs):
115 | super().__init__()
116 | self.num_classes = num_classes
117 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
118 |
119 | self.use_t2t = use_t2t
120 | if use_t2t :
121 | self.tokens_to_token = T2T_module(
122 | img_size=img_size, tokens_type=tokens_type, in_chans=in_chans, embed_dim=embed_dim, token_dim=token_dim)
123 | num_patches = self.tokens_to_token.num_patches
124 | else :
125 | self.patch_embed = PatchEmbed(
126 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
127 | num_patches = self.patch_embed.num_patches
128 |
129 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
130 | self.pos_embed = nn.Parameter(data=get_sinusoid_encoding(n_position=num_patches + 1, d_hid=embed_dim), requires_grad=False)
131 | self.pos_drop = nn.Dropout(p=drop_rate)
132 |
133 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
134 | assert keep_layers >=1, 'keep layers must >=1'
135 | self.keep_layers = keep_layers
136 | self.head_dim = embed_dim // num_heads
137 | print('ada head, ada layer, ada token', ada_head, ada_layer, ada_token)
138 |
139 | self.blocks = nn.ModuleList([
140 | StepAdaBlock(
141 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
142 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
143 | is_token_select=ada_token and i >= ada_token_start_layer, ada_token_with_mlp=ada_token_with_mlp,
144 | ada_token_pre_softmax = ada_token_pre_softmax, ada_token_detach_attn=ada_token_detach_attn, dyna_data=dyna_data, ada_head_v2=ada_head_v2, ada_token_detach_attn_at_mlp=ada_token_detach_attn_at_mlp,
145 | ada_head= ada_head and i>=keep_layers,
146 | ada_layer= ada_layer and i >=keep_layers,
147 | norm_policy=norm_policy,
148 | only_head_attn=ada_head_attn, head_slowfast=head_slowfast)
149 | for i in range(depth)])
150 | self.norm = norm_layer(embed_dim)
151 |
152 | # Classifier head
153 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
154 |
155 | trunc_normal_(self.cls_token, std=.02)
156 | self.apply(self._init_weights)
157 |
158 | def _init_weights(self, m):
159 | if isinstance(m, nn.Linear):
160 | trunc_normal_(m.weight, std=.02)
161 | if isinstance(m, nn.Linear) and m.bias is not None:
162 | nn.init.constant_(m.bias, 0)
163 | elif isinstance(m, nn.LayerNorm):
164 | nn.init.constant_(m.bias, 0)
165 | nn.init.constant_(m.weight, 1.0)
166 |
167 | @torch.jit.ignore
168 | def no_weight_decay(self):
169 | return {'cls_token'}
170 |
171 | def get_classifier(self):
172 | return self.head
173 |
174 | def reset_classifier(self, num_classes, global_pool=''):
175 | self.num_classes = num_classes
176 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
177 |
178 | def forward_features(self, x):
179 | B = x.shape[0]
180 | if self.use_t2t :
181 | x = self.tokens_to_token(x)
182 | else :
183 | x = self.patch_embed(x)
184 |
185 | cls_tokens = self.cls_token.expand(B, -1, -1)
186 | x = torch.cat((cls_tokens, x), dim=1)
187 | x = x + self.pos_embed
188 | x = self.pos_drop(x)
189 |
190 | attn_list = []
191 | hidden_list = []
192 | token_select_list = []
193 | head_select_list = []
194 | layer_select_list = []
195 | head_select_logits_list = []
196 | layer_select_logits_list = []
197 |
198 | def filter_append(target_list, element):
199 | if element is not None :
200 | target_list.append(element)
201 |
202 | for blk in self.blocks :
203 | x, attn, this_head_select, this_layer_select, this_token_select, this_head_select_logits, this_layer_select_logits = blk(x)
204 | attn_list.append(attn)
205 | hidden_list.append(x)
206 | filter_append(head_select_list, this_head_select)
207 | filter_append(layer_select_list, this_layer_select)
208 | filter_append(token_select_list, this_token_select)
209 | filter_append(head_select_logits_list, this_head_select_logits)
210 | filter_append(layer_select_logits_list, this_layer_select_logits)
211 |
212 | def convert_list_to_tensor(list_convert):
213 | if len(list_convert) :
214 | result = torch.stack(list_convert, dim=1)
215 | else :
216 | result = None
217 | return result
218 |
219 | head_select = convert_list_to_tensor(head_select_list)
220 | if head_select is not None :
221 | head_select = head_select.squeeze(-1)
222 | layer_select = convert_list_to_tensor(layer_select_list)
223 | token_select = convert_list_to_tensor(token_select_list)
224 | head_select_logits = convert_list_to_tensor(head_select_logits_list)
225 | layer_select_logits = convert_list_to_tensor(layer_select_logits_list)
226 | a = [head_select, layer_select, token_select, head_select_logits, layer_select_logits]
227 | x = self.norm(x)
228 |
229 | return x[:, 0], head_select, layer_select, token_select, attn_list, hidden_list, dict(head_select_logits=head_select_logits, layer_select_logits=layer_select_logits)
230 |
231 | def zero_classification_grad(self):
232 | for blk in self.blocks[self.keep_layers:] :
233 | blk.zero_grad()
234 | self.norm.zero_grad()
235 | self.head.zero_grad()
236 |
237 | def forward(self, x, training=False, ret_attn_list=False):
238 | x, head_select, layer_select, token_select, attn_list, hidden_list, select_logtis = self.forward_features(x)
239 | x = self.head(x)
240 | if ret_attn_list :
241 | return x, head_select, layer_select, token_select, attn_list, hidden_list, select_logtis
242 | return x, head_select, layer_select, token_select, select_logtis
243 |
244 |
245 | class T2T_GroupNorm(nn.GroupNorm):
246 | def __init__(self, num_groups: int, num_channels: int, **kwargs) -> None:
247 | super().__init__(num_groups, num_channels, **kwargs)
248 |
249 | def forward(self, input: torch.Tensor) -> torch.Tensor:
250 | if input.dim() > 2 :
251 | input = input.transpose(1,-1)
252 | result = super().forward(input)
253 | if input.dim() > 2 :
254 | result = result.transpose(1,-1)
255 | return result
256 |
257 |
258 | def t2t_group_norm(dim, num_groups):
259 | return T2T_GroupNorm(num_groups, dim)
260 |
261 | @register_model
262 | def ada_step_t2t_vit_14_lnorm(pretrained=False, **kwargs): # adopt performer for tokens to token
263 | if pretrained:
264 | kwargs.setdefault('qk_scale', 384 ** -0.5)
265 | model = AdaStepT2T_ViT(tokens_type='performer', embed_dim=384, depth=14, num_heads=6, mlp_ratio=3., **kwargs)
266 | model.default_cfg = default_cfgs['T2t_vit_14']
267 | if pretrained:
268 | load_pretrained(
269 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
270 | return model
271 |
272 | @register_model
273 | def ada_step_t2t_vit_19_lnorm(pretrained=False, **kwargs): # adopt performer for tokens to token
274 | if pretrained:
275 | kwargs.setdefault('qk_scale', 448 ** -0.5)
276 | model = AdaStepT2T_ViT(tokens_type='performer', embed_dim=448, depth=19, num_heads=7, mlp_ratio=3., **kwargs)
277 | model.default_cfg = default_cfgs['T2t_vit_19']
278 | if pretrained:
279 | load_pretrained(
280 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
281 | return model
282 |
283 | @register_model
284 | def ada_step_deit_tiny_patch16_224(pretrained=False, **kwargs):
285 | model = AdaStepT2T_ViT(
286 | use_t2t=False,
287 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
288 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
289 | model.default_cfg = _cfg()
290 | return model
291 |
292 |
293 | @register_model
294 | def ada_step_deit_small_patch16_224(pretrained=False, **kwargs):
295 | model = AdaStepT2T_ViT(
296 | use_t2t=False,
297 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
298 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
299 | model.default_cfg = _cfg()
300 | return model
301 |
302 | @register_model
303 | def ada_step_deit_base_patch16_224(pretrained=False, **kwargs): # adopt performer for tokens to token
304 | model = AdaStepT2T_ViT(
305 | use_t2t=False,
306 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
307 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
308 | model.default_cfg = _cfg()
309 | return model
310 |
--------------------------------------------------------------------------------
/models/ada_transformer_block.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from timm.models.layers import DropPath
5 | from torch.nn.modules.normalization import LayerNorm
6 |
7 | def _gumbel_sigmoid(
8 | logits, tau=1, hard=False, eps=1e-10, training = True, threshold = 0.5
9 | ):
10 | if training :
11 | # ~Gumbel(0,1)`
12 | gumbels1 = (
13 | -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)
14 | .exponential_()
15 | .log()
16 | )
17 | gumbels2 = (
18 | -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)
19 | .exponential_()
20 | .log()
21 | )
22 | # Difference of two` gumbels because we apply a sigmoid
23 | gumbels1 = (logits + gumbels1 - gumbels2) / tau
24 | y_soft = gumbels1.sigmoid()
25 | else :
26 | y_soft = logits.sigmoid()
27 |
28 | if hard:
29 | # Straight through.
30 | y_hard = torch.zeros_like(
31 | logits, memory_format=torch.legacy_contiguous_format
32 | ).masked_fill(y_soft > threshold, 1.0)
33 | ret = y_hard - y_soft.detach() + y_soft
34 | else:
35 | ret = y_soft
36 | return ret
37 |
38 | def get_random_policy(policy, ratio):
39 | random_p = torch.empty_like(policy).fill_(ratio).bernoulli() + policy * 0.0 # add policy * 0.0 into the loop of loss calculation to avoid the DDP issue
40 | return random_p
41 |
42 | class SimpleTokenSelect(nn.Module):
43 | def __init__(self, dim_in, tau=5, is_hard=True, threshold=0.5, bias=True, pre_softmax=True, mask_filled_value=float('-inf'), ada_token_nonstep=False, ada_token_detach_attn=True):
44 | super().__init__()
45 | self.count_flops = False
46 | self.ada_token_nonstep = ada_token_nonstep # if using nonstep, no mlp_head is needed in each of these layers
47 | if not ada_token_nonstep:
48 | self.mlp_head = nn.Linear(dim_in, 1, bias=bias)
49 | self.norm = nn.Identity()
50 | self.is_hard = is_hard
51 | self.tau = tau
52 | self.threshold = threshold
53 | self.add_noise = True
54 | self.pre_softmax = pre_softmax
55 | self.mask_filled_value = mask_filled_value
56 | self.ada_token_detach_attn = ada_token_detach_attn
57 | self.random_policy = False
58 | self.random_token = False
59 | self.random_token_ratio = 1.
60 |
61 | def set_tau(self, tau):
62 | self.tau = tau
63 |
64 | def forward(self, x, attn, attn_pre_softmax, token_select=None):
65 | b, l = x.shape[:2]
66 |
67 | if not self.ada_token_nonstep:
68 | # generate token policy step by step in each layer, including the first (couple of) blocks
69 | logits = self.mlp_head(self.norm(x[:,1:]))
70 | token_select = _gumbel_sigmoid(logits, self.tau, self.is_hard, threshold=self.threshold, training=self.training)
71 | if self.random_policy or self.random_token:
72 | token_select = get_random_policy(token_select, self.random_token_ratio)
73 | token_select = torch.cat([token_select.new_ones(b,1,1), token_select], dim=1)
74 | # token_select = token_select.unsqueeze(-1) #(b,l,1)
75 | token_select = token_select.transpose(1,2) #(b,1,l)
76 | else:
77 | if token_select is None:
78 | # when token_select is not given in non-step setting,
79 | # it means this layer is in the first (couple of) trans blocks before head/layer policy generation,
80 | # and thus we do not drop any token in this/these layers as well for consistency
81 | token_select = torch.ones((b, 1, l), device=x.device)
82 | else:
83 | token_select = token_select[:, None, :]
84 |
85 | if self.count_flops :
86 | return attn, token_select.squeeze(1)
87 |
88 | attn_policy = torch.bmm(token_select.transpose(-1,-2), token_select) #(b,l,l)
89 | attn_policy = attn_policy.unsqueeze(1) #(b,1,l,l)
90 | if self.ada_token_detach_attn :
91 | attn_policy = attn_policy.detach()
92 |
93 | # use pre_softmax during inference in both pre-softmax or pre-softmax training
94 | if self.pre_softmax or not self.training :
95 | eye_mat = attn.new_zeros((l,l))
96 | eye_mat = eye_mat.fill_diagonal_(1) #(1,1,l,l)
97 | attn = attn_pre_softmax * attn_policy + attn_pre_softmax.new_zeros(attn_pre_softmax.shape).masked_fill_((1 - attn_policy - eye_mat)>0, self.mask_filled_value)
98 | attn = attn.softmax(-1)
99 | assert not torch.isnan(attn).any(), 'token select pre softmax nan !'
100 | else :
101 | attn = nn.functional.normalize(attn * attn_policy, 1, -1)
102 |
103 | return attn, token_select.squeeze(1)
104 |
105 |
106 | class BlockHeadSelect(nn.Module):
107 | def __init__(self, dim_in, num_heads, tau=5, is_hard=True, threshold=0.5, bias=True):
108 | super().__init__()
109 | self.mlp_head = nn.Linear(dim_in, num_heads, bias=bias)
110 | # self.norm = Lay rNorm(dim_in)
111 | self.is_hard = is_hard
112 | self.tau = tau
113 | self.threshold = threshold
114 | self.add_noise = True
115 | self.head_dim = dim_in // num_heads
116 | self.random_policy = False
117 | self.random_head = False
118 | self.random_head_ratio = 1.
119 |
120 | def set_tau(self, tau):
121 | self.tau = tau
122 |
123 | def forward(self, x):
124 | '''
125 | ret : tensor(b, dim, 1)
126 | '''
127 | bsize = x.shape[0]
128 | logits = self.mlp_head(x)
129 | sample = _gumbel_sigmoid(logits, self.tau, self.is_hard, threshold=self.threshold, training=self.training)
130 | if self.random_policy or self.random_head:
131 | sample = get_random_policy(sample, self.random_head_ratio)
132 | sample = sample.unsqueeze(-1) #(b,h,1)
133 |
134 | width_select = sample.expand(-1,-1,self.head_dim)
135 | width_select = width_select.reshape(bsize, -1, 1)
136 |
137 | return sample, width_select, logits
138 |
139 |
140 | class BlockLayerSelect(nn.Module):
141 | def __init__(self, dim_in, num_sub_layer, tau=5, is_hard=True, threshold=0.5, bias=True):
142 | super().__init__()
143 | self.mlp_head = nn.Linear(dim_in, num_sub_layer, bias=bias)
144 | # self.norm = LayerNorm(dim_in)
145 | self.is_hard = is_hard
146 | self.tau = tau
147 | self.threshold = threshold
148 | self.add_noise = True
149 | self.random_policy = False
150 | self.random_layer = False
151 | self.random_layer_ratio = 1.
152 |
153 | def set_tau(self, tau):
154 | self.tau = tau
155 |
156 | def forward(self, x):
157 | logits = self.mlp_head(x)
158 | sample = _gumbel_sigmoid(logits, self.tau, self.is_hard, threshold=self.threshold, training=self.training)
159 | if self.random_policy or self.random_layer:
160 | sample = get_random_policy(sample, self.random_layer_ratio)
161 | sample = sample #(b,2)
162 |
163 | return sample, logits
164 |
165 |
166 | class DynaLinear(nn.Linear):
167 | def __init__(self, in_features, out_features, num_heads=6, bias=True, dyna_dim=[True, True], dyna_data=False):
168 | super(DynaLinear, self).__init__(
169 | in_features, out_features, bias=bias)
170 | self.in_features_max = in_features
171 | self.out_features_max = out_features
172 |
173 | self.num_heads = num_heads
174 | self.width_mult = 1.
175 | self.dyna_dim = dyna_dim
176 | self.in_features = in_features
177 | self.out_features = out_features
178 | self.use_full_linear = False
179 | self.dyna_data = dyna_data # if dyna_data is False, dyna weights
180 | self.count_flops = False
181 |
182 | def forward(self, input, width_select=None, width_specify=None):
183 | """
184 | input : tensor (B,L,C)
185 | width_select : tensor(B,1,dims) or (B,dims,1)
186 | """
187 | if self.use_full_linear :
188 | return super().forward(input)
189 |
190 | if self.count_flops :
191 | if width_select is not None :
192 | assert width_select.shape[0] == 1
193 | width_specify = int(width_select.sum().item())
194 | width_select = None
195 |
196 | if self.dyna_data and width_select is not None :
197 | # only support input shape of (b,l,c)
198 | assert input.dim() == 3
199 | assert width_select.dim() == 3
200 | assert width_select.shape[1] == 1 or width_select.shape[2] == 1
201 | # if output is static, then input is dynamic
202 | if width_select.shape[1] == 1 :
203 | input_mask = width_select
204 | else :
205 | input_mask = 1
206 | if width_select.shape[2] == 1 :
207 | output_mask = width_select[...,0].unsqueeze(1) #(b,1,c)
208 | else :
209 | output_mask = 1
210 | input = input * input_mask
211 | result = super().forward(input) * output_mask
212 | return result
213 |
214 | if width_select is not None:
215 | weight = self.weight * width_select
216 | b, n, c = input.shape
217 | input = input.transpose(1,2).reshape(1,-1,n)
218 | weight = weight.view(-1,c,1)
219 | if self.bias is None :
220 | bias = self.bias
221 | elif width_select.shape[-1] == 1:
222 | bias = self.bias * width_select.squeeze()
223 | bias = bias.view(-1)
224 | else :
225 | bias = self.bias.unsqueeze(0).expand(b,-1)
226 | bias = bias.reshape(-1)
227 | result = nn.functional.conv1d(input, weight, bias, groups=b)
228 | result = result.view(b,-1,n).transpose(1,2) #(b,n,c)
229 | return result
230 |
231 | if width_specify is not None :
232 | if self.dyna_dim[0] :
233 | self.in_features = width_specify
234 | if self.dyna_dim[1] :
235 | self.out_features = width_specify
236 | weight = self.weight[:self.out_features, :self.in_features]
237 | if self.bias is not None:
238 | bias = self.bias[:self.out_features]
239 | else:
240 | bias = self.bias
241 |
242 | return nn.functional.linear(input, weight, bias)
243 |
244 |
245 | class AdaAttention(nn.Module):
246 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., ada_token=False, ada_token_nonstep=False, ada_token_pre_softmax=True, ada_token_detach_attn=True, dyna_data=False, ada_token_threshold=0.6):
247 | super().__init__()
248 | self.count_flops = False
249 | self.t_ratio = 1
250 |
251 | self.num_heads = num_heads
252 | # head_dim = dim // num_heads
253 | self.head_dim = dim // num_heads
254 | self.scale = qk_scale or self.head_dim ** -0.5
255 |
256 | # self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
257 | self.query = DynaLinear(dim, dim, bias=qkv_bias, dyna_dim=[False, True], dyna_data=dyna_data)
258 | self.key = DynaLinear(dim, dim, bias=qkv_bias, dyna_dim=[False, True], dyna_data=dyna_data)
259 | self.value = DynaLinear(dim, dim, bias=qkv_bias, dyna_dim=[False, True], dyna_data=dyna_data)
260 |
261 | self.attn_drop = nn.Dropout(attn_drop)
262 | self.proj = DynaLinear(dim, dim, dyna_dim=[True, True], dyna_data=dyna_data)
263 | self.proj_drop = nn.Dropout(proj_drop)
264 |
265 | if ada_token :
266 | self.token_select = SimpleTokenSelect(dim, pre_softmax=ada_token_pre_softmax, ada_token_detach_attn=ada_token_detach_attn, ada_token_nonstep=ada_token_nonstep, threshold=ada_token_threshold)
267 | else :
268 | self.token_select = None
269 |
270 | def forward_count_flops(self, x, width_select=None, head_select=None, token_select=None, only_head_attn=False):
271 | width_specify=None
272 | B, N, C = x.shape
273 | width_select_qk = width_select
274 | if only_head_attn :
275 | assert head_select is not None
276 | width_select = None
277 |
278 | if self.token_select is not None :
279 | token_active = int(x.shape[1] * self.t_ratio)
280 | else :
281 | token_active = x.shape[1]
282 |
283 | q = self.query(x[:,:token_active], width_select=width_select_qk, width_specify=width_specify).reshape(B, token_active, -1, C//self.num_heads).permute(0,2,1,3)
284 | k = self.key(x[:,:token_active], width_select=width_select_qk, width_specify=width_specify).reshape(B, token_active, -1, C//self.num_heads).permute(0,2,1,3)
285 | v = self.value(x[:,:token_active], width_select=width_select, width_specify=width_specify).reshape(B, token_active, -1, C//self.num_heads).permute(0,2,1,3)
286 |
287 | attn = (q @ k.transpose(-2, -1)) * self.scale
288 | attn_pre_softmax = attn
289 | attn = attn.softmax(dim=-1)
290 |
291 | attn_origin = attn
292 |
293 | if self.token_select is not None :
294 | attn, token_select = self.token_select(x, attn, attn_pre_softmax, token_select=token_select)
295 |
296 | attn = self.attn_drop(attn)
297 | if only_head_attn :
298 | v[:,:attn.shape[1]] = (attn @ v[:,:attn.shape[1]])
299 | x = v.transpose(1, 2).reshape(B, token_active, -1)
300 | else :
301 | x = (attn @ v).transpose(1, 2).reshape(B, token_active, -1)
302 | if width_select is not None :
303 | width_select = width_select.transpose(-1,-2)
304 | x = self.proj(x, width_select, width_specify=width_specify)
305 |
306 | x[:, :token_active] = self.proj_drop(x[:, :token_active])
307 |
308 | return x, attn_origin, token_select
309 |
310 | def forward(self, x, mask=None, value_mask_fill=-1e4, head_mask=None, width_select=None, head_select=None, token_select=None, width_specify=None, token_keep_ratio=None, only_head_attn=False,
311 | random_token_select_ratio=1.0):
312 |
313 | if self.count_flops :
314 | return self.forward_count_flops(x, width_select, head_select, token_select, only_head_attn)
315 | B, N, C = x.shape
316 | if only_head_attn :
317 | assert head_select is not None
318 | width_select = None
319 |
320 | q = self.query(x, width_select=width_select, width_specify=width_specify).reshape(B, N, -1, C//self.num_heads).permute(0,2,1,3)
321 | k = self.key(x, width_select=width_select, width_specify=width_specify).reshape(B, N, -1, C//self.num_heads).permute(0,2,1,3)
322 | v = self.value(x, width_select=width_select, width_specify=width_specify).reshape(B, N, -1, C//self.num_heads).permute(0,2,1,3)
323 |
324 | attn = (q @ k.transpose(-2, -1)) * self.scale
325 | if mask is not None :
326 | mask = mask.view(B, 1, N, 1).expand_as(attn)
327 | attn[~mask] = value_mask_fill
328 | attn_pre_softmax = attn
329 | attn = attn.softmax(dim=-1)
330 | if only_head_attn :
331 | head_select = head_select.view(*head_select.shape, *([1]*(4-head_select.dim())))
332 | eye_mat = attn.new_zeros(attn.shape[-2:])
333 | eye_mat.fill_diagonal_(1).view(1,1,*eye_mat.shape) #(1,1,l,l)
334 | attn = attn * head_select + eye_mat * (1 - head_select)
335 |
336 | attn_origin = attn
337 | if head_mask is not None :
338 | attn = attn * head_mask
339 |
340 | if self.token_select is not None :
341 | attn, token_select = self.token_select(x, attn, attn_pre_softmax, token_select=token_select)
342 | if only_head_attn and not self.training :
343 | head_select = head_select.view(*head_select.shape, *([1]*(4-head_select.dim())))
344 | eye_mat = attn.new_zeros(attn.shape[-2:])
345 | eye_mat.fill_diagonal_(1).view(1,1,*eye_mat.shape) #(1,1,l,l)
346 | attn = attn * head_select + eye_mat * (1 - head_select)
347 |
348 | if random_token_select_ratio != 1.0:
349 | # test random baseline with predefined token select ratio
350 | token_select = (torch.rand((B, N), device=x.device) < random_token_select_ratio).float()
351 | token_select[:, 0] = 1.0 # CLS token is always kept
352 | attn_policy = torch.bmm(token_select[:, None, :].transpose(-1,-2), token_select[:, None, :]) #(b,l,l)
353 | attn = nn.functional.normalize(attn * attn_policy.unsqueeze(1), 1, -1)
354 |
355 | attn = self.attn_drop(attn)
356 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
357 | if width_select is not None :
358 | width_select = width_select.transpose(-1,-2)
359 | x = self.proj(x, width_select, width_specify=width_specify)
360 |
361 | if token_select is not None :
362 | x = x * token_select.unsqueeze(-1)
363 | x = self.proj_drop(x)
364 |
365 | return x, attn_origin, token_select
366 |
367 |
368 | class AdaMlp(nn.Module):
369 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., dyna_data=False):
370 | super().__init__()
371 | out_features = out_features or in_features
372 | hidden_features = hidden_features or in_features
373 | self.act = act_layer()
374 | self.fc1 = DynaLinear(in_features, hidden_features, dyna_dim=[True, False], dyna_data=dyna_data)
375 | self.fc2 = DynaLinear(hidden_features, out_features, dyna_dim=[False, False], dyna_data=dyna_data)
376 | self.drop = nn.Dropout(drop)
377 |
378 | def forward(self, x, mask=None, width_select=None, width_specify=None):
379 | if mask is not None :
380 | assert mask.shape[:2] == x.shape[:2]
381 | if mask.dim() == 2:
382 | mask = mask.unsqueeze(-1)
383 | if mask.dtype != x.dtype :
384 | mask = mask.type_as(x)
385 | else :
386 | mask = x.new_ones(x.shape[:2]).unsqueeze(-1)
387 | x = self.fc1(x, width_select=width_select, width_specify=width_specify)
388 | x = x * mask
389 | x = self.act(x)
390 | x = self.drop(x)
391 | width_select = None
392 | x = self.fc2(x, width_select=width_select, width_specify=width_specify)
393 | x = x * mask
394 | x = self.drop(x)
395 | return x
396 |
397 |
398 | class StepAdaBlock(nn.Module):
399 |
400 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
401 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,
402 | ada_head=False, ada_layer=False, is_token_select=False, ada_token_pre_softmax=True, ada_token_detach_attn=True,
403 | ada_token_with_mlp=False, ada_token_detach_attn_at_mlp=True,
404 | dyna_data=False, ada_head_v2=False, only_head_attn=False, head_slowfast=False,
405 | norm_policy=False):
406 | super().__init__()
407 | self.count_flops = False
408 | self.h_ratio, self.t_ratio = 1., 1.
409 | self.l_ratio = [1, 1]
410 | self.is_token_select = is_token_select
411 |
412 | self.norm_policy = None
413 | if norm_policy and (ada_head or ada_layer):
414 | self.norm_policy = norm_layer(dim)
415 | self.norm1 = norm_layer(dim)
416 | self.attn = AdaAttention(
417 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, ada_token=is_token_select, ada_token_pre_softmax=ada_token_pre_softmax, ada_token_detach_attn=ada_token_detach_attn,
418 | dyna_data=dyna_data)
419 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
420 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
421 | self.norm2 = norm_layer(dim)
422 | mlp_hidden_dim = int(dim * mlp_ratio)
423 | self.mlp_ratio = mlp_ratio
424 | self.ada_head_v2 = ada_head_v2
425 | self.mlp = AdaMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, dyna_data=dyna_data)
426 | self.ada_token_with_mlp = ada_token_with_mlp
427 | self.ada_token_detach_attn_at_mlp = ada_token_detach_attn_at_mlp
428 | self.only_head_attn = only_head_attn and ada_head
429 | self.head_slowfast = head_slowfast
430 |
431 | self.head_select = None
432 | self.layer_select = None
433 | if ada_head :
434 | self.head_select = BlockHeadSelect(dim, num_heads)
435 | if ada_layer :
436 | self.layer_select = BlockLayerSelect(dim, 2)
437 |
438 | def forward_count_flops(self, x) :
439 | if self.norm_policy is not None :
440 | policy_token = self.norm_policy(x)[:,0]
441 | else :
442 | policy_token = x[:,0]
443 | if self.layer_select is not None :
444 | sub_layer_select, layer_logits = self.layer_select(policy_token)
445 | else :
446 | sub_layer_select, layer_logits = None, None
447 | if self.head_select is not None :
448 | head_select, width_select, head_logits = self.head_select(policy_token)
449 | else :
450 | head_select, width_select, head_logits = None, None, None
451 |
452 | def mask_policy(policy, ratio) :
453 | if policy is not None :
454 | policy = torch.zeros_like(policy)
455 | policy[:, :int(policy.shape[1] * ratio)] = 1
456 | return policy
457 | h_ratio, l_ratio, t_ratio = [1, 1, 1]
458 | if head_select is not None :
459 | h_ratio = self.h_ratio
460 | head_select = mask_policy(head_select, h_ratio)
461 | if width_select is not None :
462 | width_select = mask_policy(width_select, h_ratio)
463 | if sub_layer_select is not None :
464 | l_ratio = self.l_ratio
465 | sub_layer_select[:,0] = l_ratio[0]
466 | sub_layer_select[:,1] = l_ratio[1]
467 | if self.is_token_select :
468 | t_ratio = self.t_ratio
469 |
470 | if width_select is not None :
471 | # TODO
472 | if width_select.sum() == 0 :
473 | return x
474 | width_select_attn = width_select #(b,c,1)
475 | if self.only_head_attn :
476 | assert head_select is not None
477 | width_select_mlp = None
478 | elif self.ada_head_v2 :
479 | bs = width_select.shape[0]
480 | width_select_mlp = width_select.expand(-1,-1,int(self.mlp_ratio)).reshape(bs,-1,1)
481 | else :
482 | width_select_mlp = width_select.transpose(-1,-2)
483 | else :
484 | width_select_attn, width_select_mlp = [None] * 2
485 |
486 | if sub_layer_select is None or sub_layer_select[0,0] :
487 | attn_x, attn_origin, token_select = self.attn(self.norm1(x), width_select=width_select_attn,
488 | head_select=head_select, only_head_attn=self.only_head_attn)
489 | x[:,:int(x.shape[1] * t_ratio),:attn_x.shape[-1]] = x[:,:int(x.shape[1] * t_ratio),:attn_x.shape[-1]] + attn_x
490 | elif sub_layer_select[0,0] == 0:
491 | attn_x = 0
492 | x = x + attn_x
493 |
494 | x = self.norm2(x)
495 | if self.only_head_attn :
496 | mlp_x = x
497 | else :
498 | mlp_x = x
499 |
500 | if sub_layer_select is not None and sub_layer_select[0,1] == 0 :
501 | x = x + 0
502 | else :
503 | if self.ada_token_with_mlp :
504 | token_active = int(x.shape[1] * t_ratio)
505 | else :
506 | token_active = x.shape[1]
507 |
508 | mlp_x = self.mlp(mlp_x[:,:token_active], width_select=width_select_mlp)
509 | x[:,:token_active] = x[:,:token_active] + mlp_x
510 |
511 | return x
512 |
513 | def forward(self, x, mask=None, head_mask=None, width_specify=None, # only_head_attn=False, head_slowfast=False,
514 | random_token_select_ratio=1.0):
515 | """
516 | width_select : (b,c,1)
517 | """
518 | if self.count_flops :
519 | return self.forward_count_flops(x)
520 |
521 | if self.norm_policy is not None :
522 | policy_token = self.norm_policy(x)[:,0]
523 | else :
524 | policy_token = x[:,0]
525 | if self.layer_select is not None :
526 | sub_layer_select, layer_logits = self.layer_select(policy_token)
527 | else :
528 | sub_layer_select, layer_logits = None, None
529 | if self.head_select is not None :
530 | head_select, width_select, head_logits = self.head_select(policy_token)
531 | else :
532 | head_select, width_select, head_logits = None, None, None
533 |
534 | # start
535 | if self.only_head_attn :
536 | assert head_select is not None
537 | width_select = None
538 | if width_select is not None :
539 | # TODO
540 | width_select_attn = width_select #(b,c,1)
541 | if self.ada_head_v2 :
542 | bs = width_select.shape[0]
543 | width_select_mlp = width_select.expand(-1,-1,int(self.mlp_ratio)).reshape(bs,-1,1)
544 | else :
545 | width_select_mlp = width_select.transpose(-1,-2)
546 | else :
547 | width_select_attn, width_select_mlp = [None] * 2
548 |
549 | attn_x, attn_origin, token_select = self.attn(self.norm1(x), mask=mask, head_mask=head_mask, width_select=width_select_attn,
550 | width_specify=width_specify, head_select=head_select, only_head_attn=self.only_head_attn,
551 | random_token_select_ratio=random_token_select_ratio)
552 |
553 | if sub_layer_select is None :
554 | x = x + self.drop_path(attn_x)
555 | mlp_x = self.mlp(self.norm2(x), width_select=width_select_mlp, width_specify=width_specify)
556 | if self.ada_token_with_mlp and token_select is not None :
557 | t_select = token_select.unsqueeze(-1)
558 | if self.ada_token_detach_attn_at_mlp:
559 | t_select = t_select.detach()
560 | mlp_x = t_select * mlp_x
561 | x = x + self.drop_path(mlp_x)
562 | else :
563 | x = x + sub_layer_select[:,0][:,None,None] * attn_x
564 | mlp_x = self.mlp(self.norm2(x), width_select=width_select_mlp, width_specify=width_specify)
565 | if self.ada_token_with_mlp and token_select is not None :
566 | t_select = token_select.unsqueeze(-1)
567 | if self.ada_token_detach_attn_at_mlp:
568 | t_select = t_select.detach()
569 | mlp_x = t_select * mlp_x
570 | x = x + sub_layer_select[:,1][:,None,None] * mlp_x
571 | return x, attn_origin, head_select, sub_layer_select, token_select, head_logits, layer_logits
572 |
--------------------------------------------------------------------------------
/models/deit.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | import torch
4 | import torch.nn as nn
5 | from functools import partial
6 |
7 | from timm.models.vision_transformer import VisionTransformer, _cfg
8 | from timm.models.registry import register_model
9 | from timm.models.layers import trunc_normal_
10 |
11 |
12 | __all__ = [
13 | 'deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224',
14 | 'deit_tiny_distilled_patch16_224', 'deit_small_distilled_patch16_224',
15 | 'deit_base_distilled_patch16_224', 'deit_base_patch16_384',
16 | 'deit_base_distilled_patch16_384',
17 | ]
18 |
19 |
20 | class DistilledVisionTransformer(VisionTransformer):
21 | def __init__(self, *args, **kwargs):
22 | super().__init__(*args, **kwargs)
23 | self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
24 | num_patches = self.patch_embed.num_patches
25 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
26 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
27 |
28 | trunc_normal_(self.dist_token, std=.02)
29 | trunc_normal_(self.pos_embed, std=.02)
30 | self.head_dist.apply(self._init_weights)
31 |
32 | def forward_features(self, x):
33 | # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
34 | # with slight modifications to add the dist_token
35 | B = x.shape[0]
36 | x = self.patch_embed(x)
37 |
38 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
39 | dist_token = self.dist_token.expand(B, -1, -1)
40 | x = torch.cat((cls_tokens, dist_token, x), dim=1)
41 |
42 | x = x + self.pos_embed
43 | x = self.pos_drop(x)
44 |
45 | for blk in self.blocks:
46 | x = blk(x)
47 |
48 | x = self.norm(x)
49 | return x[:, 0], x[:, 1]
50 |
51 | def forward(self, x):
52 | x, x_dist = self.forward_features(x)
53 | x = self.head(x)
54 | x_dist = self.head_dist(x_dist)
55 | if self.training:
56 | return x, x_dist
57 | else:
58 | # during inference, return the average of both classifier predictions
59 | return (x + x_dist) / 2
60 |
61 |
62 | @register_model
63 | def deit_tiny_patch16_224(pretrained=False, **kwargs):
64 | model = VisionTransformer(
65 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
66 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
67 | model.default_cfg = _cfg()
68 | if pretrained:
69 | checkpoint = torch.hub.load_state_dict_from_url(
70 | url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
71 | map_location="cpu", check_hash=True
72 | )
73 | model.load_state_dict(checkpoint["model"])
74 | return model
75 |
76 |
77 | @register_model
78 | def deit_small_patch16_224(pretrained=False, **kwargs):
79 | model = VisionTransformer(
80 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
81 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
82 | model.default_cfg = _cfg()
83 | if pretrained:
84 | checkpoint = torch.hub.load_state_dict_from_url(
85 | url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth",
86 | map_location="cpu", check_hash=True
87 | )
88 | model.load_state_dict(checkpoint["model"])
89 | return model
90 |
91 |
92 | @register_model
93 | def deit_base_patch16_224(pretrained=False, **kwargs):
94 | model = VisionTransformer(
95 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
96 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
97 | model.default_cfg = _cfg()
98 | if pretrained:
99 | checkpoint = torch.hub.load_state_dict_from_url(
100 | url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
101 | map_location="cpu", check_hash=True
102 | )
103 | model.load_state_dict(checkpoint["model"])
104 | return model
105 |
106 |
107 | @register_model
108 | def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
109 | model = DistilledVisionTransformer(
110 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
111 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
112 | model.default_cfg = _cfg()
113 | if pretrained:
114 | checkpoint = torch.hub.load_state_dict_from_url(
115 | url="https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth",
116 | map_location="cpu", check_hash=True
117 | )
118 | model.load_state_dict(checkpoint["model"])
119 | return model
120 |
121 |
122 | @register_model
123 | def deit_small_distilled_patch16_224(pretrained=False, **kwargs):
124 | model = DistilledVisionTransformer(
125 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
126 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
127 | model.default_cfg = _cfg()
128 | if pretrained:
129 | checkpoint = torch.hub.load_state_dict_from_url(
130 | url="https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth",
131 | map_location="cpu", check_hash=True
132 | )
133 | model.load_state_dict(checkpoint["model"])
134 | return model
135 |
136 |
137 | @register_model
138 | def deit_base_distilled_patch16_224(pretrained=False, **kwargs):
139 | model = DistilledVisionTransformer(
140 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
141 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
142 | model.default_cfg = _cfg()
143 | if pretrained:
144 | checkpoint = torch.hub.load_state_dict_from_url(
145 | url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth",
146 | map_location="cpu", check_hash=True
147 | )
148 | model.load_state_dict(checkpoint["model"])
149 | return model
150 |
151 |
152 | @register_model
153 | def deit_base_patch16_384(pretrained=False, **kwargs):
154 | model = VisionTransformer(
155 | img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
156 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
157 | model.default_cfg = _cfg()
158 | if pretrained:
159 | checkpoint = torch.hub.load_state_dict_from_url(
160 | url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth",
161 | map_location="cpu", check_hash=True
162 | )
163 | model.load_state_dict(checkpoint["model"])
164 | return model
165 |
166 |
167 | @register_model
168 | def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
169 | model = DistilledVisionTransformer(
170 | img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
171 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
172 | model.default_cfg = _cfg()
173 | if pretrained:
174 | checkpoint = torch.hub.load_state_dict_from_url(
175 | url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth",
176 | map_location="cpu", check_hash=True
177 | )
178 | model.load_state_dict(checkpoint["model"])
179 | return model
--------------------------------------------------------------------------------
/models/losses.py:
--------------------------------------------------------------------------------
1 | from numpy.lib.arraysetops import isin
2 | from timm import loss
3 | from timm.data.transforms_factory import transforms_imagenet_train
4 | import torch
5 | from torch.functional import Tensor
6 | import torch.nn as nn
7 |
8 | def binaray_entropy(prob, eps=1e-7):
9 | neg_entro = prob * prob.clamp(min=eps).log() + (1-prob) * (1-prob).clamp(min=eps).log()
10 | return - neg_entro
11 |
12 | class BgLoss(nn.Module):
13 | def __init__(self, base_criterion):
14 | super().__init__()
15 | self.base_criterion = base_criterion
16 |
17 | def forward(self, outputs, y) :
18 | assert isinstance(outputs, tuple) and len(outputs) == 2, 'err {} {}'.format(type(outputs), len(outputs))
19 | cls_pred, bg_pred = outputs
20 | if y.dim() == 2 :
21 | bsize, c = y.shape
22 | y_min = y.min(-1, keepdim=True)[0]
23 | y = torch.cat([y_min, y], dim=1)
24 | y_bg = torch.cat([y.new_ones(bsize,1), y.new_ones(bsize,c) * y_min], dim=1)
25 | y = nn.functional.normalize(y, 1, -1)
26 | y_bg = nn.functional.normalize(y_bg, 1, -1)
27 | else :
28 | y = y+1
29 | y_bg = y.new_zeros(y.shape)
30 | base_loss = self.base_criterion(cls_pred, y)
31 | bg_loss = self.base_criterion(bg_pred, y_bg)
32 |
33 | return (base_loss + bg_loss).mean()
34 |
35 |
36 | class AdaHeadLoss(nn.Module):
37 | def __init__(self, base_criterion, target_ratio=0.5, head_loss_ratio=2., diverse_ratio=0.1):
38 | super().__init__()
39 | self.base_criterion = base_criterion
40 | self.target_ratio = target_ratio
41 | self.head_loss_ratio = head_loss_ratio
42 | self.diverse_ratio = diverse_ratio
43 |
44 | def forward(self, outputs, y):
45 | '''
46 | head_select: (b, num_layers, num_head)
47 | '''
48 | assert len(outputs) >= 2
49 | x, head_select = outputs[:2]
50 | base_loss = self.base_criterion(x, y)
51 | head_mean = head_select.mean()
52 | flops_loss = (head_mean - self.target_ratio).abs().mean()
53 |
54 | head_mean = head_select.mean(0) # (num_layers, num_head)
55 | diverse_loss = (head_mean - self.target_ratio).abs().mean()
56 |
57 | head_loss = flops_loss + self.diverse_ratio * diverse_loss
58 |
59 | loss = base_loss + self.head_loss_ratio * head_loss
60 |
61 | return loss, dict(base_loss=base_loss, head_loss=head_loss)
62 |
63 |
64 | class AdaLoss(nn.Module):
65 | def __init__(self, base_criterion, head_target_ratio=0.5, layer_target_ratio=0.5, head_loss_ratio=2.,layer_loss_ratio=2., head_diverse_ratio=0.1, layer_diverse_ratio=0.1,
66 | head_entropy_weight=0.1, layer_entropy_weight=0.1,
67 | head_minimal_weight=0., head_minimal=0.,
68 | layer_minimal_weight=0., layer_minimal=0.,
69 | token_target_ratio=0.5, token_loss_ratio=2., token_minimal=0.1, token_minimal_weight=1.):
70 | super().__init__()
71 | self.base_criterion = base_criterion
72 | self.head_target_ratio = head_target_ratio
73 | self.layer_target_ratio = layer_target_ratio
74 |
75 | self.head_loss_ratio = head_loss_ratio
76 | self.layer_loss_ratio = layer_loss_ratio
77 |
78 | self.head_diverse_ratio = head_diverse_ratio
79 | self.layer_diverse_ratio = layer_diverse_ratio
80 |
81 | self.head_entropy_weight = head_entropy_weight
82 | self.layer_entropy_weight = layer_entropy_weight
83 |
84 | self.head_minimal_weight = head_minimal_weight
85 | self.head_minimal = head_minimal
86 |
87 | self.layer_minimal_weight = layer_minimal_weight
88 | self.layer_minimal = layer_minimal
89 |
90 | self.token_target_ratio = token_target_ratio
91 | self.token_loss_ratio = token_loss_ratio
92 | self.token_minimal = token_minimal
93 | self.token_minimal_weight = token_minimal_weight
94 |
95 | def forward(self, outputs, y):
96 | '''
97 | head_select: (b, num_layers, num_head)
98 | '''
99 | assert len(outputs) >= 3
100 | x, head_select, layer_select, token_select = outputs[:4]
101 | logits_set = outputs[-1]
102 |
103 | base_loss = self.base_criterion(x, y)
104 | layer_loss = self._get_layer_loss(x, layer_select, logits_set)
105 | head_loss = self._get_head_loss(x, head_select, logits_set, layer_select)
106 | token_loss = self._get_token_loss(x, token_select)
107 |
108 | loss = base_loss + self.head_loss_ratio * head_loss + self.layer_loss_ratio * layer_loss + self.token_loss_ratio * token_loss
109 |
110 | return loss, dict(base_loss=base_loss, head_loss=head_loss, layer_loss=layer_loss, token_loss=token_loss)
111 |
112 | def _get_head_loss(self, x, head_select, logits_set, layer_select):
113 | eps = 1e-6
114 | if head_select is not None :
115 | if layer_select is not None :
116 | block_select = layer_select.sum(-1, keepdim=True)
117 | head_select_mask = (block_select > 0).type_as(block_select)
118 | head_select_mask = head_select_mask.expand(-1,-1,head_select.shape[-1])
119 | assert head_select.shape == head_select_mask.shape
120 | else :
121 | head_select_mask = head_select.new_ones(head_select.shape)
122 | head_mean = (head_select * head_select_mask).sum() / (head_select_mask.sum() + eps)
123 | head_flops_loss = (head_mean - self.head_target_ratio).abs().mean()
124 |
125 | if self.head_diverse_ratio > 0 :
126 | # head_mean = head_select.mean(0) # (num_layers, num_head)
127 | head_mean = (head_select * head_select_mask).sum(0) / (head_select_mask.sum(0) + eps)
128 | head_diverse_loss = (head_mean - self.head_target_ratio).abs().mean()
129 | else :
130 | head_diverse_loss = 0
131 |
132 | if self.head_minimal_weight > 0 :
133 | # head_per_layer = head_select.sum(-1) #(b, num_layers)
134 | # head_minimal_loss = (1 - head_per_layer).clamp(min=0.).sum(-1).mean()
135 | # head_mean = head_select.mean(0) # (num_layers, num_head)
136 | head_mean = (head_select * head_select_mask).sum(0) / (head_select_mask.sum(0) + eps)
137 | head_minimal_loss = (self.head_minimal - head_mean).clamp(min=0.).sum()
138 | else :
139 | head_minimal_loss = 0
140 |
141 | if self.head_entropy_weight > 0 :
142 | head_select_logits = logits_set['head_select_logits']
143 | head_entropy = binaray_entropy(head_select_logits.sigmoid()).mean()
144 | else :
145 | head_entropy = 0
146 |
147 | head_loss = head_flops_loss + self.head_diverse_ratio * head_diverse_loss - self.head_entropy_weight * head_entropy \
148 | + self.head_minimal_weight * head_minimal_loss
149 | else :
150 | head_loss = x.new_zeros(1).mean()
151 |
152 | return head_loss
153 |
154 | def _get_layer_loss(self, x, layer_select, logits_set):
155 | if layer_select is not None :
156 | layer_mean = layer_select.mean()
157 | layer_flops_loss = (layer_mean - self.layer_target_ratio).abs().mean()
158 |
159 | if self.layer_diverse_ratio > 0 :
160 | layer_mean = layer_select.mean((0,-1))
161 | layer_diverse_loss = (layer_mean - self.layer_target_ratio).abs().mean()
162 | else :
163 | layer_diverse_loss = 0
164 |
165 | if self.layer_entropy_weight > 0 :
166 | layer_select_logits = logits_set['layer_select_logits']
167 | layer_entropy = binaray_entropy(layer_select_logits.sigmoid()).mean()
168 | else :
169 | layer_entropy = 0
170 |
171 | if self.layer_minimal_weight > 0 :
172 | layer_mean = layer_select.mean(0) #(num_layers, 2)
173 | layer_minimal_loss = (self.layer_minimal - layer_mean).clamp(min=0.).sum()
174 | else :
175 | layer_minimal_loss = 0
176 |
177 | layer_loss = layer_flops_loss + self.layer_diverse_ratio * layer_diverse_loss - self.layer_entropy_weight * layer_entropy \
178 | + self.layer_minimal_weight * layer_minimal_loss
179 | else :
180 | layer_loss = x.new_zeros(1).mean()
181 |
182 | return layer_loss
183 |
184 | def _get_token_loss(self, x, token_select):
185 | """
186 | token_select : tensor (b, num_layer, l)
187 |
188 | """
189 | if token_select is not None :
190 | token_mean = token_select.mean()
191 | # token_flops_loss = (token_mean - self.token_target_ratio).abs().mean()
192 | # token_flops_loss = (token_mean - self.token_target_ratio).clamp(min=0.).mean()
193 | token_flops_loss = ((token_mean - self.token_target_ratio)**2).mean()
194 |
195 | if self.token_minimal_weight > 0 :
196 | token_mean = token_select.mean(-1)
197 | token_minimal_loss = (self.token_minimal - token_mean).clamp(min=0.).sum()
198 | else :
199 | token_minimal_loss = 0
200 |
201 | token_loss = token_flops_loss + self.token_minimal_weight * token_minimal_loss
202 | else :
203 | token_loss = x.new_zeros(1).mean()
204 |
205 | return token_loss
206 |
207 | def soft_cross_entropy(predicts, targets):
208 | student_likelihood = torch.nn.functional.log_softmax(predicts, dim=-1)
209 | targets_prob = torch.nn.functional.softmax(targets, dim=-1)
210 | return (- targets_prob * student_likelihood).sum(-1).mean()
211 |
212 | # TODO : hard or soft distill loss
213 | class TeacherLoss(nn.Module):
214 | def __init__(self, teacher_model, base_criterion, kd_ratio=1., tau=5., attn_ratio=.5, hidden_ratio=.1, pred_ratio=.5, keep_layers=1):
215 | super().__init__()
216 | self.kd_ratio = kd_ratio
217 | print('self tau', tau)
218 | self.tau = tau
219 | self.base_criterion = base_criterion
220 | self.teacher_model = teacher_model
221 | self.mse_loss = nn.MSELoss(reduction='none')
222 | self.l1_loss = nn.L1Loss(reduction='none')
223 | self.attn_ratio = attn_ratio
224 | self.hidden_ratio = hidden_ratio
225 | # self.layer_ratio = layer_ratio
226 | self.pred_ratio = pred_ratio
227 | self.teacher_model.eval()
228 | self.keep_layers = keep_layers
229 |
230 | def forward(self, x, outputs, y):
231 | assert len(outputs) >= 5
232 | logits, head_select, layer_select, token_select, attn_list, hidden_list = outputs[:6]
233 | base_loss, meta_loss = self.base_criterion(outputs, y)
234 |
235 | with torch.no_grad():
236 | logits_teacher, attn_list_teacher, hidden_list_teacher = self.teacher_model(x, ret_attn_list=True)
237 |
238 | attn_loss = x.new_zeros(1).mean()
239 | hidden_loss = x.new_zeros(1).mean()
240 | if head_select is not None :
241 | head_select_start = len(attn_list) - head_select.shape[1]
242 | for i, (attn_s, attn_t, hidden_s, hidden_t) in enumerate(zip(attn_list, attn_list_teacher, hidden_list, hidden_list_teacher)) :
243 | assert attn_s.dim() == 4 and attn_t.dim() ==4
244 | if i >= self.keep_layers and layer_select is not None :
245 | this_select = layer_select[:,i-self.keep_layers].detach() #(b,2)
246 | this_select = (this_select > 0.5).float().unsqueeze(-1)
247 | attn_select = this_select[:,0]
248 | hidden_select = this_select[:,1]
249 | else :
250 | this_select = 1.
251 | attn_select = 1.
252 | hidden_select = 1.
253 |
254 | if head_select is not None and i >= head_select_start :
255 | attn_mask = head_select.detach()[:,i-head_select_start] #(b,head,1)
256 | if attn_mask.dim() == 2 :
257 | attn_mask = attn_mask.unsqueeze(-1)
258 | else :
259 | attn_mask = 1
260 |
261 | attn_s = attn_s[:,:,0] * attn_mask #(b,head,len)
262 | # TODO : normalize teacher attn
263 | attn_t = attn_t[:,:,0] * attn_mask
264 | hidden_s = hidden_s[:,0] #(b, c)
265 | hidden_t = hidden_t[:,0]
266 | # attn_s = torch.where(attn_s <=1e-2, attn_s.new_zeros(attn_s.shape), attn_s)
267 | # attn_t = torch.where(attn_t <=1e-2, attn_t.new_zeros(attn_t.shape), attn_t)
268 | # TODO : how to design attn loss
269 | if self.attn_ratio > 0 :
270 | this_attn_loss = self.l1_loss(attn_s, attn_t).sum(-1)
271 | this_attn_loss = (attn_select * this_attn_loss).mean()
272 | attn_loss +=this_attn_loss
273 | if self.hidden_ratio > 0 :
274 | this_hidden_loss = self.mse_loss(hidden_s, hidden_t)
275 | this_hidden_loss = (hidden_select * this_hidden_loss).mean()
276 | hidden_loss += this_hidden_loss
277 |
278 | if self.pred_ratio > 0 :
279 | T = self.tau
280 | pred_loss = soft_cross_entropy(logits/T, logits_teacher/T)
281 | else :
282 | pred_loss = 0
283 |
284 | loss = base_loss + self.attn_ratio*attn_loss+ self.hidden_ratio*hidden_loss + self.pred_ratio*pred_loss
285 | meta_loss.update(attn_loss=attn_loss, hidden_loss=hidden_loss, pred_loss=pred_loss)
286 | return loss, meta_loss
--------------------------------------------------------------------------------
/models/token_performer.py:
--------------------------------------------------------------------------------
1 | """
2 | Borrow from t2t-vit(https://github.com/yitu-opensource/T2T-ViT
3 | Take Performer as T2T Transformer
4 | """
5 | import math
6 | import torch
7 | import torch.nn as nn
8 |
9 | class Token_performer(nn.Module):
10 | def __init__(self, dim, in_dim, head_cnt=1, kernel_ratio=0.5, dp1=0.1, dp2 = 0.1):
11 | super().__init__()
12 | self.emb = in_dim * head_cnt # we use 1, so it is no need here
13 | self.kqv = nn.Linear(dim, 3 * self.emb)
14 | self.dp = nn.Dropout(dp1)
15 | self.proj = nn.Linear(self.emb, self.emb)
16 | self.head_cnt = head_cnt
17 | self.norm1 = nn.LayerNorm(dim)
18 | self.norm2 = nn.LayerNorm(self.emb)
19 | self.epsilon = 1e-8 # for stable in division
20 |
21 | self.mlp = nn.Sequential(
22 | nn.Linear(self.emb, 1 * self.emb),
23 | nn.GELU(),
24 | nn.Linear(1 * self.emb, self.emb),
25 | nn.Dropout(dp2),
26 | )
27 |
28 | self.m = int(self.emb * kernel_ratio)
29 | self.w = torch.randn(self.m, self.emb)
30 | self.w = nn.Parameter(nn.init.orthogonal_(self.w) * math.sqrt(self.m), requires_grad=False)
31 |
32 | def prm_exp(self, x):
33 | # part of the function is borrow from https://github.com/lucidrains/performer-pytorch
34 | # and Simo Ryu (https://github.com/cloneofsimo)
35 | # ==== positive random features for gaussian kernels ====
36 | # x = (B, T, hs)
37 | # w = (m, hs)
38 | # return : x : B, T, m
39 | # SM(x, y) = E_w[exp(w^T x - |x|/2) exp(w^T y - |y|/2)]
40 | # therefore return exp(w^Tx - |x|/2)/sqrt(m)
41 | xd = ((x * x).sum(dim=-1, keepdim=True)).repeat(1, 1, self.m) / 2
42 | wtx = torch.einsum('bti,mi->btm', x.float(), self.w)
43 |
44 | return torch.exp(wtx - xd) / math.sqrt(self.m)
45 |
46 | def single_attn(self, x):
47 | k, q, v = torch.split(self.kqv(x), self.emb, dim=-1)
48 | kp, qp = self.prm_exp(k), self.prm_exp(q) # (B, T, m), (B, T, m)
49 | D = torch.einsum('bti,bi->bt', qp, kp.sum(dim=1)).unsqueeze(dim=2) # (B, T, m) * (B, m) -> (B, T, 1)
50 | kptv = torch.einsum('bin,bim->bnm', v.float(), kp) # (B, emb, m)
51 | y = torch.einsum('bti,bni->btn', qp, kptv) / (D.repeat(1, 1, self.emb) + self.epsilon) # (B, T, emb)/Diag
52 | # skip connection
53 | y = v + self.dp(self.proj(y)) # same as token_transformer in T2T layer, use v as skip connection
54 |
55 | return y
56 |
57 | def forward(self, x):
58 | x = self.single_attn(self.norm1(x))
59 | x = x + self.mlp(self.norm2(x))
60 | return x
61 |
62 |
--------------------------------------------------------------------------------
/models/token_transformer.py:
--------------------------------------------------------------------------------
1 | """
2 | Borrow from t2t-vit(https://github.com/yitu-opensource/T2T-ViT)
3 | Take the standard Transformer as T2T Transformer
4 | """
5 | import torch.nn as nn
6 | from timm.models.layers import DropPath
7 | from .transformer_block import Mlp
8 |
9 | class Attention(nn.Module):
10 | def __init__(self, dim, num_heads=8, in_dim = None, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
11 | super().__init__()
12 | self.num_heads = num_heads
13 | self.in_dim = in_dim
14 | head_dim = dim // num_heads
15 | self.scale = qk_scale or head_dim ** -0.5
16 |
17 | self.qkv = nn.Linear(dim, in_dim * 3, bias=qkv_bias)
18 | self.attn_drop = nn.Dropout(attn_drop)
19 | self.proj = nn.Linear(in_dim, in_dim)
20 | self.proj_drop = nn.Dropout(proj_drop)
21 |
22 | def forward(self, x):
23 | B, N, C = x.shape
24 |
25 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.in_dim).permute(2, 0, 3, 1, 4)
26 | q, k, v = qkv[0], qkv[1], qkv[2]
27 |
28 | attn = (q * self.scale) @ k.transpose(-2, -1)
29 | attn = attn.softmax(dim=-1)
30 | attn = self.attn_drop(attn)
31 |
32 | x = (attn @ v).transpose(1, 2).reshape(B, N, self.in_dim)
33 | x = self.proj(x)
34 | x = self.proj_drop(x)
35 |
36 | # skip connection
37 | x = v.squeeze(1) + x # because the original x has different size with current x, use v to do skip connection
38 |
39 | return x
40 |
41 | class Token_transformer(nn.Module):
42 |
43 | def __init__(self, dim, in_dim, num_heads, mlp_ratio=1., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
44 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
45 | super().__init__()
46 | self.norm1 = norm_layer(dim)
47 | self.attn = Attention(
48 | dim, in_dim=in_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
49 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
50 | self.norm2 = norm_layer(in_dim)
51 | self.mlp = Mlp(in_features=in_dim, hidden_features=int(in_dim*mlp_ratio), out_features=in_dim, act_layer=act_layer, drop=drop)
52 |
53 | def forward(self, x):
54 | x = self.attn(self.norm1(x))
55 | x = x + self.drop_path(self.mlp(self.norm2(x)))
56 | return x
57 |
--------------------------------------------------------------------------------
/models/transformer_block.py:
--------------------------------------------------------------------------------
1 | """
2 | Borrow from timm(https://github.com/rwightman/pytorch-image-models)
3 | """
4 | import torch
5 | import torch.nn as nn
6 | import numpy as np
7 | from timm.models.layers import DropPath
8 |
9 | class Mlp(nn.Module):
10 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
11 | super().__init__()
12 | out_features = out_features or in_features
13 | hidden_features = hidden_features or in_features
14 | self.fc1 = nn.Linear(in_features, hidden_features)
15 | self.act = act_layer()
16 | self.fc2 = nn.Linear(hidden_features, out_features)
17 | self.drop = nn.Dropout(drop)
18 |
19 | def forward(self, x):
20 | x = self.fc1(x)
21 | x = self.act(x)
22 | x = self.drop(x)
23 | x = self.fc2(x)
24 | x = self.drop(x)
25 | return x
26 |
27 | class Attention(nn.Module):
28 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
29 | super().__init__()
30 | self.num_heads = num_heads
31 | head_dim = dim // num_heads
32 |
33 | self.scale = qk_scale or head_dim ** -0.5
34 |
35 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
36 | self.attn_drop = nn.Dropout(attn_drop)
37 | self.proj = nn.Linear(dim, dim)
38 | self.proj_drop = nn.Dropout(proj_drop)
39 |
40 | def forward(self, x, ret_attn=False):
41 | B, N, C = x.shape
42 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
43 | q, k, v = qkv[0], qkv[1], qkv[2]
44 |
45 | attn = (q @ k.transpose(-2, -1)) * self.scale
46 | attn = attn.softmax(dim=-1)
47 | attn = self.attn_drop(attn)
48 |
49 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
50 | x = self.proj(x)
51 | x = self.proj_drop(x)
52 | if ret_attn :
53 | return x, attn
54 | return x
55 |
56 | class Block(nn.Module):
57 |
58 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
59 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
60 | super().__init__()
61 | if norm_layer == nn.GroupNorm :
62 | self.norm1 = norm_layer(num_heads, dim)
63 | self.norm2 = norm_layer(num_heads, dim)
64 | else:
65 | self.norm1 = norm_layer(dim)
66 | self.norm2 = norm_layer(dim)
67 | self.attn = Attention(
68 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
69 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
70 | mlp_hidden_dim = int(dim * mlp_ratio)
71 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
72 |
73 | def forward(self, x, ret_attn=False):
74 | x_attn, attn = self.attn(self.norm1(x), ret_attn=ret_attn)
75 | x = x + self.drop_path(x_attn)
76 | x = x + self.drop_path(self.mlp(self.norm2(x)))
77 | if ret_attn :
78 | return x, attn
79 | return x
80 |
81 |
82 | def get_sinusoid_encoding(n_position, d_hid):
83 | ''' Sinusoid position encoding table '''
84 |
85 | def get_position_angle_vec(position):
86 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
87 |
88 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
89 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
90 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
91 |
92 | return torch.FloatTensor(sinusoid_table).unsqueeze(0)
93 |
--------------------------------------------------------------------------------
/saver.py:
--------------------------------------------------------------------------------
1 | """ Checkpoint Saver
2 |
3 | Track top-n training checkpoints and maintain recovery checkpoints on specified intervals.
4 |
5 | Hacked together by / Copyright 2020 Ross Wightman
6 | """
7 |
8 | import glob
9 | import operator
10 | import os
11 | import logging
12 |
13 | import torch
14 |
15 | from timm.utils.model import unwrap_model, get_state_dict
16 | from shutil import copyfile
17 |
18 | _logger = logging.getLogger(__name__)
19 |
20 |
21 | class MyCheckpointSaver:
22 | def __init__(
23 | self,
24 | model,
25 | optimizer,
26 | args=None,
27 | model_ema=None,
28 | amp_scaler=None,
29 | checkpoint_prefix='checkpoint',
30 | recovery_prefix='recovery',
31 | checkpoint_dir='',
32 | recovery_dir='',
33 | decreasing=False,
34 | max_history=10,
35 | unwrap_fn=unwrap_model):
36 |
37 | # objects to save state_dicts of
38 | self.model = model
39 | self.optimizer = optimizer
40 | self.args = args
41 | self.model_ema = model_ema
42 | self.amp_scaler = amp_scaler
43 |
44 | # state
45 | self.checkpoint_files = [] # (filename, metric) tuples in order of decreasing betterness
46 | self.best_epoch = None
47 | self.best_metric = None
48 | self.curr_recovery_file = ''
49 | self.last_recovery_file = ''
50 |
51 | # config
52 | self.checkpoint_dir = checkpoint_dir
53 | self.recovery_dir = recovery_dir
54 | self.save_prefix = checkpoint_prefix
55 | self.recovery_prefix = recovery_prefix
56 | self.extension = '.pth.tar'
57 | self.decreasing = decreasing # a lower metric is better if True
58 | self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs
59 | self.max_history = max_history
60 | self.unwrap_fn = unwrap_fn
61 | assert self.max_history >= 1
62 |
63 | def save_checkpoint(self, epoch, metric=None):
64 | assert epoch >= 0
65 | tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension)
66 | last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension)
67 | self._save(tmp_save_path, epoch, metric)
68 | if os.path.exists(last_save_path):
69 | os.remove(last_save_path) # required for Windows support.
70 | os.rename(tmp_save_path, last_save_path)
71 | worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None
72 | if (len(self.checkpoint_files) < self.max_history
73 | or metric is None or self.cmp(metric, worst_file[1])):
74 | if len(self.checkpoint_files) >= self.max_history:
75 | self._cleanup_checkpoints(1)
76 | filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension
77 | save_path = os.path.join(self.checkpoint_dir, filename)
78 | copyfile(last_save_path, save_path)
79 | self.checkpoint_files.append((save_path, metric))
80 | self.checkpoint_files = sorted(
81 | self.checkpoint_files, key=lambda x: x[1],
82 | reverse=not self.decreasing) # sort in descending order if a lower metric is not better
83 |
84 | checkpoints_str = "Current checkpoints:\n"
85 | for c in self.checkpoint_files:
86 | checkpoints_str += ' {}\n'.format(c)
87 | _logger.info(checkpoints_str)
88 |
89 | if metric is not None and (self.best_metric is None or self.cmp(metric, self.best_metric)):
90 | self.best_epoch = epoch
91 | self.best_metric = metric
92 | best_save_path = os.path.join(self.checkpoint_dir, 'model_best' + self.extension)
93 | if os.path.exists(best_save_path):
94 | os.remove(best_save_path)
95 | copyfile(last_save_path, best_save_path)
96 |
97 | return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch)
98 |
99 | def _save(self, save_path, epoch, metric=None):
100 | save_state = {
101 | 'epoch': epoch,
102 | 'arch': type(self.model).__name__.lower(),
103 | 'state_dict': get_state_dict(self.model, self.unwrap_fn),
104 | 'optimizer': self.optimizer.state_dict(),
105 | 'version': 2, # version < 2 increments epoch before save
106 | }
107 | if self.args is not None:
108 | save_state['arch'] = self.args.model
109 | save_state['args'] = self.args
110 | if self.amp_scaler is not None:
111 | save_state[self.amp_scaler.state_dict_key] = self.amp_scaler.state_dict()
112 | if self.model_ema is not None:
113 | save_state['state_dict_ema'] = get_state_dict(self.model_ema, self.unwrap_fn)
114 | if metric is not None:
115 | save_state['metric'] = metric
116 | torch.save(save_state, save_path)
117 |
118 | def _cleanup_checkpoints(self, trim=0):
119 | trim = min(len(self.checkpoint_files), trim)
120 | delete_index = self.max_history - trim
121 | if delete_index <= 0 or len(self.checkpoint_files) <= delete_index:
122 | return
123 | to_delete = self.checkpoint_files[delete_index:]
124 | for d in to_delete:
125 | try:
126 | _logger.debug("Cleaning checkpoint: {}".format(d))
127 | os.remove(d[0])
128 | except Exception as e:
129 | _logger.error("Exception '{}' while deleting checkpoint".format(e))
130 | self.checkpoint_files = self.checkpoint_files[:delete_index]
131 |
132 | def save_recovery(self, epoch, batch_idx=0):
133 | assert epoch >= 0
134 | filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension
135 | save_path = os.path.join(self.recovery_dir, filename)
136 | self._save(save_path, epoch)
137 | if os.path.exists(self.last_recovery_file):
138 | try:
139 | _logger.debug("Cleaning recovery: {}".format(self.last_recovery_file))
140 | os.remove(self.last_recovery_file)
141 | except Exception as e:
142 | _logger.error("Exception '{}' while removing {}".format(e, self.last_recovery_file))
143 | self.last_recovery_file = self.curr_recovery_file
144 | self.curr_recovery_file = save_path
145 |
146 | def find_recovery(self):
147 | recovery_path = os.path.join(self.recovery_dir, self.recovery_prefix)
148 | files = glob.glob(recovery_path + '*' + self.extension)
149 | files = sorted(files)
150 | if len(files):
151 | return files[0]
152 | else:
153 | return ''
154 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | '''
2 | Borrow from t2t-vit(https://github.com/yitu-opensource/T2T-ViT
3 | - load_for_transfer_learning: load pretrained paramters to model in transfer learning
4 | - get_mean_and_std: calculate the mean and std value of dataset.
5 | - msr_init: net parameter initialization.
6 | - progress_bar: progress bar mimic xlua.progress.
7 | '''
8 | import os
9 | import sys
10 | import time
11 | import torch
12 |
13 | import torch.nn as nn
14 | import torch.nn.init as init
15 | import logging
16 | import os
17 | from collections import OrderedDict
18 |
19 | _logger = logging.getLogger(__name__)
20 |
21 | def convert_qkv(origin_dict):
22 | model_dict = OrderedDict()
23 | for k,v in origin_dict.items():
24 | if 'qkv' in k :
25 | #bias
26 | if v.dim() == 1 :
27 | dim = v.shape[-1] // 3
28 | tmp_bias = v
29 | b_q, b_k, b_v = [None] *3 if tmp_bias is None else [tmp_bias[i*dim:(i+1)*dim] for i in range(3)]
30 | print('bias q k v', b_q.shape, b_k.shape, b_v.shape)
31 | model_dict[k.replace('qkv', 'query')] = b_q
32 | model_dict[k.replace('qkv', 'key')] = b_k
33 | model_dict[k.replace('qkv', 'value')] = b_v
34 | else :
35 | dim = v.shape[-1]
36 | tmp_weight = v
37 | w_q, w_k, w_v = [tmp_weight[i*dim:(i+1)*dim] for i in range(3)]
38 | model_dict[k.replace('qkv', 'query')] = w_q
39 | model_dict[k.replace('qkv', 'key')] = w_k
40 | model_dict[k.replace('qkv', 'value')] = w_v
41 | else :
42 | model_dict[k] = v
43 |
44 | return model_dict
45 |
46 | def ada_load_state_dict(checkpoint_path, model, use_qkv=False, strict=True):
47 | if not os.path.isfile(checkpoint_path) :
48 | _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
49 | raise FileNotFoundError()
50 |
51 | checkpoint = torch.load(checkpoint_path, map_location='cpu')
52 | if 'state_dict_ema' in checkpoint :
53 | checkpoint = checkpoint['state_dict_ema']
54 | elif 'state_dict' in checkpoint :
55 | checkpoint = checkpoint['state_dict']
56 | elif 'model' in checkpoint :
57 | checkpoint = checkpoint['model'] # load deit type model
58 |
59 | if not use_qkv :
60 | checkpoint = convert_qkv(checkpoint)
61 |
62 | new_state_dict = OrderedDict()
63 | for k, v in checkpoint.items():
64 | # strip `module.` prefix
65 | name = k[7:] if k.startswith('module') else k
66 | new_state_dict[name] = v
67 | checkpoint = new_state_dict
68 |
69 | info = model.load_state_dict(checkpoint, strict=strict)
70 | if not strict :
71 | print('state dict load info', info)
72 |
73 |
74 | def deit_load_state_dict(checkpoint_path, model, strict=True):
75 | if not os.path.isfile(checkpoint_path) :
76 | _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
77 | raise FileNotFoundError()
78 |
79 | checkpoint = torch.load(checkpoint_path, map_location='cpu')
80 | checkpoint = checkpoint['model']
81 |
82 | new_state_dict = OrderedDict()
83 | for k, v in checkpoint.items():
84 | # strip `module.` prefix
85 | name = k[7:] if k.startswith('module') else k
86 | new_state_dict[name] = v
87 | checkpoint = new_state_dict
88 |
89 | info = model.load_state_dict(checkpoint, strict=strict)
90 | if not strict :
91 | print('state dict load info', info)
92 |
93 |
94 | def resize_pos_embed(posemb, posemb_new): # example: 224:(14x14+1)-> 384: (24x24+1)
95 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from
96 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
97 | ntok_new = posemb_new.shape[1]
98 | if True:
99 | posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] # posemb_tok is for cls token, posemb_grid for the following tokens
100 | ntok_new -= 1
101 | else:
102 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
103 | gs_old = int(math.sqrt(len(posemb_grid))) # 14
104 | gs_new = int(math.sqrt(ntok_new)) # 24
105 | _logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new)
106 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) # [1, 196, dim]->[1, 14, 14, dim]->[1, dim, 14, 14]
107 | posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bicubic') # [1, dim, 14, 14] -> [1, dim, 24, 24]
108 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1) # [1, dim, 24, 24] -> [1, 24*24, dim]
109 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) # [1, 24*24+1, dim]
110 | return posemb
111 |
112 | def load_state_dict(checkpoint_path, model, use_ema=False, num_classes=1000, del_posemb=False):
113 | if checkpoint_path and os.path.isfile(checkpoint_path):
114 | checkpoint = torch.load(checkpoint_path, map_location='cpu')
115 | state_dict_key = 'state_dict'
116 | if isinstance(checkpoint, dict):
117 | if use_ema and 'state_dict_ema' in checkpoint:
118 | state_dict_key = 'state_dict_ema'
119 | if state_dict_key and state_dict_key in checkpoint:
120 | new_state_dict = OrderedDict()
121 | for k, v in checkpoint[state_dict_key].items():
122 | # strip `module.` prefix
123 | name = k[7:] if k.startswith('module') else k
124 | new_state_dict[name] = v
125 | state_dict = new_state_dict
126 | else:
127 | state_dict = checkpoint
128 | _logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path))
129 | if num_classes != 1000:
130 | # completely discard fully connected for all other differences between pretrained and created model
131 | del state_dict['head' + '.weight']
132 | del state_dict['head' + '.bias']
133 |
134 | if del_posemb==True:
135 | del state_dict['pos_embed']
136 |
137 | old_posemb = state_dict['pos_embed']
138 | if model.pos_embed.shape != old_posemb.shape: # need resize the position embedding by interpolate
139 | new_posemb = resize_pos_embed(old_posemb, model.pos_embed)
140 | state_dict['pos_embed'] = new_posemb
141 |
142 | return state_dict
143 | else:
144 | _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
145 | raise FileNotFoundError()
146 |
147 |
148 |
149 | def load_for_transfer_learning(model, checkpoint_path, use_ema=False, strict=True, num_classes=1000):
150 | state_dict = load_state_dict(checkpoint_path, use_ema, num_classes)
151 | model.load_state_dict(state_dict, strict=strict)
152 |
153 |
154 | def get_mean_and_std(dataset):
155 | '''Compute the mean and std value of dataset.'''
156 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
157 | mean = torch.zeros(3)
158 | std = torch.zeros(3)
159 | print('==> Computing mean and std..')
160 | for inputs, targets in dataloader:
161 | for i in range(3):
162 | mean[i] += inputs[:,i,:,:].mean()
163 | std[i] += inputs[:,i,:,:].std()
164 | mean.div_(len(dataset))
165 | std.div_(len(dataset))
166 | return mean, std
167 |
168 | def init_params(net):
169 | '''Init layer parameters.'''
170 | for m in net.modules():
171 | if isinstance(m, nn.Conv2d):
172 | init.kaiming_normal(m.weight, mode='fan_out')
173 | if m.bias:
174 | init.constant(m.bias, 0)
175 | elif isinstance(m, nn.BatchNorm2d):
176 | init.constant(m.weight, 1)
177 | init.constant(m.bias, 0)
178 | elif isinstance(m, nn.Linear):
179 | init.normal(m.weight, std=1e-3)
180 | if m.bias:
181 | init.constant(m.bias, 0)
182 |
183 |
184 | # _, term_width = os.popen('stty size', 'r').read().split()
185 | # term_width = int(term_width)
186 | term_width = 100 # since we're not using transfer_learning.py or progress bar, just set a constant to make no-terminal job submission happy
187 |
188 | TOTAL_BAR_LENGTH = 65.
189 | last_time = time.time()
190 | begin_time = last_time
191 | def progress_bar(current, total, msg=None):
192 | global last_time, begin_time
193 | if current == 0:
194 | begin_time = time.time() # Reset for new bar.
195 |
196 | cur_len = int(TOTAL_BAR_LENGTH*current/total)
197 | rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
198 |
199 | sys.stdout.write(' [')
200 | for i in range(cur_len):
201 | sys.stdout.write('=')
202 | sys.stdout.write('>')
203 | for i in range(rest_len):
204 | sys.stdout.write('.')
205 | sys.stdout.write(']')
206 |
207 | cur_time = time.time()
208 | step_time = cur_time - last_time
209 | last_time = cur_time
210 | tot_time = cur_time - begin_time
211 |
212 | L = []
213 | L.append(' Step: %s' % format_time(step_time))
214 | L.append(' | Tot: %s' % format_time(tot_time))
215 | if msg:
216 | L.append(' | ' + msg)
217 |
218 | msg = ''.join(L)
219 | sys.stdout.write(msg)
220 | for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
221 | sys.stdout.write(' ')
222 |
223 | # Go back to the center of the bar.
224 | for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
225 | sys.stdout.write('\b')
226 | sys.stdout.write(' %d/%d ' % (current+1, total))
227 |
228 | if current < total-1:
229 | sys.stdout.write('\r')
230 | else:
231 | sys.stdout.write('\n')
232 | sys.stdout.flush()
233 |
234 | def format_time(seconds):
235 | days = int(seconds / 3600/24)
236 | seconds = seconds - days*3600*24
237 | hours = int(seconds / 3600)
238 | seconds = seconds - hours*3600
239 | minutes = int(seconds / 60)
240 | seconds = seconds - minutes*60
241 | secondsf = int(seconds)
242 | seconds = seconds - secondsf
243 | millis = int(seconds*1000)
244 |
245 | f = ''
246 | i = 1
247 | if days > 0:
248 | f += str(days) + 'D'
249 | i += 1
250 | if hours > 0 and i <= 2:
251 | f += str(hours) + 'h'
252 | i += 1
253 | if minutes > 0 and i <= 2:
254 | f += str(minutes) + 'm'
255 | i += 1
256 | if secondsf > 0 and i <= 2:
257 | f += str(secondsf) + 's'
258 | i += 1
259 | if millis > 0 and i <= 2:
260 | f += str(millis) + 'ms'
261 | i += 1
262 | if f == '':
263 | f = '0ms'
264 | return f
265 |
--------------------------------------------------------------------------------