├── LICENSE
├── README.md
├── configs
├── deit_b_transmix.yaml
└── deit_s_transmix.yaml
├── distributed_train.sh
├── pic1.png
├── pic2.png
├── requirements.txt
├── timm
└── models
│ └── vision_transformer.py
├── train.py
└── transmix.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2022 ByteDance
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TransMix: Attend to Mix for Vision Transformers
2 |
3 | This repository includes the official project for the paper: [*TransMix: Attend to Mix for Vision Transformers*](https://arxiv.org/abs/2111.09833), CVPR 2022
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 | # Key Feature
14 |
15 | Improve your Vision Transformer (ViT) by ~1% on ImageNet with minimal computational cost and a simple ```--transmix```.
16 |
17 | # Getting Started
18 |
19 | First, clone the repo:
20 | ```shell
21 | git clone https://github.com/Beckschen/TransMix.git
22 | ```
23 |
24 | Then, you need to install the required packages including: [Pytorch](https://pytorch.org/) version 1.7.1,
25 | [torchvision](https://pytorch.org/vision/stable/index.html) version 0.8.2,
26 | [Timm](https://github.com/rwightman/pytorch-image-models) version 0.5.4
27 | and ```pyyaml```. To install all these packages, simply run
28 | ```
29 | pip3 install -r requirements.txt
30 | ```
31 |
32 | Download and extract the [ImageNet](https://imagenet.stanford.edu/) dataset to ```data``` folder. Suppose you're using
33 | 8 GPUs for training, then simply run
34 | ```shell
35 | bash ./distributed_train.sh 8 data/ --config $YOUR_CONFIG_PATH_HERE
36 | ```
37 |
38 | By default, all our config files have enabled the training with TransMix.
39 | If you want to enable TransMix during the training of your own model,
40 | you can add a ```--transmix``` in your training script. For example:
41 | ```shell
42 | python3 -m torch.distributed.launch --nproc_per_node=8 train.py data/ --config $YOUR_CONFIG_PATH_HERE --transmix
43 | ```
44 |
45 | Or you can simply specify ```transmix: True``` in your ```yaml``` config file like what we did in [deit_s_transmix](configs/deit_s_transmix.yaml).
46 |
47 | To evaluate your model trained with TransMix, please refer to [timm](https://github.com/rwightman/pytorch-image-models#train-validation-inference-scripts).
48 | You can also find your validation accuracy during training.
49 |
50 | # Model Zoo
51 |
52 | Coming soon!
53 |
54 | ## Acknowledgement
55 | This repository is built using the [Timm](https://github.com/rwightman/pytorch-image-models) library and
56 | the [DeiT](https://github.com/facebookresearch/deit) repository.
57 |
58 | ## License
59 | This repository is released under the Apache 2.0 license as found in the [LICENSE](LICENSE) file.
60 |
61 | ## Cite This Paper
62 | If you find our code helpful for your research, please using the following bibtex to cite our paper:
63 |
64 | ```
65 | @InProceedings{transmix,
66 | title = {TransMix: Attend to Mix for Vision Transformers},
67 | author = {Chen, Jie-Neng and Sun, Shuyang and He, Ju and Torr, Philip and Yuille, Alan and Bai, Song},
68 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
69 | month = {June},
70 | year = {2022}
71 | }
72 | ```
73 |
--------------------------------------------------------------------------------
/configs/deit_b_transmix.yaml:
--------------------------------------------------------------------------------
1 | model: "deit_base_patch16_224_return_attn"
2 |
3 | img_size: 224
4 | decay_epochs: 30
5 | opt: adamw
6 | num_classes: 1000
7 | mixup: 0.8
8 | cutmix: 1.0
9 | drop_path: 0.1
10 | dist_bn: ""
11 | model_ema: True
12 | aa: rand-m9-mstd0.5-inc1
13 | pin_mem: False
14 | model_ema_decay: 0.99996
15 | no_prefetcher: True
16 | transmix: True # enable transmix
17 | mixup_switch_prob: 0.8
18 | min_lr: 1e-5
19 | lr: 1e-3
20 | warmup_lr: 1e-6
21 | weight_decay: 5e-2
22 | warmup_epochs: 5
23 | workers: 8
24 | total_batch_size: 256
25 |
--------------------------------------------------------------------------------
/configs/deit_s_transmix.yaml:
--------------------------------------------------------------------------------
1 | model: deit_small_patch16_224_return_attn
2 | warmup_lr: 1e-6
3 | img_size: 224
4 | decay_epochs: 30
5 | opt: adamw
6 | num_classes: 1000
7 | mixup: 0.8
8 | cutmix: 1.0
9 | drop_path: 0.1
10 | dist_bn: ""
11 | model_ema: True
12 | aa: rand-m9-mstd0.5-inc1
13 | pin_mem: False
14 | model_ema_decay: 0.99996
15 | no_prefetcher: True
16 | transmix: True # enable transmix
17 | mixup-switch-prob: 0.8
18 | min_lr: 1e-5
19 | lr: 1e-3
20 | weight_decay: 3e-2
21 | warmup_epochs: 20
22 | workers: 8
23 | total_batch_size: 256
24 |
--------------------------------------------------------------------------------
/distributed_train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | NUM_PROC=$1
3 | shift
4 | python3 -m torch.distributed.launch --nproc_per_node=$NUM_PROC train.py "$@"
5 |
6 |
--------------------------------------------------------------------------------
/pic1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Beckschen/TransMix/0e4d31eb34772f9d12cc450678c4bf4ca89ba828/pic1.png
--------------------------------------------------------------------------------
/pic2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Beckschen/TransMix/0e4d31eb34772f9d12cc450678c4bf4ca89ba828/pic2.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.7.1
2 | torchvision==0.8.2
3 | timm==0.5.4
4 | pyyaml
5 |
--------------------------------------------------------------------------------
/timm/models/vision_transformer.py:
--------------------------------------------------------------------------------
1 | """ Vision Transformer (ViT) in PyTorch
2 |
3 | A PyTorch implement of Vision Transformers as described in:
4 |
5 | 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
6 | - https://arxiv.org/abs/2010.11929
7 |
8 | `How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
9 | - https://arxiv.org/abs/2106.10270
10 |
11 | The official jax code is released and available at https://github.com/google-research/vision_transformer
12 |
13 | DeiT model defs and weights from https://github.com/facebookresearch/deit,
14 | paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
15 |
16 | Acknowledgments:
17 | * The paper authors for releasing code and weights, thanks!
18 | * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
19 | for some einops/einsum fun
20 | * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
21 | * Bert reference code checks against Huggingface Transformers and Tensorflow Bert
22 |
23 | Hacked together by / Copyright 2021 Ross Wightman
24 | """
25 | import math
26 | import logging
27 | from functools import partial
28 | from collections import OrderedDict
29 | from copy import deepcopy
30 |
31 | import torch
32 | import torch.nn as nn
33 | import torch.nn.functional as F
34 |
35 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
36 | from timm.models.helpers import build_model_with_cfg, named_apply, adapt_input_conv
37 | from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_
38 | from timm.models.registry import register_model
39 |
40 | _logger = logging.getLogger(__name__)
41 |
42 |
43 | def _cfg(url='', **kwargs):
44 | return {
45 | 'url': url,
46 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
47 | 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
48 | 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
49 | 'first_conv': 'patch_embed.proj', 'classifier': 'head',
50 | **kwargs
51 | }
52 |
53 |
54 | default_cfgs = {
55 | # patch models (weights from official Google JAX impl)
56 | 'vit_tiny_patch16_224': _cfg(
57 | url='https://storage.googleapis.com/vit_models/augreg/'
58 | 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
59 | 'vit_tiny_patch16_384': _cfg(
60 | url='https://storage.googleapis.com/vit_models/augreg/'
61 | 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
62 | input_size=(3, 384, 384), crop_pct=1.0),
63 | 'vit_small_patch32_224': _cfg(
64 | url='https://storage.googleapis.com/vit_models/augreg/'
65 | 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
66 | 'vit_small_patch32_384': _cfg(
67 | url='https://storage.googleapis.com/vit_models/augreg/'
68 | 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
69 | input_size=(3, 384, 384), crop_pct=1.0),
70 | 'vit_small_patch16_224': _cfg(
71 | url='https://storage.googleapis.com/vit_models/augreg/'
72 | 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
73 | 'vit_small_patch16_384': _cfg(
74 | url='https://storage.googleapis.com/vit_models/augreg/'
75 | 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
76 | input_size=(3, 384, 384), crop_pct=1.0),
77 | 'vit_base_patch32_224': _cfg(
78 | url='https://storage.googleapis.com/vit_models/augreg/'
79 | 'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'),
80 | 'vit_base_patch32_384': _cfg(
81 | url='https://storage.googleapis.com/vit_models/augreg/'
82 | 'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
83 | input_size=(3, 384, 384), crop_pct=1.0),
84 | 'vit_base_patch16_224': _cfg(
85 | url='https://storage.googleapis.com/vit_models/augreg/'
86 | 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
87 | 'vit_base_patch16_384': _cfg(
88 | url='https://storage.googleapis.com/vit_models/augreg/'
89 | 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
90 | input_size=(3, 384, 384), crop_pct=1.0),
91 | 'vit_base_patch8_224': _cfg(
92 | url='https://storage.googleapis.com/vit_models/augreg/'
93 | 'B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
94 | 'vit_large_patch32_224': _cfg(
95 | url='', # no official model weights for this combo, only for in21k
96 | ),
97 | 'vit_large_patch32_384': _cfg(
98 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
99 | input_size=(3, 384, 384), crop_pct=1.0),
100 | 'vit_large_patch16_224': _cfg(
101 | url='https://storage.googleapis.com/vit_models/augreg/'
102 | 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'),
103 | 'vit_large_patch16_384': _cfg(
104 | url='https://storage.googleapis.com/vit_models/augreg/'
105 | 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz',
106 | input_size=(3, 384, 384), crop_pct=1.0),
107 |
108 | # patch models, imagenet21k (weights from official Google JAX impl)
109 | 'vit_tiny_patch16_224_in21k': _cfg(
110 | url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
111 | num_classes=21843),
112 | 'vit_small_patch32_224_in21k': _cfg(
113 | url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
114 | num_classes=21843),
115 | 'vit_small_patch16_224_in21k': _cfg(
116 | url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz',
117 | num_classes=21843),
118 | 'vit_base_patch32_224_in21k': _cfg(
119 | url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz',
120 | num_classes=21843),
121 | 'vit_base_patch16_224_in21k': _cfg(
122 | url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
123 | num_classes=21843),
124 | 'vit_base_patch8_224_in21k': _cfg(
125 | url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz',
126 | num_classes=21843),
127 | 'vit_large_patch32_224_in21k': _cfg(
128 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
129 | num_classes=21843),
130 | 'vit_large_patch16_224_in21k': _cfg(
131 | url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz',
132 | num_classes=21843),
133 | 'vit_huge_patch14_224_in21k': _cfg(
134 | url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz',
135 | hf_hub='timm/vit_huge_patch14_224_in21k',
136 | num_classes=21843),
137 |
138 | # SAM trained models (https://arxiv.org/abs/2106.01548)
139 | 'vit_base_patch32_sam_224': _cfg(
140 | url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz'),
141 | 'vit_base_patch16_sam_224': _cfg(
142 | url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz'),
143 |
144 | # deit models (FB weights)
145 | 'deit_tiny_patch16_224': _cfg(
146 | url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth',
147 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
148 | 'deit_small_patch16_224': _cfg(
149 | url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth',
150 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
151 | 'deit_base_patch16_224': _cfg(
152 | url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',
153 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
154 | 'deit_base_patch16_384': _cfg(
155 | url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
156 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0),
157 | 'deit_tiny_distilled_patch16_224': _cfg(
158 | url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',
159 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
160 | 'deit_small_distilled_patch16_224': _cfg(
161 | url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
162 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
163 | 'deit_base_distilled_patch16_224': _cfg(
164 | url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',
165 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')),
166 | 'deit_base_distilled_patch16_384': _cfg(
167 | url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
168 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0,
169 | classifier=('head', 'head_dist')),
170 |
171 | # ViT ImageNet-21K-P pretraining by MILL
172 | 'vit_base_patch16_224_miil_in21k': _cfg(
173 | url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth',
174 | mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221,
175 | ),
176 | 'vit_base_patch16_224_miil': _cfg(
177 | url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm'
178 | '/vit_base_patch16_224_1k_miil_84_4.pth',
179 | mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear',
180 | ),
181 | }
182 |
183 |
184 | class Attention(nn.Module):
185 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., return_attn=False):
186 | super().__init__()
187 | self.num_heads = num_heads
188 | head_dim = dim // num_heads
189 | self.scale = head_dim ** -0.5
190 |
191 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
192 | self.attn_drop = nn.Dropout(attn_drop)
193 | self.proj = nn.Linear(dim, dim)
194 | self.proj_drop = nn.Dropout(proj_drop)
195 |
196 | self.return_attn = return_attn
197 |
198 | def forward(self, x):
199 | B, N, C = x.shape
200 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
201 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
202 |
203 | attn = (q @ k.transpose(-2, -1)) * self.scale
204 | attn = attn.softmax(dim=-1)
205 | attn_softmax = attn.detach().clone()
206 | attn = self.attn_drop(attn)
207 |
208 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
209 | x = self.proj(x)
210 | x = self.proj_drop(x)
211 | if self.return_attn:
212 | return x, attn_softmax
213 | return x
214 |
215 |
216 | class Block(nn.Module):
217 |
218 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
219 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, return_attn=False):
220 | super().__init__()
221 | self.norm1 = norm_layer(dim)
222 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, return_attn=return_attn)
223 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
224 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
225 | self.norm2 = norm_layer(dim)
226 | mlp_hidden_dim = int(dim * mlp_ratio)
227 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
228 | self.return_attn = return_attn # modified by jchen
229 | def forward(self, x):
230 | if self.return_attn:
231 | res = x
232 | x, attn = self.attn(self.norm1(x))
233 | x = res + self.drop_path(x)
234 | x = x + self.drop_path(self.mlp(self.norm2(x)))
235 | return x, attn
236 |
237 | x = x + self.drop_path(self.attn(self.norm1(x)))
238 | x = x + self.drop_path(self.mlp(self.norm2(x)))
239 | return x
240 |
241 |
242 | class VisionTransformer(nn.Module):
243 | """ Vision Transformer
244 |
245 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
246 | - https://arxiv.org/abs/2010.11929
247 |
248 | Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
249 | - https://arxiv.org/abs/2012.12877
250 | """
251 |
252 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
253 | num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False,
254 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None,
255 | act_layer=None, weight_init='', return_attn=False):
256 | """
257 | Args:
258 | img_size (int, tuple): input image size
259 | patch_size (int, tuple): patch size
260 | in_chans (int): number of input channels
261 | num_classes (int): number of classes for classification head
262 | embed_dim (int): embedding dimension
263 | depth (int): depth of transformer
264 | num_heads (int): number of attention heads
265 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim
266 | qkv_bias (bool): enable bias for qkv if True
267 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
268 | distilled (bool): model includes a distillation token and head as in DeiT models
269 | drop_rate (float): dropout rate
270 | attn_drop_rate (float): attention dropout rate
271 | drop_path_rate (float): stochastic depth rate
272 | embed_layer (nn.Module): patch embedding layer
273 | norm_layer: (nn.Module): normalization layer
274 | weight_init: (str): weight init scheme
275 | """
276 | super().__init__()
277 | self.return_attn = return_attn
278 | self.num_classes = num_classes
279 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
280 | self.num_tokens = 2 if distilled else 1
281 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
282 | act_layer = act_layer or nn.GELU
283 |
284 | self.patch_embed = embed_layer(
285 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
286 | num_patches = self.patch_embed.num_patches
287 |
288 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
289 | self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
290 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
291 | self.pos_drop = nn.Dropout(p=drop_rate)
292 |
293 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
294 | self.blocks = nn.Sequential(*[
295 | Block(
296 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
297 | attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer,
298 | return_attn=return_attn and i==depth-1)
299 | for i in range(depth)])
300 | self.norm = norm_layer(embed_dim)
301 |
302 | # Representation layer
303 | if representation_size and not distilled:
304 | self.num_features = representation_size
305 | self.pre_logits = nn.Sequential(OrderedDict([
306 | ('fc', nn.Linear(embed_dim, representation_size)),
307 | ('act', nn.Tanh())
308 | ]))
309 | else:
310 | self.pre_logits = nn.Identity()
311 |
312 | # Classifier head(s)
313 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
314 | self.head_dist = None
315 | if distilled:
316 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
317 |
318 | self.init_weights(weight_init)
319 |
320 | def init_weights(self, mode=''):
321 | assert mode in ('jax', 'jax_nlhb', 'nlhb', '')
322 | head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
323 | trunc_normal_(self.pos_embed, std=.02)
324 | if self.dist_token is not None:
325 | trunc_normal_(self.dist_token, std=.02)
326 | if mode.startswith('jax'):
327 | # leave cls token as zeros to match jax impl
328 | named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self)
329 | else:
330 | trunc_normal_(self.cls_token, std=.02)
331 | self.apply(_init_vit_weights)
332 |
333 | def _init_weights(self, m):
334 | # this fn left here for compat with downstream users
335 | _init_vit_weights(m)
336 |
337 | @torch.jit.ignore()
338 | def load_pretrained(self, checkpoint_path, prefix=''):
339 | _load_weights(self, checkpoint_path, prefix)
340 |
341 | @torch.jit.ignore
342 | def no_weight_decay(self):
343 | return {'pos_embed', 'cls_token', 'dist_token'}
344 |
345 | def get_classifier(self):
346 | if self.dist_token is None:
347 | return self.head
348 | else:
349 | return self.head, self.head_dist
350 |
351 | def reset_classifier(self, num_classes, global_pool=''):
352 | self.num_classes = num_classes
353 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
354 | if self.num_tokens == 2:
355 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
356 |
357 | def forward_features(self, x):
358 | x = self.patch_embed(x)
359 | cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
360 | if self.dist_token is None:
361 | x = torch.cat((cls_token, x), dim=1)
362 | else:
363 | x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
364 | x = self.pos_drop(x + self.pos_embed)
365 | x = self.blocks(x)
366 | if self.return_attn:
367 | x, attn = x[0], x[1]
368 | attn = torch.mean(attn[:, :, 0, 1:], dim=1) # attn from cls_token to images
369 | x = self.norm(x)
370 | if self.dist_token is None:
371 | return self.pre_logits(x[:, 0]), attn
372 | else:
373 | return x[:, 0], x[:, 1], attn
374 | else:
375 | x = self.norm(x)
376 | if self.dist_token is None:
377 | return self.pre_logits(x[:, 0])
378 | else:
379 | return x[:, 0], x[:, 1]
380 |
381 | def forward(self, x):
382 | x = self.forward_features(x)
383 | if self.return_attn:
384 | if self.head_dist is not None:
385 | x, x_dist, attn = self.head(x[0]), self.head_dist(x[1]), x[2] # x must be a tuple
386 | if self.training and not torch.jit.is_scripting():
387 | # during inference, return the average of both classifier predictions
388 | return x, x_dist, attn
389 | else:
390 | return (x + x_dist) / 2, attn
391 | else:
392 | x, attn = x[0], x[1]
393 | x = self.head(x)
394 | return x, attn
395 | else:
396 | if self.head_dist is not None:
397 | x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple
398 | if self.training and not torch.jit.is_scripting():
399 | # during inference, return the average of both classifier predictions
400 | return x, x_dist
401 | else:
402 | return (x + x_dist) / 2
403 | else:
404 | x = self.head(x)
405 | return x
406 |
407 |
408 | def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False):
409 | """ ViT weight initialization
410 | * When called without n, head_bias, jax_impl args it will behave exactly the same
411 | as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
412 | * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
413 | """
414 | if isinstance(module, nn.Linear):
415 | if name.startswith('head'):
416 | nn.init.zeros_(module.weight)
417 | nn.init.constant_(module.bias, head_bias)
418 | elif name.startswith('pre_logits'):
419 | lecun_normal_(module.weight)
420 | nn.init.zeros_(module.bias)
421 | else:
422 | if jax_impl:
423 | nn.init.xavier_uniform_(module.weight)
424 | if module.bias is not None:
425 | if 'mlp' in name:
426 | nn.init.normal_(module.bias, std=1e-6)
427 | else:
428 | nn.init.zeros_(module.bias)
429 | else:
430 | trunc_normal_(module.weight, std=.02)
431 | if module.bias is not None:
432 | nn.init.zeros_(module.bias)
433 | elif jax_impl and isinstance(module, nn.Conv2d):
434 | # NOTE conv was left to pytorch default in my original init
435 | lecun_normal_(module.weight)
436 | if module.bias is not None:
437 | nn.init.zeros_(module.bias)
438 | elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
439 | nn.init.zeros_(module.bias)
440 | nn.init.ones_(module.weight)
441 |
442 |
443 | @torch.no_grad()
444 | def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
445 | """ Load weights from .npz checkpoints for official Google Brain Flax implementation
446 | """
447 | import numpy as np
448 |
449 | def _n2p(w, t=True):
450 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
451 | w = w.flatten()
452 | if t:
453 | if w.ndim == 4:
454 | w = w.transpose([3, 2, 0, 1])
455 | elif w.ndim == 3:
456 | w = w.transpose([2, 0, 1])
457 | elif w.ndim == 2:
458 | w = w.transpose([1, 0])
459 | return torch.from_numpy(w)
460 |
461 | w = np.load(checkpoint_path)
462 | if not prefix and 'opt/target/embedding/kernel' in w:
463 | prefix = 'opt/target/'
464 |
465 | if hasattr(model.patch_embed, 'backbone'):
466 | # hybrid
467 | backbone = model.patch_embed.backbone
468 | stem_only = not hasattr(backbone, 'stem')
469 | stem = backbone if stem_only else backbone.stem
470 | stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
471 | stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
472 | stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
473 | if not stem_only:
474 | for i, stage in enumerate(backbone.stages):
475 | for j, block in enumerate(stage.blocks):
476 | bp = f'{prefix}block{i + 1}/unit{j + 1}/'
477 | for r in range(3):
478 | getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
479 | getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
480 | getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
481 | if block.downsample is not None:
482 | block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
483 | block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
484 | block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
485 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
486 | else:
487 | embed_conv_w = adapt_input_conv(
488 | model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
489 | model.patch_embed.proj.weight.copy_(embed_conv_w)
490 | model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
491 | model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
492 | pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
493 | if pos_embed_w.shape != model.pos_embed.shape:
494 | pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
495 | pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
496 | model.pos_embed.copy_(pos_embed_w)
497 | model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
498 | model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
499 | if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
500 | model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
501 | model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
502 | if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
503 | model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
504 | model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
505 | for i, block in enumerate(model.blocks.children()):
506 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
507 | mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
508 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
509 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
510 | block.attn.qkv.weight.copy_(torch.cat([
511 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
512 | block.attn.qkv.bias.copy_(torch.cat([
513 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
514 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
515 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
516 | for r in range(2):
517 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
518 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
519 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
520 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
521 |
522 |
523 | def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
524 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from
525 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
526 | _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
527 | ntok_new = posemb_new.shape[1]
528 | if num_tokens:
529 | posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
530 | ntok_new -= num_tokens
531 | else:
532 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
533 | gs_old = int(math.sqrt(len(posemb_grid)))
534 | if not len(gs_new): # backwards compatibility
535 | gs_new = [int(math.sqrt(ntok_new))] * 2
536 | assert len(gs_new) >= 2
537 | _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new)
538 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
539 | posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False)
540 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
541 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
542 | return posemb
543 |
544 |
545 | def checkpoint_filter_fn(state_dict, model):
546 | """ convert patch embedding weight from manual patchify + linear proj to conv"""
547 | out_dict = {}
548 | if 'model' in state_dict:
549 | # For deit models
550 | state_dict = state_dict['model']
551 | for k, v in state_dict.items():
552 | if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
553 | # For old models that I trained prior to conv based patchification
554 | O, I, H, W = model.patch_embed.proj.weight.shape
555 | v = v.reshape(O, -1, H, W)
556 | elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
557 | # To resize pos embedding when using model at different size from pretrained weights
558 | v = resize_pos_embed(
559 | v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
560 | out_dict[k] = v
561 | return out_dict
562 |
563 |
564 | def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs):
565 | default_cfg = default_cfg or default_cfgs[variant]
566 | if kwargs.get('features_only', None):
567 | raise RuntimeError('features_only not implemented for Vision Transformer models.')
568 |
569 | # NOTE this extra code to support handling of repr size for in21k pretrained models
570 | default_num_classes = default_cfg['num_classes']
571 | num_classes = kwargs.get('num_classes', default_num_classes)
572 | repr_size = kwargs.pop('representation_size', None)
573 | if repr_size is not None and num_classes != default_num_classes:
574 | # Remove representation layer if fine-tuning. This may not always be the desired action,
575 | # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface?
576 | _logger.warning("Removing representation layer for fine-tuning.")
577 | repr_size = None
578 |
579 | model = build_model_with_cfg(
580 | VisionTransformer, variant, pretrained,
581 | default_cfg=default_cfg,
582 | representation_size=repr_size,
583 | pretrained_filter_fn=checkpoint_filter_fn,
584 | pretrained_custom_load='npz' in default_cfg['url'],
585 | **kwargs)
586 | return model
587 |
588 |
589 |
590 | @register_model
591 | def deit_small_patch16_224_return_attn(pretrained=False, **kwargs):
592 |
593 | """ an extra output for the class attention
594 | DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
595 | ImageNet-1k weights from https://github.com/facebookresearch/deit.
596 | """
597 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, return_attn=True, **kwargs)
598 | model = _create_vision_transformer('deit_small_patch16_224', pretrained=pretrained, **model_kwargs)
599 | return model
600 |
601 | @register_model
602 | def deit_base_patch16_224_return_attn(pretrained=False, **kwargs):
603 | """ an extra output for the class attention
604 | DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
605 | ImageNet-1k weights from https://github.com/facebookresearch/deit.
606 | """
607 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, return_attn=True, **kwargs)
608 | model = _create_vision_transformer('deit_base_patch16_224', pretrained=pretrained, **model_kwargs)
609 | return model
610 |
611 | @register_model
612 | def vit_tiny_patch16_224(pretrained=False, **kwargs):
613 | """ ViT-Tiny (Vit-Ti/16)
614 | """
615 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
616 | model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
617 | return model
618 |
619 |
620 | @register_model
621 | def vit_tiny_patch16_384(pretrained=False, **kwargs):
622 | """ ViT-Tiny (Vit-Ti/16) @ 384x384.
623 | """
624 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
625 | model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **model_kwargs)
626 | return model
627 |
628 |
629 | @register_model
630 | def vit_small_patch32_224(pretrained=False, **kwargs):
631 | """ ViT-Small (ViT-S/32)
632 | """
633 | model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
634 | model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **model_kwargs)
635 | return model
636 |
637 |
638 | @register_model
639 | def vit_small_patch32_384(pretrained=False, **kwargs):
640 | """ ViT-Small (ViT-S/32) at 384x384.
641 | """
642 | model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
643 | model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **model_kwargs)
644 | return model
645 |
646 |
647 | @register_model
648 | def vit_small_patch16_224(pretrained=False, **kwargs):
649 | """ ViT-Small (ViT-S/16)
650 | NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper
651 | """
652 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
653 | model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs)
654 | return model
655 |
656 |
657 | @register_model
658 | def vit_small_patch16_384(pretrained=False, **kwargs):
659 | """ ViT-Small (ViT-S/16)
660 | NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper
661 | """
662 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
663 | model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs)
664 | return model
665 |
666 |
667 | @register_model
668 | def vit_base_patch32_224(pretrained=False, **kwargs):
669 | """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
670 | ImageNet-1k weights fine-tuned from in21k, source https://github.com/google-research/vision_transformer.
671 | """
672 | model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
673 | model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs)
674 | return model
675 |
676 |
677 | @register_model
678 | def vit_base_patch32_384(pretrained=False, **kwargs):
679 | """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
680 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
681 | """
682 | model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
683 | model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs)
684 | return model
685 |
686 |
687 | @register_model
688 | def vit_base_patch16_224(pretrained=False, **kwargs):
689 | """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
690 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
691 | """
692 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
693 | model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs)
694 | return model
695 |
696 |
697 | @register_model
698 | def vit_base_patch16_384(pretrained=False, **kwargs):
699 | """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
700 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
701 | """
702 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
703 | model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs)
704 | return model
705 |
706 |
707 | @register_model
708 | def vit_base_patch8_224(pretrained=False, **kwargs):
709 | """ ViT-Base (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929).
710 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
711 | """
712 | model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs)
713 | model = _create_vision_transformer('vit_base_patch8_224', pretrained=pretrained, **model_kwargs)
714 | return model
715 |
716 |
717 | @register_model
718 | def vit_large_patch32_224(pretrained=False, **kwargs):
719 | """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
720 | """
721 | model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
722 | model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs)
723 | return model
724 |
725 |
726 | @register_model
727 | def vit_large_patch32_384(pretrained=False, **kwargs):
728 | """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
729 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
730 | """
731 | model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
732 | model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs)
733 | return model
734 |
735 |
736 | @register_model
737 | def vit_large_patch16_224(pretrained=False, **kwargs):
738 | """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
739 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
740 | """
741 | model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
742 | model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs)
743 | return model
744 |
745 |
746 | @register_model
747 | def vit_large_patch16_384(pretrained=False, **kwargs):
748 | """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
749 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
750 | """
751 | model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
752 | model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs)
753 | return model
754 |
755 |
756 | @register_model
757 | def vit_base_patch16_sam_224(pretrained=False, **kwargs):
758 | """ ViT-Base (ViT-B/16) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548
759 | """
760 | # NOTE original SAM weights release worked with representation_size=768
761 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=0, **kwargs)
762 | model = _create_vision_transformer('vit_base_patch16_sam_224', pretrained=pretrained, **model_kwargs)
763 | return model
764 |
765 |
766 | @register_model
767 | def vit_base_patch32_sam_224(pretrained=False, **kwargs):
768 | """ ViT-Base (ViT-B/32) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548
769 | """
770 | # NOTE original SAM weights release worked with representation_size=768
771 | model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=0, **kwargs)
772 | model = _create_vision_transformer('vit_base_patch32_sam_224', pretrained=pretrained, **model_kwargs)
773 | return model
774 |
775 |
776 | @register_model
777 | def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs):
778 | """ ViT-Tiny (Vit-Ti/16).
779 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
780 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
781 | """
782 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
783 | model = _create_vision_transformer('vit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
784 | return model
785 |
786 |
787 | @register_model
788 | def vit_small_patch32_224_in21k(pretrained=False, **kwargs):
789 | """ ViT-Small (ViT-S/16)
790 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
791 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
792 | """
793 | model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
794 | model = _create_vision_transformer('vit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
795 | return model
796 |
797 |
798 | @register_model
799 | def vit_small_patch16_224_in21k(pretrained=False, **kwargs):
800 | """ ViT-Small (ViT-S/16)
801 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
802 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
803 | """
804 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
805 | model = _create_vision_transformer('vit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
806 | return model
807 |
808 |
809 | @register_model
810 | def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
811 | """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
812 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
813 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
814 | """
815 | model_kwargs = dict(
816 | patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
817 | model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
818 | return model
819 |
820 |
821 | @register_model
822 | def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
823 | """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
824 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
825 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
826 | """
827 | model_kwargs = dict(
828 | patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
829 | model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
830 | return model
831 |
832 |
833 | @register_model
834 | def vit_base_patch8_224_in21k(pretrained=False, **kwargs):
835 | """ ViT-Base model (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929).
836 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
837 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
838 | """
839 | model_kwargs = dict(
840 | patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs)
841 | model = _create_vision_transformer('vit_base_patch8_224_in21k', pretrained=pretrained, **model_kwargs)
842 | return model
843 |
844 |
845 | @register_model
846 | def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
847 | """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
848 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
849 | NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights
850 | """
851 | model_kwargs = dict(
852 | patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
853 | model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
854 | return model
855 |
856 |
857 | @register_model
858 | def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
859 | """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
860 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
861 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
862 | """
863 | model_kwargs = dict(
864 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
865 | model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
866 | return model
867 |
868 |
869 | @register_model
870 | def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
871 | """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
872 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
873 | NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights
874 | """
875 | model_kwargs = dict(
876 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs)
877 | model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs)
878 | return model
879 |
880 |
881 | @register_model
882 | def deit_tiny_patch16_224(pretrained=False, **kwargs):
883 | """ DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
884 | ImageNet-1k weights from https://github.com/facebookresearch/deit.
885 | """
886 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
887 | model = _create_vision_transformer('deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
888 | return model
889 |
890 |
891 | @register_model
892 | def deit_small_patch16_224(pretrained=False, **kwargs):
893 | """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
894 | ImageNet-1k weights from https://github.com/facebookresearch/deit.
895 | """
896 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
897 | model = _create_vision_transformer('deit_small_patch16_224', pretrained=pretrained, **model_kwargs)
898 | return model
899 |
900 |
901 | @register_model
902 | def deit_base_patch16_224(pretrained=False, **kwargs):
903 | """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
904 | ImageNet-1k weights from https://github.com/facebookresearch/deit.
905 | """
906 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
907 | model = _create_vision_transformer('deit_base_patch16_224', pretrained=pretrained, **model_kwargs)
908 | return model
909 |
910 |
911 | @register_model
912 | def deit_base_patch16_384(pretrained=False, **kwargs):
913 | """ DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
914 | ImageNet-1k weights from https://github.com/facebookresearch/deit.
915 | """
916 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
917 | model = _create_vision_transformer('deit_base_patch16_384', pretrained=pretrained, **model_kwargs)
918 | return model
919 |
920 |
921 | @register_model
922 | def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
923 | """ DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
924 | ImageNet-1k weights from https://github.com/facebookresearch/deit.
925 | """
926 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
927 | model = _create_vision_transformer(
928 | 'deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
929 | return model
930 |
931 |
932 | @register_model
933 | def deit_small_distilled_patch16_224(pretrained=False, **kwargs):
934 | """ DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
935 | ImageNet-1k weights from https://github.com/facebookresearch/deit.
936 | """
937 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
938 | model = _create_vision_transformer(
939 | 'deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
940 | return model
941 |
942 |
943 | @register_model
944 | def deit_base_distilled_patch16_224(pretrained=False, **kwargs):
945 | """ DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
946 | ImageNet-1k weights from https://github.com/facebookresearch/deit.
947 | """
948 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
949 | model = _create_vision_transformer(
950 | 'deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
951 | return model
952 |
953 |
954 | @register_model
955 | def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
956 | """ DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
957 | ImageNet-1k weights from https://github.com/facebookresearch/deit.
958 | """
959 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
960 | model = _create_vision_transformer(
961 | 'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
962 | return model
963 |
964 |
965 | @register_model
966 | def vit_base_patch16_224_miil_in21k(pretrained=False, **kwargs):
967 | """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
968 | Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
969 | """
970 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
971 | model = _create_vision_transformer('vit_base_patch16_224_miil_in21k', pretrained=pretrained, **model_kwargs)
972 | return model
973 |
974 |
975 | @register_model
976 | def vit_base_patch16_224_miil(pretrained=False, **kwargs):
977 | """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
978 | Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
979 | """
980 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
981 | model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs)
982 | return model
983 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """ ImageNet Training Script for TransMix.
3 | This script was modified from an early version of the PyTorch Image Models (timm)
4 | (https://github.com/rwightman/pytorch-image-models)
5 | Hacked together by Jieneng Chen and Shuyang Sun / Copyright 2022 ByteDance
6 | """
7 | import argparse
8 | import time
9 | import yaml
10 | import os
11 | import logging
12 | from collections import OrderedDict
13 | from contextlib import suppress
14 | from datetime import datetime
15 |
16 | import torch
17 | import torch.nn as nn
18 | import torchvision.utils
19 | from torch.nn.parallel import DistributedDataParallel as NativeDDP
20 |
21 | from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
22 | from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\
23 | convert_splitbn_model, model_parameters
24 | from timm.utils import *
25 | from timm.loss import *
26 | from timm.optim import create_optimizer_v2, optimizer_kwargs
27 | from timm.scheduler import create_scheduler
28 | from timm.utils import ApexScaler, NativeScaler
29 |
30 | try:
31 | from apex import amp
32 | from apex.parallel import DistributedDataParallel as ApexDDP
33 | from apex.parallel import convert_syncbn_model
34 | has_apex = True
35 | except ImportError:
36 | has_apex = False
37 |
38 | has_native_amp = False
39 | try:
40 | if getattr(torch.cuda.amp, 'autocast') is not None:
41 | has_native_amp = True
42 | except AttributeError:
43 | pass
44 |
45 | try:
46 | import wandb
47 | has_wandb = True
48 | except ImportError:
49 | has_wandb = False
50 |
51 | torch.backends.cudnn.benchmark = True
52 | _logger = logging.getLogger('train')
53 |
54 | # The first arg parser parses out only the --config argument, this argument is used to
55 | # load a yaml file containing key-values that override the defaults for the main parser below
56 | config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
57 | parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
58 | help='YAML config file specifying default arguments')
59 |
60 |
61 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
62 |
63 | # Dataset parameters
64 | parser.add_argument('data_dir', default='./data/', type=str, metavar='DIR', # there is no default in timm!!!
65 | help='path to dataset')
66 | parser.add_argument('--dataset', '-d', metavar='NAME', default='',
67 | help='dataset type (default: ImageFolder/ImageTar if empty)')
68 | parser.add_argument('--train-split', metavar='NAME', default='train',
69 | help='dataset train split (default: train)')
70 | parser.add_argument('--val-split', metavar='NAME', default='validation',
71 | help='dataset validation split (default: validation)')
72 | parser.add_argument('--dataset-download', action='store_true', default=False,
73 | help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
74 | parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
75 | help='path to class to idx mapping file (default: "")')
76 |
77 | # Model parameters
78 | parser.add_argument('--model', default='resnet50', type=str, metavar='MODEL',
79 | help='Name of model to train (default: "resnet50"')
80 | parser.add_argument('--pretrained', action='store_true', default=False,
81 | help='Start with pretrained version of specified network (if avail)')
82 | parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
83 | help='Initialize model from this checkpoint (default: none)')
84 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
85 | help='Resume full model and optimizer state from checkpoint (default: none)')
86 | parser.add_argument('--no-resume-opt', action='store_true', default=False,
87 | help='prevent resume of optimizer state when resuming model')
88 | parser.add_argument('--num-classes', type=int, default=None, metavar='N',
89 | help='number of label classes (Model default if None)')
90 | parser.add_argument('--gp', default=None, type=str, metavar='POOL',
91 | help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
92 | parser.add_argument('--img-size', type=int, default=None, metavar='N',
93 | help='Image patch size (default: None => model default)')
94 | parser.add_argument('--input-size', default=None, nargs=3, type=int,
95 | metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
96 | parser.add_argument('--crop-pct', default=None, type=float,
97 | metavar='N', help='Input image center crop percent (for validation only)')
98 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
99 | help='Override mean pixel value of dataset')
100 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
101 | help='Override std deviation of of dataset')
102 | parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
103 | help='Image resize interpolation type (overrides model)')
104 | parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
105 | help='input batch size for training (default: 128)')
106 | parser.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N',
107 | help='validation batch size override (default: None)')
108 |
109 | # Optimizer parameters
110 | parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
111 | help='Optimizer (default: "sgd"')
112 | parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON',
113 | help='Optimizer Epsilon (default: None, use opt default)')
114 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
115 | help='Optimizer Betas (default: None, use opt default)')
116 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
117 | help='Optimizer momentum (default: 0.9)')
118 | parser.add_argument('--weight-decay', type=float, default=2e-5,
119 | help='weight decay (default: 2e-5)')
120 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
121 | help='Clip gradient norm (default: None, no clipping)')
122 | parser.add_argument('--clip-mode', type=str, default='norm',
123 | help='Gradient clipping mode. One of ("norm", "value", "agc")')
124 |
125 |
126 | # Learning rate schedule parameters
127 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
128 | help='LR scheduler (default: "step"')
129 | parser.add_argument('--lr', type=float, default=0.05, metavar='LR',
130 | help='learning rate (default: 0.05)')
131 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
132 | help='learning rate noise on/off epoch percentages')
133 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
134 | help='learning rate noise limit percent (default: 0.67)')
135 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
136 | help='learning rate noise std-dev (default: 1.0)')
137 | parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
138 | help='learning rate cycle len multiplier (default: 1.0)')
139 | parser.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT',
140 | help='amount to decay each learning rate cycle (default: 0.5)')
141 | parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
142 | help='learning rate cycle limit, cycles enabled if > 1')
143 | parser.add_argument('--lr-k-decay', type=float, default=1.0,
144 | help='learning rate k-decay for cosine/poly (default: 1.0)')
145 | parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
146 | help='warmup learning rate (default: 0.0001)')
147 | parser.add_argument('--min-lr', type=float, default=1e-6, metavar='LR', # 1e-5 for vit
148 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
149 | parser.add_argument('--epochs', type=int, default=300, metavar='N',
150 | help='number of epochs to train (default: 300)')
151 | parser.add_argument('--epoch-repeats', type=float, default=0., metavar='N',
152 | help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).')
153 | parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
154 | help='manual epoch number (useful on restarts)')
155 | parser.add_argument('--decay-epochs', type=float, default=100, metavar='N',
156 | help='epoch interval to decay LR')
157 | parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N',
158 | help='epochs to warmup LR, if scheduler supports')
159 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
160 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
161 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
162 | help='patience epochs for Plateau LR scheduler (default: 10')
163 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
164 | help='LR decay rate (default: 0.1)')
165 |
166 | # Augmentation & regularization parameters
167 | parser.add_argument('--no-aug', action='store_true', default=False,
168 | help='Disable all training augmentation, override other train aug args')
169 | parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
170 | help='Random resize scale (default: 0.08 1.0)')
171 | parser.add_argument('--ratio', type=float, nargs='+', default=[3./4., 4./3.], metavar='RATIO',
172 | help='Random resize aspect ratio (default: 0.75 1.33)')
173 | parser.add_argument('--hflip', type=float, default=0.5,
174 | help='Horizontal flip training aug probability')
175 | parser.add_argument('--vflip', type=float, default=0.,
176 | help='Vertical flip training aug probability')
177 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
178 | help='Color jitter factor (default: 0.4)')
179 | parser.add_argument('--aa', type=str, default=None, metavar='NAME',
180 | help='Use AutoAugment policy. "v0" or "original". (default: None)'),
181 | parser.add_argument('--aug-repeats', type=int, default=0,
182 | help='Number of augmentation repetitions (distributed training only) (default: 0)')
183 | parser.add_argument('--aug-splits', type=int, default=0,
184 | help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
185 | parser.add_argument('--jsd-loss', action='store_true', default=False,
186 | help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
187 | parser.add_argument('--bce-loss', action='store_true', default=False,
188 | help='Enable BCE loss w/ Mixup/CutMix use.')
189 | parser.add_argument('--bce-target-thresh', type=float, default=None,
190 | help='Threshold for binarizing softened BCE targets (default: None, disabled)')
191 | parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
192 | help='Random erase prob (default: 0.)')
193 | parser.add_argument('--remode', type=str, default='pixel',
194 | help='Random erase mode (default: "pixel")')
195 | parser.add_argument('--recount', type=int, default=1,
196 | help='Random erase count (default: 1)')
197 | parser.add_argument('--resplit', action='store_true', default=False,
198 | help='Do not random erase first (clean) augmentation split')
199 | parser.add_argument('--mixup', type=float, default=0.0,
200 | help='mixup alpha, mixup enabled if > 0. (default: 0.)')
201 | parser.add_argument('--cutmix', type=float, default=0.0,
202 | help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
203 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
204 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
205 | parser.add_argument('--mixup-prob', type=float, default=1.0,
206 | help='Probability of performing mixup or cutmix when either/both is enabled')
207 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
208 | help='Probability of switching to cutmix when both mixup and cutmix enabled')
209 | parser.add_argument('--mixup-mode', type=str, default='batch',
210 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
211 | parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
212 | help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
213 | parser.add_argument('--smoothing', type=float, default=0.1,
214 | help='Label smoothing (default: 0.1)')
215 | parser.add_argument('--train-interpolation', type=str, default='random',
216 | help='Training interpolation (random, bilinear, bicubic default: "random")')
217 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
218 | help='Dropout rate (default: 0.)')
219 | parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
220 | help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
221 | parser.add_argument('--drop-path', type=float, default=None, metavar='PCT',
222 | help='Drop path rate (default: None)')
223 | parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
224 | help='Drop block rate (default: None)')
225 |
226 | # Batch norm parameters (only works with gen_efficientnet based models currently)
227 | parser.add_argument('--bn-momentum', type=float, default=None,
228 | help='BatchNorm momentum override (if not None)')
229 | parser.add_argument('--bn-eps', type=float, default=None,
230 | help='BatchNorm epsilon override (if not None)')
231 | parser.add_argument('--sync-bn', action='store_true',
232 | help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
233 | parser.add_argument('--dist-bn', type=str, default='reduce',
234 | help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
235 | parser.add_argument('--split-bn', action='store_true',
236 | help='Enable separate BN layers per augmentation split.')
237 |
238 | # Model Exponential Moving Average
239 | parser.add_argument('--model-ema', action='store_true', default=False,
240 | help='Enable tracking moving average of model weights')
241 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
242 | help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
243 | parser.add_argument('--model-ema-decay', type=float, default=0.9998,
244 | help='decay factor for model weights moving average (default: 0.9998)')
245 |
246 | # Misc
247 | parser.add_argument('--seed', type=int, default=42, metavar='S',
248 | help='random seed (default: 42)')
249 | parser.add_argument('--worker-seeding', type=str, default='all',
250 | help='worker seed mode (default: all)')
251 | parser.add_argument('--log-interval', type=int, default=50, metavar='N',
252 | help='how many batches to wait before logging training status')
253 | parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
254 | help='how many batches to wait before writing recovery checkpoint')
255 | parser.add_argument('--checkpoint-hist', type=int, default=10, metavar='N',
256 | help='number of checkpoints to keep (default: 10)')
257 | parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
258 | help='how many training processes to use (default: 4)')
259 | parser.add_argument('--save-images', action='store_true', default=False,
260 | help='save images of input bathes every log interval for debugging')
261 | parser.add_argument('--amp', action='store_true', default=False,
262 | help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
263 | parser.add_argument('--apex-amp', action='store_true', default=False,
264 | help='Use NVIDIA Apex AMP mixed precision')
265 | parser.add_argument('--native-amp', action='store_true', default=False,
266 | help='Use Native Torch AMP mixed precision')
267 | parser.add_argument('--no-ddp-bb', action='store_true', default=False,
268 | help='Force broadcast buffers for native DDP to off.')
269 | parser.add_argument('--channels-last', action='store_true', default=False,
270 | help='Use channels_last memory layout')
271 | parser.add_argument('--pin-mem', action='store_true', default=False,
272 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
273 | parser.add_argument('--no-prefetcher', action='store_true', default=False,
274 | help='disable fast prefetcher')
275 | parser.add_argument('--output', default='', type=str, metavar='PATH',
276 | help='path to output folder (default: none, current dir)')
277 | parser.add_argument('--experiment', default='', type=str, metavar='NAME',
278 | help='name of train experiment, name of sub-folder for output')
279 | parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
280 | help='Best metric (default: "top1"')
281 | parser.add_argument('--tta', type=int, default=0, metavar='N',
282 | help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
283 | parser.add_argument("--local_rank", default=0, type=int)
284 | parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
285 | help='use the multi-epochs-loader to save time at the beginning of every epoch')
286 | parser.add_argument('--torchscript', dest='torchscript', action='store_true',
287 | help='convert model torchscript for inference')
288 | parser.add_argument('--fuser', default='', type=str,
289 | help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
290 | parser.add_argument('--log-wandb', action='store_true', default=False,
291 | help='log training and validation metrics to wandb')
292 | # new flags we added to timm.
293 | parser.add_argument('--transmix', action='store_true', default=False, help='')
294 | parser.add_argument('--total-batch-size', type=int, default=None,
295 | help='input batch size for training (default: None), batch-size = total-batch-size / world_size')
296 | def _parse_args():
297 | # Do we have a config file to parse?
298 | args_config, remaining = config_parser.parse_known_args()
299 | if args_config.config:
300 | with open(args_config.config, 'r') as f:
301 | cfg = yaml.safe_load(f)
302 | parser.set_defaults(**cfg)
303 |
304 | # The main arg parser parses the rest of the args, the usual
305 | # defaults will have been overridden if config file specified.
306 | args = parser.parse_args(remaining)
307 |
308 | # Cache the args as a text string to save them in the output dir later
309 | args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
310 | return args, args_text
311 |
312 |
313 | def main():
314 | setup_default_logging()
315 | args, args_text = _parse_args()
316 | print(args)
317 |
318 | if args.log_wandb:
319 | if has_wandb:
320 | wandb.init(project=args.experiment, config=args)
321 | else:
322 | _logger.warning("You've requested to log metrics to wandb but package not found. "
323 | "Metrics not being logged to wandb, try `pip install wandb`")
324 |
325 | args.prefetcher = not args.no_prefetcher
326 | args.distributed = False
327 | if 'WORLD_SIZE' in os.environ:
328 | args.distributed = int(os.environ['WORLD_SIZE']) > 1
329 | args.device = 'cuda:0'
330 | args.world_size = 1
331 | args.rank = 0 # global rank
332 | if args.distributed:
333 | args.device = 'cuda:%d' % args.local_rank
334 | torch.cuda.set_device(args.local_rank)
335 | torch.distributed.init_process_group(backend='nccl', init_method='env://')
336 | args.world_size = torch.distributed.get_world_size()
337 | args.rank = torch.distributed.get_rank()
338 | _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
339 | % (args.rank, args.world_size))
340 | else:
341 | _logger.info('Training with a single process on 1 GPUs.')
342 | assert args.rank >= 0
343 |
344 | # resolve AMP arguments based on PyTorch / Apex availability
345 | use_amp = None
346 | if args.amp:
347 | # `--amp` chooses native amp before apex (APEX ver not actively maintained)
348 | if has_native_amp:
349 | args.native_amp = True
350 | elif has_apex:
351 | args.apex_amp = True
352 | if args.apex_amp and has_apex:
353 | use_amp = 'apex'
354 | elif args.native_amp and has_native_amp:
355 | use_amp = 'native'
356 | elif args.apex_amp or args.native_amp:
357 | _logger.warning("Neither APEX or native Torch AMP is available, using float32. "
358 | "Install NVIDA apex or upgrade to PyTorch 1.6")
359 |
360 | random_seed(args.seed, args.rank)
361 |
362 | model = create_model(
363 | args.model,
364 | pretrained=args.pretrained,
365 | num_classes=args.num_classes,
366 | drop_rate=args.drop,
367 | drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path
368 | drop_path_rate=args.drop_path,
369 | drop_block_rate=args.drop_block,
370 | global_pool=args.gp,
371 | bn_momentum=args.bn_momentum,
372 | bn_eps=args.bn_eps,
373 | scriptable=args.torchscript,
374 | checkpoint_path=args.initial_checkpoint)
375 | if args.num_classes is None:
376 | assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
377 | args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly
378 |
379 | if args.local_rank == 0:
380 | _logger.info(
381 | f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}')
382 |
383 | data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
384 |
385 | # setup augmentation batch splits for contrastive loss or split bn
386 | num_aug_splits = 0
387 | if args.aug_splits > 0:
388 | assert args.aug_splits > 1, 'A split of 1 makes no sense'
389 | num_aug_splits = args.aug_splits
390 |
391 | # enable split bn (separate bn stats per batch-portion)
392 | if args.split_bn:
393 | assert num_aug_splits > 1 or args.resplit
394 | model = convert_splitbn_model(model, max(num_aug_splits, 2))
395 |
396 | # move model to GPU, enable channels last layout if set
397 | model.cuda()
398 | if args.channels_last:
399 | model = model.to(memory_format=torch.channels_last)
400 |
401 | # setup synchronized BatchNorm for distributed training
402 | if args.distributed and args.sync_bn:
403 | assert not args.split_bn
404 | if has_apex and use_amp == 'apex':
405 | # Apex SyncBN preferred unless native amp is activated
406 | model = convert_syncbn_model(model)
407 | else:
408 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
409 | if args.local_rank == 0:
410 | _logger.info(
411 | 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
412 | 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
413 |
414 | if args.torchscript:
415 | assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
416 | assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
417 | model = torch.jit.script(model)
418 |
419 | optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args))
420 |
421 | # setup automatic mixed-precision (AMP) loss scaling and op casting
422 | amp_autocast = suppress # do nothing
423 | loss_scaler = None
424 | if use_amp == 'apex':
425 | model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
426 | loss_scaler = ApexScaler()
427 | if args.local_rank == 0:
428 | _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
429 | elif use_amp == 'native':
430 | amp_autocast = torch.cuda.amp.autocast
431 | loss_scaler = NativeScaler()
432 | if args.local_rank == 0:
433 | _logger.info('Using native Torch AMP. Training in mixed precision.')
434 | else:
435 | if args.local_rank == 0:
436 | _logger.info('AMP not enabled. Training in float32.')
437 |
438 | # optionally resume from a checkpoint
439 | resume_epoch = None
440 | if args.resume:
441 | resume_epoch = resume_checkpoint(
442 | model, args.resume,
443 | optimizer=None if args.no_resume_opt else optimizer,
444 | loss_scaler=None if args.no_resume_opt else loss_scaler,
445 | log_info=args.local_rank == 0)
446 |
447 | # setup exponential moving average of model weights, SWA could be used here too
448 | model_ema = None
449 | if args.model_ema:
450 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
451 | model_ema = ModelEmaV2(
452 | model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None)
453 | if args.resume:
454 | load_checkpoint(model_ema.module, args.resume, use_ema=True)
455 |
456 | # setup distributed training
457 | if args.distributed:
458 | if has_apex and use_amp == 'apex':
459 | # Apex DDP preferred unless native amp is activated
460 | if args.local_rank == 0:
461 | _logger.info("Using NVIDIA APEX DistributedDataParallel.")
462 | model = ApexDDP(model, delay_allreduce=True)
463 | else:
464 | if args.local_rank == 0:
465 | _logger.info("Using native Torch DistributedDataParallel.")
466 | model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb)
467 | # NOTE: EMA model does not need to be wrapped by DDP
468 |
469 | # setup learning rate schedule and starting epoch
470 | lr_scheduler, num_epochs = create_scheduler(args, optimizer)
471 | start_epoch = 0
472 | if args.start_epoch is not None:
473 | # a specified start_epoch will always override the resume epoch
474 | start_epoch = args.start_epoch
475 | elif resume_epoch is not None:
476 | start_epoch = resume_epoch
477 | if lr_scheduler is not None and start_epoch > 0:
478 | lr_scheduler.step(start_epoch)
479 |
480 | if args.local_rank == 0:
481 | _logger.info('Scheduled epochs: {}'.format(num_epochs))
482 |
483 | # create the train and eval datasets
484 | dataset_train = create_dataset(
485 | args.dataset, root=args.data_dir, split=args.train_split, is_training=True,
486 | class_map=args.class_map,
487 | download=args.dataset_download,
488 | batch_size=args.batch_size,
489 | repeats=args.epoch_repeats)
490 | dataset_eval = create_dataset(
491 | args.dataset, root=args.data_dir, split=args.val_split, is_training=False,
492 | class_map=args.class_map,
493 | download=args.dataset_download,
494 | batch_size=args.batch_size)
495 |
496 | # setup mixup / cutmix
497 | collate_fn = None
498 | mixup_fn = None
499 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
500 | if mixup_active:
501 | mixup_args = dict(
502 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
503 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
504 | label_smoothing=args.smoothing, num_classes=args.num_classes)
505 | if args.transmix:
506 | # wrap mixup_fn with TransMix helper, disable args.prefetcher
507 | from transmix import Mixup_transmix
508 | mixup_fn = Mixup_transmix(**mixup_args)
509 | else:
510 | if args.prefetcher:
511 | assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)
512 | collate_fn = FastCollateMixup(**mixup_args)
513 | else:
514 | mixup_fn = Mixup(**mixup_args)
515 |
516 |
517 | # wrap dataset in AugMix helper
518 | if num_aug_splits > 1:
519 | dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
520 |
521 | # create data loaders w/ augmentation pipeiine
522 | train_interpolation = args.train_interpolation
523 | if args.no_aug or not train_interpolation:
524 | train_interpolation = data_config['interpolation']
525 |
526 | if args.total_batch_size:
527 | args.batch_size = args.total_batch_size // args.world_size
528 |
529 | loader_train = create_loader(
530 | dataset_train,
531 | input_size=data_config['input_size'],
532 | batch_size=args.batch_size,
533 | is_training=True,
534 | use_prefetcher=args.prefetcher,
535 | no_aug=args.no_aug,
536 | re_prob=args.reprob,
537 | re_mode=args.remode,
538 | re_count=args.recount,
539 | re_split=args.resplit,
540 | scale=args.scale,
541 | ratio=args.ratio,
542 | hflip=args.hflip,
543 | vflip=args.vflip,
544 | color_jitter=args.color_jitter,
545 | auto_augment=args.aa,
546 | num_aug_repeats=args.aug_repeats,
547 | num_aug_splits=num_aug_splits,
548 | interpolation=train_interpolation,
549 | mean=data_config['mean'],
550 | std=data_config['std'],
551 | num_workers=args.workers,
552 | distributed=args.distributed,
553 | collate_fn=collate_fn,
554 | pin_memory=args.pin_mem,
555 | use_multi_epochs_loader=args.use_multi_epochs_loader,
556 | worker_seeding=args.worker_seeding,
557 | )
558 |
559 | loader_eval = create_loader(
560 | dataset_eval,
561 | input_size=data_config['input_size'],
562 | batch_size=args.validation_batch_size or args.batch_size,
563 | is_training=False,
564 | use_prefetcher=args.prefetcher,
565 | interpolation=data_config['interpolation'],
566 | mean=data_config['mean'],
567 | std=data_config['std'],
568 | num_workers=args.workers,
569 | distributed=args.distributed,
570 | crop_pct=data_config['crop_pct'],
571 | pin_memory=args.pin_mem,
572 | )
573 |
574 | # setup loss function
575 | if args.jsd_loss:
576 | assert num_aug_splits > 1 # JSD only valid with aug splits set
577 | train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing)
578 | elif mixup_active:
579 | # smoothing is handled with mixup target transform which outputs sparse, soft targets
580 | if args.bce_loss:
581 | train_loss_fn = BinaryCrossEntropy(target_threshold=args.bce_target_thresh)
582 | else:
583 | train_loss_fn = SoftTargetCrossEntropy()
584 | elif args.smoothing:
585 | if args.bce_loss:
586 | train_loss_fn = BinaryCrossEntropy(smoothing=args.smoothing, target_threshold=args.bce_target_thresh)
587 | else:
588 | train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
589 | else:
590 | train_loss_fn = nn.CrossEntropyLoss()
591 | train_loss_fn = train_loss_fn.cuda()
592 | validate_loss_fn = nn.CrossEntropyLoss().cuda()
593 |
594 | # setup checkpoint saver and eval metric tracking
595 | eval_metric = args.eval_metric
596 | best_metric = None
597 | best_epoch = None
598 | saver = None
599 | output_dir = None
600 | if args.rank == 0:
601 | if args.experiment:
602 | exp_name = args.experiment
603 | else:
604 | exp_name = '-'.join([
605 | datetime.now().strftime("%Y%m%d-%H%M%S"),
606 | safe_model_name(args.model),
607 | str(data_config['input_size'][-1])
608 | ])
609 | output_dir = get_outdir(args.output if args.output else './output/train', exp_name)
610 | decreasing = True if eval_metric == 'loss' else False
611 | saver = CheckpointSaver(
612 | model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
613 | checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist)
614 | with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
615 | f.write(args_text)
616 |
617 | try:
618 | for epoch in range(start_epoch, num_epochs):
619 | if args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
620 | loader_train.sampler.set_epoch(epoch)
621 |
622 | train_metrics = train_one_epoch(
623 | epoch, model, loader_train, optimizer, train_loss_fn, args,
624 | lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
625 | amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)
626 |
627 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
628 | if args.local_rank == 0:
629 | _logger.info("Distributing BatchNorm running means and vars")
630 | distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
631 |
632 | eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast)
633 |
634 | if model_ema is not None and not args.model_ema_force_cpu:
635 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
636 | distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
637 | ema_eval_metrics = validate(
638 | model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)')
639 | eval_metrics = ema_eval_metrics
640 |
641 | if lr_scheduler is not None:
642 | # step LR for next epoch
643 | lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
644 |
645 | if output_dir is not None:
646 | update_summary(
647 | epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
648 | write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb)
649 |
650 | if saver is not None:
651 | # save proper checkpoint with eval metric
652 | save_metric = eval_metrics[eval_metric]
653 | best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric)
654 |
655 | except KeyboardInterrupt:
656 | pass
657 | if best_metric is not None:
658 | _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
659 |
660 |
661 | def train_one_epoch(
662 | epoch, model, loader, optimizer, loss_fn, args,
663 | lr_scheduler=None, saver=None, output_dir=None, amp_autocast=suppress,
664 | loss_scaler=None, model_ema=None, mixup_fn=None):
665 |
666 | if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
667 | if args.prefetcher and loader.mixup_enabled:
668 | loader.mixup_enabled = False
669 | elif mixup_fn is not None:
670 | mixup_fn.mixup_enabled = False
671 |
672 | second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
673 | batch_time_m = AverageMeter()
674 | data_time_m = AverageMeter()
675 | losses_m = AverageMeter()
676 |
677 | model.train()
678 |
679 | end = time.time()
680 | last_idx = len(loader) - 1
681 | num_updates = epoch * len(loader)
682 | for batch_idx, (input, target) in enumerate(loader):
683 | last_batch = batch_idx == last_idx
684 | data_time_m.update(time.time() - end)
685 | if not args.prefetcher:
686 | input, target = input.cuda(), target.cuda()
687 | if mixup_fn is not None:
688 | input, target = mixup_fn(input, target) # target (B, K), or target is tuple under transmix
689 |
690 | if args.channels_last:
691 | input = input.contiguous(memory_format=torch.channels_last)
692 |
693 | with amp_autocast():
694 | output = model(input)
695 | if args.transmix:
696 | (output, attn) = output # attention from cls_token to images: (b, hw)
697 | if isinstance(target, tuple): # target is tuple of (target, y1, y2, lam) when switch to cutmix
698 | target = mixup_fn.transmix_label(target, attn, input.shape)
699 | loss = loss_fn(output, target)
700 |
701 |
702 | if not args.distributed:
703 | losses_m.update(loss.item(), input.size(0))
704 |
705 | optimizer.zero_grad()
706 | if loss_scaler is not None:
707 | loss_scaler(
708 | loss, optimizer,
709 | clip_grad=args.clip_grad, clip_mode=args.clip_mode,
710 | parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
711 | create_graph=second_order)
712 | else:
713 | loss.backward(create_graph=second_order)
714 | if args.clip_grad is not None:
715 | dispatch_clip_grad(
716 | model_parameters(model, exclude_head='agc' in args.clip_mode),
717 | value=args.clip_grad, mode=args.clip_mode)
718 | optimizer.step()
719 |
720 | if model_ema is not None:
721 | model_ema.update(model)
722 |
723 | torch.cuda.synchronize()
724 | num_updates += 1
725 | batch_time_m.update(time.time() - end)
726 | if last_batch or batch_idx % args.log_interval == 0:
727 | lrl = [param_group['lr'] for param_group in optimizer.param_groups]
728 | lr = sum(lrl) / len(lrl)
729 |
730 | if args.distributed:
731 | reduced_loss = reduce_tensor(loss.data, args.world_size)
732 | losses_m.update(reduced_loss.item(), input.size(0))
733 |
734 | if args.local_rank == 0:
735 | _logger.info(
736 | 'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
737 | 'Loss: {loss.val:#.4g} ({loss.avg:#.3g}) '
738 | 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
739 | '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
740 | 'LR: {lr:.3e} '
741 | 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
742 | epoch,
743 | batch_idx, len(loader),
744 | 100. * batch_idx / last_idx,
745 | loss=losses_m,
746 | batch_time=batch_time_m,
747 | rate=input.size(0) * args.world_size / batch_time_m.val,
748 | rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
749 | lr=lr,
750 | data_time=data_time_m))
751 |
752 | if args.save_images and output_dir:
753 | torchvision.utils.save_image(
754 | input,
755 | os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
756 | padding=0,
757 | normalize=True)
758 |
759 | if saver is not None and args.recovery_interval and (
760 | last_batch or (batch_idx + 1) % args.recovery_interval == 0):
761 | saver.save_recovery(epoch, batch_idx=batch_idx)
762 |
763 | if lr_scheduler is not None:
764 | lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
765 |
766 | end = time.time()
767 | # end for
768 |
769 | if hasattr(optimizer, 'sync_lookahead'):
770 | optimizer.sync_lookahead()
771 |
772 | return OrderedDict([('loss', losses_m.avg)])
773 |
774 |
775 | def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''):
776 | batch_time_m = AverageMeter()
777 | losses_m = AverageMeter()
778 | top1_m = AverageMeter()
779 | top5_m = AverageMeter()
780 |
781 | model.eval()
782 |
783 | end = time.time()
784 | last_idx = len(loader) - 1
785 | with torch.no_grad():
786 | for batch_idx, (input, target) in enumerate(loader):
787 | last_batch = batch_idx == last_idx
788 | if not args.prefetcher:
789 | input = input.cuda()
790 | target = target.cuda()
791 | if args.channels_last:
792 | input = input.contiguous(memory_format=torch.channels_last)
793 |
794 | with amp_autocast():
795 | output = model(input)
796 | if isinstance(output, (tuple, list)):
797 | output = output[0]
798 |
799 | # augmentation reduction
800 | reduce_factor = args.tta
801 | if reduce_factor > 1:
802 | output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
803 | target = target[0:target.size(0):reduce_factor]
804 |
805 | loss = loss_fn(output, target)
806 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
807 |
808 | if args.distributed:
809 | reduced_loss = reduce_tensor(loss.data, args.world_size)
810 | acc1 = reduce_tensor(acc1, args.world_size)
811 | acc5 = reduce_tensor(acc5, args.world_size)
812 | else:
813 | reduced_loss = loss.data
814 |
815 | torch.cuda.synchronize()
816 |
817 | losses_m.update(reduced_loss.item(), input.size(0))
818 | top1_m.update(acc1.item(), output.size(0))
819 | top5_m.update(acc5.item(), output.size(0))
820 |
821 | batch_time_m.update(time.time() - end)
822 | end = time.time()
823 | if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
824 | log_name = 'Test' + log_suffix
825 | _logger.info(
826 | '{0}: [{1:>4d}/{2}] '
827 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
828 | 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
829 | 'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
830 | 'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
831 | log_name, batch_idx, last_idx, batch_time=batch_time_m,
832 | loss=losses_m, top1=top1_m, top5=top5_m))
833 |
834 | metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
835 |
836 | return metrics
837 |
838 |
839 | if __name__ == '__main__':
840 | main()
--------------------------------------------------------------------------------
/transmix.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 | from timm.data.mixup import Mixup, cutmix_bbox_and_lam, one_hot
5 |
6 | def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda', return_y1y2=False):
7 | off_value = smoothing / num_classes
8 | on_value = 1. - smoothing + off_value
9 | y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device)
10 | y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device)
11 | if return_y1y2:
12 | return y1 * lam + y2 * (1. - lam), y1.clone(), y2.clone()
13 | else:
14 | return y1 * lam + y2 * (1. - lam)
15 |
16 |
17 | class Mixup_transmix(Mixup):
18 | """ act like Mixup(), but return useful information with method transmix_label()
19 | Mixup/Cutmix that applies different params to each element or whole batch, where per-batch is set as default
20 |
21 | Args:
22 | mixup_alpha (float): mixup alpha value, mixup is active if > 0.
23 | cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
24 | cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
25 | prob (float): probability of applying mixup or cutmix per batch or element
26 | switch_prob (float): probability of switching to cutmix instead of mixup when both are active
27 | mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
28 | correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
29 | label_smoothing (float): apply label smoothing to the mixed target tensor
30 | num_classes (int): number of classes for target
31 | transmix (bool): enable TransMix or not
32 | """
33 | def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
34 | mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000):
35 | self.mixup_alpha = mixup_alpha
36 | self.cutmix_alpha = cutmix_alpha
37 | self.cutmix_minmax = cutmix_minmax
38 | if self.cutmix_minmax is not None:
39 | assert len(self.cutmix_minmax) == 2
40 | # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
41 | self.cutmix_alpha = 1.0
42 | self.mix_prob = prob
43 | self.switch_prob = switch_prob
44 | self.label_smoothing = label_smoothing
45 | self.num_classes = num_classes
46 | self.mode = mode
47 | self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
48 | self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
49 |
50 | def _mix_batch(self, x):
51 | lam, use_cutmix = self._params_per_batch()
52 |
53 | if lam == 1.:
54 | return 1.
55 | if use_cutmix:
56 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
57 | x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
58 | x[:, :, yl:yh, xl:xh] = x.flip(0)[:, :, yl:yh, xl:xh] # cutmix for input!
59 | return lam, (yl, yh, xl, xh) # return box!
60 | else:
61 | x_flipped = x.flip(0).mul_(1. - lam)
62 | x.mul_(lam).add_(x_flipped)
63 |
64 | return lam
65 |
66 |
67 | def transmix_label(self, target, attn, input_shape, ratio=0.5):
68 | """use the self information?
69 | args:
70 | attn (torch.tensor): attention map from the last Transformer with shape (N, hw)
71 | target (tuple): (target, y1, y2, use_cutmix, box)
72 | target (torch.tensor): mixed target by area-ratio
73 | y1 (torch.tensor): one-hot label for image A (background image) (N, k)
74 | y2 (torch.tensor): one-hot label for image B (cropped patch) (N, k)
75 | use_cutmix (bool): enable cutmix if True, otherwise enable Mixup
76 | box (tuple): (yl, yh, xl, xh)
77 | returns:
78 | target (torch.tensor): with shape (N, K)
79 | """
80 | # the placeholder _ is the area-based target
81 | (_, y1, y2, box) = target
82 | lam0 = (box[1]-box[0]) * (box[3]-box[2]) / (input_shape[2] * input_shape[3])
83 | mask = torch.zeros((input_shape[2], input_shape[3])).cuda()
84 | mask[box[0]:box[1], box[2]:box[3]] = 1
85 | mask = nn.Upsample(size=int(math.sqrt(attn.shape[1])))(mask.unsqueeze(0).unsqueeze(0)).int()
86 | mask = mask.view(1, -1).repeat(len(attn), 1) # (b, hw)
87 | w1, w2 = torch.sum((1-mask) * attn, dim=1), torch.sum(mask * attn, dim=1)
88 | lam1 = w2 / (w1+w2) # (b, )
89 | lam = (lam0 + lam1) / 2 # ()+(b,) ratio=0.5
90 | target = y1 * (1. - lam).unsqueeze(1) + y2 * lam.unsqueeze(1)
91 | return target
92 |
93 | def __call__(self, x, target):
94 | assert len(x) % 2 == 0, 'Batch size should be even when using this'
95 | assert self.mode == 'batch', 'Mixup mode is batch by default'
96 | lam = self._mix_batch(x) # tuple or value
97 | if isinstance(lam, tuple):
98 | lam, box = lam # lam: (b,)
99 | use_cutmix = True
100 | else: # lam is a value
101 | use_cutmix = False
102 |
103 | mixed_target, y1, y2 = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device, return_y1y2=True) # tuple or tensor
104 | if use_cutmix:
105 | return x, (mixed_target, y1, y2, box)
106 | else:
107 | return x, mixed_target
--------------------------------------------------------------------------------