├── LICENSE ├── README.md ├── configs ├── deit_b_transmix.yaml └── deit_s_transmix.yaml ├── distributed_train.sh ├── pic1.png ├── pic2.png ├── requirements.txt ├── timm └── models │ └── vision_transformer.py ├── train.py └── transmix.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2022 ByteDance 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TransMix: Attend to Mix for Vision Transformers 2 | 3 | This repository includes the official project for the paper: [*TransMix: Attend to Mix for Vision Transformers*](https://arxiv.org/abs/2111.09833), CVPR 2022 4 | 5 |

6 | 7 |

8 | 9 |

10 | 11 |

12 | 13 | # Key Feature 14 | 15 | Improve your Vision Transformer (ViT) by ~1% on ImageNet with minimal computational cost and a simple ```--transmix```. 16 | 17 | # Getting Started 18 | 19 | First, clone the repo: 20 | ```shell 21 | git clone https://github.com/Beckschen/TransMix.git 22 | ``` 23 | 24 | Then, you need to install the required packages including: [Pytorch](https://pytorch.org/) version 1.7.1, 25 | [torchvision](https://pytorch.org/vision/stable/index.html) version 0.8.2, 26 | [Timm](https://github.com/rwightman/pytorch-image-models) version 0.5.4 27 | and ```pyyaml```. To install all these packages, simply run 28 | ``` 29 | pip3 install -r requirements.txt 30 | ``` 31 | 32 | Download and extract the [ImageNet](https://imagenet.stanford.edu/) dataset to ```data``` folder. Suppose you're using 33 | 8 GPUs for training, then simply run 34 | ```shell 35 | bash ./distributed_train.sh 8 data/ --config $YOUR_CONFIG_PATH_HERE 36 | ``` 37 | 38 | By default, all our config files have enabled the training with TransMix. 39 | If you want to enable TransMix during the training of your own model, 40 | you can add a ```--transmix``` in your training script. For example: 41 | ```shell 42 | python3 -m torch.distributed.launch --nproc_per_node=8 train.py data/ --config $YOUR_CONFIG_PATH_HERE --transmix 43 | ``` 44 | 45 | Or you can simply specify ```transmix: True``` in your ```yaml``` config file like what we did in [deit_s_transmix](configs/deit_s_transmix.yaml). 46 | 47 | To evaluate your model trained with TransMix, please refer to [timm](https://github.com/rwightman/pytorch-image-models#train-validation-inference-scripts). 48 | You can also find your validation accuracy during training. 49 | 50 | # Model Zoo 51 | 52 | Coming soon! 53 | 54 | ## Acknowledgement 55 | This repository is built using the [Timm](https://github.com/rwightman/pytorch-image-models) library and 56 | the [DeiT](https://github.com/facebookresearch/deit) repository. 57 | 58 | ## License 59 | This repository is released under the Apache 2.0 license as found in the [LICENSE](LICENSE) file. 60 | 61 | ## Cite This Paper 62 | If you find our code helpful for your research, please using the following bibtex to cite our paper: 63 | 64 | ``` 65 | @InProceedings{transmix, 66 | title = {TransMix: Attend to Mix for Vision Transformers}, 67 | author = {Chen, Jie-Neng and Sun, Shuyang and He, Ju and Torr, Philip and Yuille, Alan and Bai, Song}, 68 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 69 | month = {June}, 70 | year = {2022} 71 | } 72 | ``` 73 | -------------------------------------------------------------------------------- /configs/deit_b_transmix.yaml: -------------------------------------------------------------------------------- 1 | model: "deit_base_patch16_224_return_attn" 2 | 3 | img_size: 224 4 | decay_epochs: 30 5 | opt: adamw 6 | num_classes: 1000 7 | mixup: 0.8 8 | cutmix: 1.0 9 | drop_path: 0.1 10 | dist_bn: "" 11 | model_ema: True 12 | aa: rand-m9-mstd0.5-inc1 13 | pin_mem: False 14 | model_ema_decay: 0.99996 15 | no_prefetcher: True 16 | transmix: True # enable transmix 17 | mixup_switch_prob: 0.8 18 | min_lr: 1e-5 19 | lr: 1e-3 20 | warmup_lr: 1e-6 21 | weight_decay: 5e-2 22 | warmup_epochs: 5 23 | workers: 8 24 | total_batch_size: 256 25 | -------------------------------------------------------------------------------- /configs/deit_s_transmix.yaml: -------------------------------------------------------------------------------- 1 | model: deit_small_patch16_224_return_attn 2 | warmup_lr: 1e-6 3 | img_size: 224 4 | decay_epochs: 30 5 | opt: adamw 6 | num_classes: 1000 7 | mixup: 0.8 8 | cutmix: 1.0 9 | drop_path: 0.1 10 | dist_bn: "" 11 | model_ema: True 12 | aa: rand-m9-mstd0.5-inc1 13 | pin_mem: False 14 | model_ema_decay: 0.99996 15 | no_prefetcher: True 16 | transmix: True # enable transmix 17 | mixup-switch-prob: 0.8 18 | min_lr: 1e-5 19 | lr: 1e-3 20 | weight_decay: 3e-2 21 | warmup_epochs: 20 22 | workers: 8 23 | total_batch_size: 256 24 | -------------------------------------------------------------------------------- /distributed_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | NUM_PROC=$1 3 | shift 4 | python3 -m torch.distributed.launch --nproc_per_node=$NUM_PROC train.py "$@" 5 | 6 | -------------------------------------------------------------------------------- /pic1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Beckschen/TransMix/0e4d31eb34772f9d12cc450678c4bf4ca89ba828/pic1.png -------------------------------------------------------------------------------- /pic2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Beckschen/TransMix/0e4d31eb34772f9d12cc450678c4bf4ca89ba828/pic2.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.7.1 2 | torchvision==0.8.2 3 | timm==0.5.4 4 | pyyaml 5 | -------------------------------------------------------------------------------- /timm/models/vision_transformer.py: -------------------------------------------------------------------------------- 1 | """ Vision Transformer (ViT) in PyTorch 2 | 3 | A PyTorch implement of Vision Transformers as described in: 4 | 5 | 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' 6 | - https://arxiv.org/abs/2010.11929 7 | 8 | `How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers` 9 | - https://arxiv.org/abs/2106.10270 10 | 11 | The official jax code is released and available at https://github.com/google-research/vision_transformer 12 | 13 | DeiT model defs and weights from https://github.com/facebookresearch/deit, 14 | paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 15 | 16 | Acknowledgments: 17 | * The paper authors for releasing code and weights, thanks! 18 | * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out 19 | for some einops/einsum fun 20 | * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT 21 | * Bert reference code checks against Huggingface Transformers and Tensorflow Bert 22 | 23 | Hacked together by / Copyright 2021 Ross Wightman 24 | """ 25 | import math 26 | import logging 27 | from functools import partial 28 | from collections import OrderedDict 29 | from copy import deepcopy 30 | 31 | import torch 32 | import torch.nn as nn 33 | import torch.nn.functional as F 34 | 35 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD 36 | from timm.models.helpers import build_model_with_cfg, named_apply, adapt_input_conv 37 | from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ 38 | from timm.models.registry import register_model 39 | 40 | _logger = logging.getLogger(__name__) 41 | 42 | 43 | def _cfg(url='', **kwargs): 44 | return { 45 | 'url': url, 46 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 47 | 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 48 | 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, 49 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 50 | **kwargs 51 | } 52 | 53 | 54 | default_cfgs = { 55 | # patch models (weights from official Google JAX impl) 56 | 'vit_tiny_patch16_224': _cfg( 57 | url='https://storage.googleapis.com/vit_models/augreg/' 58 | 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 59 | 'vit_tiny_patch16_384': _cfg( 60 | url='https://storage.googleapis.com/vit_models/augreg/' 61 | 'Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', 62 | input_size=(3, 384, 384), crop_pct=1.0), 63 | 'vit_small_patch32_224': _cfg( 64 | url='https://storage.googleapis.com/vit_models/augreg/' 65 | 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 66 | 'vit_small_patch32_384': _cfg( 67 | url='https://storage.googleapis.com/vit_models/augreg/' 68 | 'S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', 69 | input_size=(3, 384, 384), crop_pct=1.0), 70 | 'vit_small_patch16_224': _cfg( 71 | url='https://storage.googleapis.com/vit_models/augreg/' 72 | 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 73 | 'vit_small_patch16_384': _cfg( 74 | url='https://storage.googleapis.com/vit_models/augreg/' 75 | 'S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', 76 | input_size=(3, 384, 384), crop_pct=1.0), 77 | 'vit_base_patch32_224': _cfg( 78 | url='https://storage.googleapis.com/vit_models/augreg/' 79 | 'B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz'), 80 | 'vit_base_patch32_384': _cfg( 81 | url='https://storage.googleapis.com/vit_models/augreg/' 82 | 'B_32-i21k-300ep-lr_0.001-aug_light1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz', 83 | input_size=(3, 384, 384), crop_pct=1.0), 84 | 'vit_base_patch16_224': _cfg( 85 | url='https://storage.googleapis.com/vit_models/augreg/' 86 | 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'), 87 | 'vit_base_patch16_384': _cfg( 88 | url='https://storage.googleapis.com/vit_models/augreg/' 89 | 'B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz', 90 | input_size=(3, 384, 384), crop_pct=1.0), 91 | 'vit_base_patch8_224': _cfg( 92 | url='https://storage.googleapis.com/vit_models/augreg/' 93 | 'B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz'), 94 | 'vit_large_patch32_224': _cfg( 95 | url='', # no official model weights for this combo, only for in21k 96 | ), 97 | 'vit_large_patch32_384': _cfg( 98 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', 99 | input_size=(3, 384, 384), crop_pct=1.0), 100 | 'vit_large_patch16_224': _cfg( 101 | url='https://storage.googleapis.com/vit_models/augreg/' 102 | 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz'), 103 | 'vit_large_patch16_384': _cfg( 104 | url='https://storage.googleapis.com/vit_models/augreg/' 105 | 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', 106 | input_size=(3, 384, 384), crop_pct=1.0), 107 | 108 | # patch models, imagenet21k (weights from official Google JAX impl) 109 | 'vit_tiny_patch16_224_in21k': _cfg( 110 | url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz', 111 | num_classes=21843), 112 | 'vit_small_patch32_224_in21k': _cfg( 113 | url='https://storage.googleapis.com/vit_models/augreg/S_32-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', 114 | num_classes=21843), 115 | 'vit_small_patch16_224_in21k': _cfg( 116 | url='https://storage.googleapis.com/vit_models/augreg/S_16-i21k-300ep-lr_0.001-aug_light1-wd_0.03-do_0.0-sd_0.0.npz', 117 | num_classes=21843), 118 | 'vit_base_patch32_224_in21k': _cfg( 119 | url='https://storage.googleapis.com/vit_models/augreg/B_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.03-do_0.0-sd_0.0.npz', 120 | num_classes=21843), 121 | 'vit_base_patch16_224_in21k': _cfg( 122 | url='https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', 123 | num_classes=21843), 124 | 'vit_base_patch8_224_in21k': _cfg( 125 | url='https://storage.googleapis.com/vit_models/augreg/B_8-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0.npz', 126 | num_classes=21843), 127 | 'vit_large_patch32_224_in21k': _cfg( 128 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', 129 | num_classes=21843), 130 | 'vit_large_patch16_224_in21k': _cfg( 131 | url='https://storage.googleapis.com/vit_models/augreg/L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1.npz', 132 | num_classes=21843), 133 | 'vit_huge_patch14_224_in21k': _cfg( 134 | url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz', 135 | hf_hub='timm/vit_huge_patch14_224_in21k', 136 | num_classes=21843), 137 | 138 | # SAM trained models (https://arxiv.org/abs/2106.01548) 139 | 'vit_base_patch32_sam_224': _cfg( 140 | url='https://storage.googleapis.com/vit_models/sam/ViT-B_32.npz'), 141 | 'vit_base_patch16_sam_224': _cfg( 142 | url='https://storage.googleapis.com/vit_models/sam/ViT-B_16.npz'), 143 | 144 | # deit models (FB weights) 145 | 'deit_tiny_patch16_224': _cfg( 146 | url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth', 147 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 148 | 'deit_small_patch16_224': _cfg( 149 | url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', 150 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 151 | 'deit_base_patch16_224': _cfg( 152 | url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', 153 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 154 | 'deit_base_patch16_384': _cfg( 155 | url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth', 156 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0), 157 | 'deit_tiny_distilled_patch16_224': _cfg( 158 | url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth', 159 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), 160 | 'deit_small_distilled_patch16_224': _cfg( 161 | url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth', 162 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), 163 | 'deit_base_distilled_patch16_224': _cfg( 164 | url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', 165 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), 166 | 'deit_base_distilled_patch16_384': _cfg( 167 | url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth', 168 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0, 169 | classifier=('head', 'head_dist')), 170 | 171 | # ViT ImageNet-21K-P pretraining by MILL 172 | 'vit_base_patch16_224_miil_in21k': _cfg( 173 | url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth', 174 | mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221, 175 | ), 176 | 'vit_base_patch16_224_miil': _cfg( 177 | url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm' 178 | '/vit_base_patch16_224_1k_miil_84_4.pth', 179 | mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', 180 | ), 181 | } 182 | 183 | 184 | class Attention(nn.Module): 185 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., return_attn=False): 186 | super().__init__() 187 | self.num_heads = num_heads 188 | head_dim = dim // num_heads 189 | self.scale = head_dim ** -0.5 190 | 191 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 192 | self.attn_drop = nn.Dropout(attn_drop) 193 | self.proj = nn.Linear(dim, dim) 194 | self.proj_drop = nn.Dropout(proj_drop) 195 | 196 | self.return_attn = return_attn 197 | 198 | def forward(self, x): 199 | B, N, C = x.shape 200 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 201 | q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 202 | 203 | attn = (q @ k.transpose(-2, -1)) * self.scale 204 | attn = attn.softmax(dim=-1) 205 | attn_softmax = attn.detach().clone() 206 | attn = self.attn_drop(attn) 207 | 208 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 209 | x = self.proj(x) 210 | x = self.proj_drop(x) 211 | if self.return_attn: 212 | return x, attn_softmax 213 | return x 214 | 215 | 216 | class Block(nn.Module): 217 | 218 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., 219 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, return_attn=False): 220 | super().__init__() 221 | self.norm1 = norm_layer(dim) 222 | self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, return_attn=return_attn) 223 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 224 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 225 | self.norm2 = norm_layer(dim) 226 | mlp_hidden_dim = int(dim * mlp_ratio) 227 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 228 | self.return_attn = return_attn # modified by jchen 229 | def forward(self, x): 230 | if self.return_attn: 231 | res = x 232 | x, attn = self.attn(self.norm1(x)) 233 | x = res + self.drop_path(x) 234 | x = x + self.drop_path(self.mlp(self.norm2(x))) 235 | return x, attn 236 | 237 | x = x + self.drop_path(self.attn(self.norm1(x))) 238 | x = x + self.drop_path(self.mlp(self.norm2(x))) 239 | return x 240 | 241 | 242 | class VisionTransformer(nn.Module): 243 | """ Vision Transformer 244 | 245 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` 246 | - https://arxiv.org/abs/2010.11929 247 | 248 | Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` 249 | - https://arxiv.org/abs/2012.12877 250 | """ 251 | 252 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 253 | num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, 254 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, 255 | act_layer=None, weight_init='', return_attn=False): 256 | """ 257 | Args: 258 | img_size (int, tuple): input image size 259 | patch_size (int, tuple): patch size 260 | in_chans (int): number of input channels 261 | num_classes (int): number of classes for classification head 262 | embed_dim (int): embedding dimension 263 | depth (int): depth of transformer 264 | num_heads (int): number of attention heads 265 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 266 | qkv_bias (bool): enable bias for qkv if True 267 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 268 | distilled (bool): model includes a distillation token and head as in DeiT models 269 | drop_rate (float): dropout rate 270 | attn_drop_rate (float): attention dropout rate 271 | drop_path_rate (float): stochastic depth rate 272 | embed_layer (nn.Module): patch embedding layer 273 | norm_layer: (nn.Module): normalization layer 274 | weight_init: (str): weight init scheme 275 | """ 276 | super().__init__() 277 | self.return_attn = return_attn 278 | self.num_classes = num_classes 279 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 280 | self.num_tokens = 2 if distilled else 1 281 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 282 | act_layer = act_layer or nn.GELU 283 | 284 | self.patch_embed = embed_layer( 285 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 286 | num_patches = self.patch_embed.num_patches 287 | 288 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 289 | self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None 290 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) 291 | self.pos_drop = nn.Dropout(p=drop_rate) 292 | 293 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 294 | self.blocks = nn.Sequential(*[ 295 | Block( 296 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, 297 | attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, 298 | return_attn=return_attn and i==depth-1) 299 | for i in range(depth)]) 300 | self.norm = norm_layer(embed_dim) 301 | 302 | # Representation layer 303 | if representation_size and not distilled: 304 | self.num_features = representation_size 305 | self.pre_logits = nn.Sequential(OrderedDict([ 306 | ('fc', nn.Linear(embed_dim, representation_size)), 307 | ('act', nn.Tanh()) 308 | ])) 309 | else: 310 | self.pre_logits = nn.Identity() 311 | 312 | # Classifier head(s) 313 | self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() 314 | self.head_dist = None 315 | if distilled: 316 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 317 | 318 | self.init_weights(weight_init) 319 | 320 | def init_weights(self, mode=''): 321 | assert mode in ('jax', 'jax_nlhb', 'nlhb', '') 322 | head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. 323 | trunc_normal_(self.pos_embed, std=.02) 324 | if self.dist_token is not None: 325 | trunc_normal_(self.dist_token, std=.02) 326 | if mode.startswith('jax'): 327 | # leave cls token as zeros to match jax impl 328 | named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self) 329 | else: 330 | trunc_normal_(self.cls_token, std=.02) 331 | self.apply(_init_vit_weights) 332 | 333 | def _init_weights(self, m): 334 | # this fn left here for compat with downstream users 335 | _init_vit_weights(m) 336 | 337 | @torch.jit.ignore() 338 | def load_pretrained(self, checkpoint_path, prefix=''): 339 | _load_weights(self, checkpoint_path, prefix) 340 | 341 | @torch.jit.ignore 342 | def no_weight_decay(self): 343 | return {'pos_embed', 'cls_token', 'dist_token'} 344 | 345 | def get_classifier(self): 346 | if self.dist_token is None: 347 | return self.head 348 | else: 349 | return self.head, self.head_dist 350 | 351 | def reset_classifier(self, num_classes, global_pool=''): 352 | self.num_classes = num_classes 353 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 354 | if self.num_tokens == 2: 355 | self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() 356 | 357 | def forward_features(self, x): 358 | x = self.patch_embed(x) 359 | cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks 360 | if self.dist_token is None: 361 | x = torch.cat((cls_token, x), dim=1) 362 | else: 363 | x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) 364 | x = self.pos_drop(x + self.pos_embed) 365 | x = self.blocks(x) 366 | if self.return_attn: 367 | x, attn = x[0], x[1] 368 | attn = torch.mean(attn[:, :, 0, 1:], dim=1) # attn from cls_token to images 369 | x = self.norm(x) 370 | if self.dist_token is None: 371 | return self.pre_logits(x[:, 0]), attn 372 | else: 373 | return x[:, 0], x[:, 1], attn 374 | else: 375 | x = self.norm(x) 376 | if self.dist_token is None: 377 | return self.pre_logits(x[:, 0]) 378 | else: 379 | return x[:, 0], x[:, 1] 380 | 381 | def forward(self, x): 382 | x = self.forward_features(x) 383 | if self.return_attn: 384 | if self.head_dist is not None: 385 | x, x_dist, attn = self.head(x[0]), self.head_dist(x[1]), x[2] # x must be a tuple 386 | if self.training and not torch.jit.is_scripting(): 387 | # during inference, return the average of both classifier predictions 388 | return x, x_dist, attn 389 | else: 390 | return (x + x_dist) / 2, attn 391 | else: 392 | x, attn = x[0], x[1] 393 | x = self.head(x) 394 | return x, attn 395 | else: 396 | if self.head_dist is not None: 397 | x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple 398 | if self.training and not torch.jit.is_scripting(): 399 | # during inference, return the average of both classifier predictions 400 | return x, x_dist 401 | else: 402 | return (x + x_dist) / 2 403 | else: 404 | x = self.head(x) 405 | return x 406 | 407 | 408 | def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False): 409 | """ ViT weight initialization 410 | * When called without n, head_bias, jax_impl args it will behave exactly the same 411 | as my original init for compatibility with prev hparam / downstream use cases (ie DeiT). 412 | * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl 413 | """ 414 | if isinstance(module, nn.Linear): 415 | if name.startswith('head'): 416 | nn.init.zeros_(module.weight) 417 | nn.init.constant_(module.bias, head_bias) 418 | elif name.startswith('pre_logits'): 419 | lecun_normal_(module.weight) 420 | nn.init.zeros_(module.bias) 421 | else: 422 | if jax_impl: 423 | nn.init.xavier_uniform_(module.weight) 424 | if module.bias is not None: 425 | if 'mlp' in name: 426 | nn.init.normal_(module.bias, std=1e-6) 427 | else: 428 | nn.init.zeros_(module.bias) 429 | else: 430 | trunc_normal_(module.weight, std=.02) 431 | if module.bias is not None: 432 | nn.init.zeros_(module.bias) 433 | elif jax_impl and isinstance(module, nn.Conv2d): 434 | # NOTE conv was left to pytorch default in my original init 435 | lecun_normal_(module.weight) 436 | if module.bias is not None: 437 | nn.init.zeros_(module.bias) 438 | elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): 439 | nn.init.zeros_(module.bias) 440 | nn.init.ones_(module.weight) 441 | 442 | 443 | @torch.no_grad() 444 | def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): 445 | """ Load weights from .npz checkpoints for official Google Brain Flax implementation 446 | """ 447 | import numpy as np 448 | 449 | def _n2p(w, t=True): 450 | if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: 451 | w = w.flatten() 452 | if t: 453 | if w.ndim == 4: 454 | w = w.transpose([3, 2, 0, 1]) 455 | elif w.ndim == 3: 456 | w = w.transpose([2, 0, 1]) 457 | elif w.ndim == 2: 458 | w = w.transpose([1, 0]) 459 | return torch.from_numpy(w) 460 | 461 | w = np.load(checkpoint_path) 462 | if not prefix and 'opt/target/embedding/kernel' in w: 463 | prefix = 'opt/target/' 464 | 465 | if hasattr(model.patch_embed, 'backbone'): 466 | # hybrid 467 | backbone = model.patch_embed.backbone 468 | stem_only = not hasattr(backbone, 'stem') 469 | stem = backbone if stem_only else backbone.stem 470 | stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel']))) 471 | stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale'])) 472 | stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias'])) 473 | if not stem_only: 474 | for i, stage in enumerate(backbone.stages): 475 | for j, block in enumerate(stage.blocks): 476 | bp = f'{prefix}block{i + 1}/unit{j + 1}/' 477 | for r in range(3): 478 | getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel'])) 479 | getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale'])) 480 | getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias'])) 481 | if block.downsample is not None: 482 | block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel'])) 483 | block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale'])) 484 | block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias'])) 485 | embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) 486 | else: 487 | embed_conv_w = adapt_input_conv( 488 | model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) 489 | model.patch_embed.proj.weight.copy_(embed_conv_w) 490 | model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) 491 | model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) 492 | pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) 493 | if pos_embed_w.shape != model.pos_embed.shape: 494 | pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights 495 | pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) 496 | model.pos_embed.copy_(pos_embed_w) 497 | model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) 498 | model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) 499 | if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: 500 | model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) 501 | model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) 502 | if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: 503 | model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) 504 | model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) 505 | for i, block in enumerate(model.blocks.children()): 506 | block_prefix = f'{prefix}Transformer/encoderblock_{i}/' 507 | mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' 508 | block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) 509 | block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) 510 | block.attn.qkv.weight.copy_(torch.cat([ 511 | _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) 512 | block.attn.qkv.bias.copy_(torch.cat([ 513 | _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) 514 | block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) 515 | block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) 516 | for r in range(2): 517 | getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel'])) 518 | getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias'])) 519 | block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) 520 | block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) 521 | 522 | 523 | def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): 524 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from 525 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 526 | _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) 527 | ntok_new = posemb_new.shape[1] 528 | if num_tokens: 529 | posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] 530 | ntok_new -= num_tokens 531 | else: 532 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0] 533 | gs_old = int(math.sqrt(len(posemb_grid))) 534 | if not len(gs_new): # backwards compatibility 535 | gs_new = [int(math.sqrt(ntok_new))] * 2 536 | assert len(gs_new) >= 2 537 | _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new) 538 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 539 | posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False) 540 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) 541 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 542 | return posemb 543 | 544 | 545 | def checkpoint_filter_fn(state_dict, model): 546 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 547 | out_dict = {} 548 | if 'model' in state_dict: 549 | # For deit models 550 | state_dict = state_dict['model'] 551 | for k, v in state_dict.items(): 552 | if 'patch_embed.proj.weight' in k and len(v.shape) < 4: 553 | # For old models that I trained prior to conv based patchification 554 | O, I, H, W = model.patch_embed.proj.weight.shape 555 | v = v.reshape(O, -1, H, W) 556 | elif k == 'pos_embed' and v.shape != model.pos_embed.shape: 557 | # To resize pos embedding when using model at different size from pretrained weights 558 | v = resize_pos_embed( 559 | v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) 560 | out_dict[k] = v 561 | return out_dict 562 | 563 | 564 | def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs): 565 | default_cfg = default_cfg or default_cfgs[variant] 566 | if kwargs.get('features_only', None): 567 | raise RuntimeError('features_only not implemented for Vision Transformer models.') 568 | 569 | # NOTE this extra code to support handling of repr size for in21k pretrained models 570 | default_num_classes = default_cfg['num_classes'] 571 | num_classes = kwargs.get('num_classes', default_num_classes) 572 | repr_size = kwargs.pop('representation_size', None) 573 | if repr_size is not None and num_classes != default_num_classes: 574 | # Remove representation layer if fine-tuning. This may not always be the desired action, 575 | # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface? 576 | _logger.warning("Removing representation layer for fine-tuning.") 577 | repr_size = None 578 | 579 | model = build_model_with_cfg( 580 | VisionTransformer, variant, pretrained, 581 | default_cfg=default_cfg, 582 | representation_size=repr_size, 583 | pretrained_filter_fn=checkpoint_filter_fn, 584 | pretrained_custom_load='npz' in default_cfg['url'], 585 | **kwargs) 586 | return model 587 | 588 | 589 | 590 | @register_model 591 | def deit_small_patch16_224_return_attn(pretrained=False, **kwargs): 592 | 593 | """ an extra output for the class attention 594 | DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). 595 | ImageNet-1k weights from https://github.com/facebookresearch/deit. 596 | """ 597 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, return_attn=True, **kwargs) 598 | model = _create_vision_transformer('deit_small_patch16_224', pretrained=pretrained, **model_kwargs) 599 | return model 600 | 601 | @register_model 602 | def deit_base_patch16_224_return_attn(pretrained=False, **kwargs): 603 | """ an extra output for the class attention 604 | DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). 605 | ImageNet-1k weights from https://github.com/facebookresearch/deit. 606 | """ 607 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, return_attn=True, **kwargs) 608 | model = _create_vision_transformer('deit_base_patch16_224', pretrained=pretrained, **model_kwargs) 609 | return model 610 | 611 | @register_model 612 | def vit_tiny_patch16_224(pretrained=False, **kwargs): 613 | """ ViT-Tiny (Vit-Ti/16) 614 | """ 615 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) 616 | model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) 617 | return model 618 | 619 | 620 | @register_model 621 | def vit_tiny_patch16_384(pretrained=False, **kwargs): 622 | """ ViT-Tiny (Vit-Ti/16) @ 384x384. 623 | """ 624 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) 625 | model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **model_kwargs) 626 | return model 627 | 628 | 629 | @register_model 630 | def vit_small_patch32_224(pretrained=False, **kwargs): 631 | """ ViT-Small (ViT-S/32) 632 | """ 633 | model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) 634 | model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **model_kwargs) 635 | return model 636 | 637 | 638 | @register_model 639 | def vit_small_patch32_384(pretrained=False, **kwargs): 640 | """ ViT-Small (ViT-S/32) at 384x384. 641 | """ 642 | model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) 643 | model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **model_kwargs) 644 | return model 645 | 646 | 647 | @register_model 648 | def vit_small_patch16_224(pretrained=False, **kwargs): 649 | """ ViT-Small (ViT-S/16) 650 | NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper 651 | """ 652 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) 653 | model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs) 654 | return model 655 | 656 | 657 | @register_model 658 | def vit_small_patch16_384(pretrained=False, **kwargs): 659 | """ ViT-Small (ViT-S/16) 660 | NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper 661 | """ 662 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) 663 | model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs) 664 | return model 665 | 666 | 667 | @register_model 668 | def vit_base_patch32_224(pretrained=False, **kwargs): 669 | """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). 670 | ImageNet-1k weights fine-tuned from in21k, source https://github.com/google-research/vision_transformer. 671 | """ 672 | model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) 673 | model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs) 674 | return model 675 | 676 | 677 | @register_model 678 | def vit_base_patch32_384(pretrained=False, **kwargs): 679 | """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). 680 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. 681 | """ 682 | model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) 683 | model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs) 684 | return model 685 | 686 | 687 | @register_model 688 | def vit_base_patch16_224(pretrained=False, **kwargs): 689 | """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 690 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. 691 | """ 692 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 693 | model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs) 694 | return model 695 | 696 | 697 | @register_model 698 | def vit_base_patch16_384(pretrained=False, **kwargs): 699 | """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 700 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. 701 | """ 702 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 703 | model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs) 704 | return model 705 | 706 | 707 | @register_model 708 | def vit_base_patch8_224(pretrained=False, **kwargs): 709 | """ ViT-Base (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929). 710 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. 711 | """ 712 | model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs) 713 | model = _create_vision_transformer('vit_base_patch8_224', pretrained=pretrained, **model_kwargs) 714 | return model 715 | 716 | 717 | @register_model 718 | def vit_large_patch32_224(pretrained=False, **kwargs): 719 | """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights. 720 | """ 721 | model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) 722 | model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs) 723 | return model 724 | 725 | 726 | @register_model 727 | def vit_large_patch32_384(pretrained=False, **kwargs): 728 | """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). 729 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. 730 | """ 731 | model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) 732 | model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs) 733 | return model 734 | 735 | 736 | @register_model 737 | def vit_large_patch16_224(pretrained=False, **kwargs): 738 | """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). 739 | ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. 740 | """ 741 | model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) 742 | model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs) 743 | return model 744 | 745 | 746 | @register_model 747 | def vit_large_patch16_384(pretrained=False, **kwargs): 748 | """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). 749 | ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. 750 | """ 751 | model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) 752 | model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs) 753 | return model 754 | 755 | 756 | @register_model 757 | def vit_base_patch16_sam_224(pretrained=False, **kwargs): 758 | """ ViT-Base (ViT-B/16) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548 759 | """ 760 | # NOTE original SAM weights release worked with representation_size=768 761 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=0, **kwargs) 762 | model = _create_vision_transformer('vit_base_patch16_sam_224', pretrained=pretrained, **model_kwargs) 763 | return model 764 | 765 | 766 | @register_model 767 | def vit_base_patch32_sam_224(pretrained=False, **kwargs): 768 | """ ViT-Base (ViT-B/32) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548 769 | """ 770 | # NOTE original SAM weights release worked with representation_size=768 771 | model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=0, **kwargs) 772 | model = _create_vision_transformer('vit_base_patch32_sam_224', pretrained=pretrained, **model_kwargs) 773 | return model 774 | 775 | 776 | @register_model 777 | def vit_tiny_patch16_224_in21k(pretrained=False, **kwargs): 778 | """ ViT-Tiny (Vit-Ti/16). 779 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 780 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer 781 | """ 782 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) 783 | model = _create_vision_transformer('vit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs) 784 | return model 785 | 786 | 787 | @register_model 788 | def vit_small_patch32_224_in21k(pretrained=False, **kwargs): 789 | """ ViT-Small (ViT-S/16) 790 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 791 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer 792 | """ 793 | model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) 794 | model = _create_vision_transformer('vit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs) 795 | return model 796 | 797 | 798 | @register_model 799 | def vit_small_patch16_224_in21k(pretrained=False, **kwargs): 800 | """ ViT-Small (ViT-S/16) 801 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 802 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer 803 | """ 804 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) 805 | model = _create_vision_transformer('vit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs) 806 | return model 807 | 808 | 809 | @register_model 810 | def vit_base_patch32_224_in21k(pretrained=False, **kwargs): 811 | """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). 812 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 813 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer 814 | """ 815 | model_kwargs = dict( 816 | patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) 817 | model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs) 818 | return model 819 | 820 | 821 | @register_model 822 | def vit_base_patch16_224_in21k(pretrained=False, **kwargs): 823 | """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 824 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 825 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer 826 | """ 827 | model_kwargs = dict( 828 | patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 829 | model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs) 830 | return model 831 | 832 | 833 | @register_model 834 | def vit_base_patch8_224_in21k(pretrained=False, **kwargs): 835 | """ ViT-Base model (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929). 836 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 837 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer 838 | """ 839 | model_kwargs = dict( 840 | patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs) 841 | model = _create_vision_transformer('vit_base_patch8_224_in21k', pretrained=pretrained, **model_kwargs) 842 | return model 843 | 844 | 845 | @register_model 846 | def vit_large_patch32_224_in21k(pretrained=False, **kwargs): 847 | """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). 848 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 849 | NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights 850 | """ 851 | model_kwargs = dict( 852 | patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs) 853 | model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs) 854 | return model 855 | 856 | 857 | @register_model 858 | def vit_large_patch16_224_in21k(pretrained=False, **kwargs): 859 | """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). 860 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 861 | NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer 862 | """ 863 | model_kwargs = dict( 864 | patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) 865 | model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs) 866 | return model 867 | 868 | 869 | @register_model 870 | def vit_huge_patch14_224_in21k(pretrained=False, **kwargs): 871 | """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). 872 | ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. 873 | NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights 874 | """ 875 | model_kwargs = dict( 876 | patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs) 877 | model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs) 878 | return model 879 | 880 | 881 | @register_model 882 | def deit_tiny_patch16_224(pretrained=False, **kwargs): 883 | """ DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). 884 | ImageNet-1k weights from https://github.com/facebookresearch/deit. 885 | """ 886 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) 887 | model = _create_vision_transformer('deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) 888 | return model 889 | 890 | 891 | @register_model 892 | def deit_small_patch16_224(pretrained=False, **kwargs): 893 | """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). 894 | ImageNet-1k weights from https://github.com/facebookresearch/deit. 895 | """ 896 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) 897 | model = _create_vision_transformer('deit_small_patch16_224', pretrained=pretrained, **model_kwargs) 898 | return model 899 | 900 | 901 | @register_model 902 | def deit_base_patch16_224(pretrained=False, **kwargs): 903 | """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). 904 | ImageNet-1k weights from https://github.com/facebookresearch/deit. 905 | """ 906 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 907 | model = _create_vision_transformer('deit_base_patch16_224', pretrained=pretrained, **model_kwargs) 908 | return model 909 | 910 | 911 | @register_model 912 | def deit_base_patch16_384(pretrained=False, **kwargs): 913 | """ DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877). 914 | ImageNet-1k weights from https://github.com/facebookresearch/deit. 915 | """ 916 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 917 | model = _create_vision_transformer('deit_base_patch16_384', pretrained=pretrained, **model_kwargs) 918 | return model 919 | 920 | 921 | @register_model 922 | def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs): 923 | """ DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). 924 | ImageNet-1k weights from https://github.com/facebookresearch/deit. 925 | """ 926 | model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) 927 | model = _create_vision_transformer( 928 | 'deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) 929 | return model 930 | 931 | 932 | @register_model 933 | def deit_small_distilled_patch16_224(pretrained=False, **kwargs): 934 | """ DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). 935 | ImageNet-1k weights from https://github.com/facebookresearch/deit. 936 | """ 937 | model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) 938 | model = _create_vision_transformer( 939 | 'deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) 940 | return model 941 | 942 | 943 | @register_model 944 | def deit_base_distilled_patch16_224(pretrained=False, **kwargs): 945 | """ DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). 946 | ImageNet-1k weights from https://github.com/facebookresearch/deit. 947 | """ 948 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 949 | model = _create_vision_transformer( 950 | 'deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) 951 | return model 952 | 953 | 954 | @register_model 955 | def deit_base_distilled_patch16_384(pretrained=False, **kwargs): 956 | """ DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877). 957 | ImageNet-1k weights from https://github.com/facebookresearch/deit. 958 | """ 959 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) 960 | model = _create_vision_transformer( 961 | 'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs) 962 | return model 963 | 964 | 965 | @register_model 966 | def vit_base_patch16_224_miil_in21k(pretrained=False, **kwargs): 967 | """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 968 | Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K 969 | """ 970 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs) 971 | model = _create_vision_transformer('vit_base_patch16_224_miil_in21k', pretrained=pretrained, **model_kwargs) 972 | return model 973 | 974 | 975 | @register_model 976 | def vit_base_patch16_224_miil(pretrained=False, **kwargs): 977 | """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). 978 | Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K 979 | """ 980 | model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs) 981 | model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs) 982 | return model 983 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ ImageNet Training Script for TransMix. 3 | This script was modified from an early version of the PyTorch Image Models (timm) 4 | (https://github.com/rwightman/pytorch-image-models) 5 | Hacked together by Jieneng Chen and Shuyang Sun / Copyright 2022 ByteDance 6 | """ 7 | import argparse 8 | import time 9 | import yaml 10 | import os 11 | import logging 12 | from collections import OrderedDict 13 | from contextlib import suppress 14 | from datetime import datetime 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torchvision.utils 19 | from torch.nn.parallel import DistributedDataParallel as NativeDDP 20 | 21 | from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset 22 | from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\ 23 | convert_splitbn_model, model_parameters 24 | from timm.utils import * 25 | from timm.loss import * 26 | from timm.optim import create_optimizer_v2, optimizer_kwargs 27 | from timm.scheduler import create_scheduler 28 | from timm.utils import ApexScaler, NativeScaler 29 | 30 | try: 31 | from apex import amp 32 | from apex.parallel import DistributedDataParallel as ApexDDP 33 | from apex.parallel import convert_syncbn_model 34 | has_apex = True 35 | except ImportError: 36 | has_apex = False 37 | 38 | has_native_amp = False 39 | try: 40 | if getattr(torch.cuda.amp, 'autocast') is not None: 41 | has_native_amp = True 42 | except AttributeError: 43 | pass 44 | 45 | try: 46 | import wandb 47 | has_wandb = True 48 | except ImportError: 49 | has_wandb = False 50 | 51 | torch.backends.cudnn.benchmark = True 52 | _logger = logging.getLogger('train') 53 | 54 | # The first arg parser parses out only the --config argument, this argument is used to 55 | # load a yaml file containing key-values that override the defaults for the main parser below 56 | config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False) 57 | parser.add_argument('-c', '--config', default='', type=str, metavar='FILE', 58 | help='YAML config file specifying default arguments') 59 | 60 | 61 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') 62 | 63 | # Dataset parameters 64 | parser.add_argument('data_dir', default='./data/', type=str, metavar='DIR', # there is no default in timm!!! 65 | help='path to dataset') 66 | parser.add_argument('--dataset', '-d', metavar='NAME', default='', 67 | help='dataset type (default: ImageFolder/ImageTar if empty)') 68 | parser.add_argument('--train-split', metavar='NAME', default='train', 69 | help='dataset train split (default: train)') 70 | parser.add_argument('--val-split', metavar='NAME', default='validation', 71 | help='dataset validation split (default: validation)') 72 | parser.add_argument('--dataset-download', action='store_true', default=False, 73 | help='Allow download of dataset for torch/ and tfds/ datasets that support it.') 74 | parser.add_argument('--class-map', default='', type=str, metavar='FILENAME', 75 | help='path to class to idx mapping file (default: "")') 76 | 77 | # Model parameters 78 | parser.add_argument('--model', default='resnet50', type=str, metavar='MODEL', 79 | help='Name of model to train (default: "resnet50"') 80 | parser.add_argument('--pretrained', action='store_true', default=False, 81 | help='Start with pretrained version of specified network (if avail)') 82 | parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH', 83 | help='Initialize model from this checkpoint (default: none)') 84 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 85 | help='Resume full model and optimizer state from checkpoint (default: none)') 86 | parser.add_argument('--no-resume-opt', action='store_true', default=False, 87 | help='prevent resume of optimizer state when resuming model') 88 | parser.add_argument('--num-classes', type=int, default=None, metavar='N', 89 | help='number of label classes (Model default if None)') 90 | parser.add_argument('--gp', default=None, type=str, metavar='POOL', 91 | help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') 92 | parser.add_argument('--img-size', type=int, default=None, metavar='N', 93 | help='Image patch size (default: None => model default)') 94 | parser.add_argument('--input-size', default=None, nargs=3, type=int, 95 | metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') 96 | parser.add_argument('--crop-pct', default=None, type=float, 97 | metavar='N', help='Input image center crop percent (for validation only)') 98 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', 99 | help='Override mean pixel value of dataset') 100 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', 101 | help='Override std deviation of of dataset') 102 | parser.add_argument('--interpolation', default='', type=str, metavar='NAME', 103 | help='Image resize interpolation type (overrides model)') 104 | parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N', 105 | help='input batch size for training (default: 128)') 106 | parser.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N', 107 | help='validation batch size override (default: None)') 108 | 109 | # Optimizer parameters 110 | parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', 111 | help='Optimizer (default: "sgd"') 112 | parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON', 113 | help='Optimizer Epsilon (default: None, use opt default)') 114 | parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', 115 | help='Optimizer Betas (default: None, use opt default)') 116 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', 117 | help='Optimizer momentum (default: 0.9)') 118 | parser.add_argument('--weight-decay', type=float, default=2e-5, 119 | help='weight decay (default: 2e-5)') 120 | parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', 121 | help='Clip gradient norm (default: None, no clipping)') 122 | parser.add_argument('--clip-mode', type=str, default='norm', 123 | help='Gradient clipping mode. One of ("norm", "value", "agc")') 124 | 125 | 126 | # Learning rate schedule parameters 127 | parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', 128 | help='LR scheduler (default: "step"') 129 | parser.add_argument('--lr', type=float, default=0.05, metavar='LR', 130 | help='learning rate (default: 0.05)') 131 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', 132 | help='learning rate noise on/off epoch percentages') 133 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', 134 | help='learning rate noise limit percent (default: 0.67)') 135 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', 136 | help='learning rate noise std-dev (default: 1.0)') 137 | parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT', 138 | help='learning rate cycle len multiplier (default: 1.0)') 139 | parser.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT', 140 | help='amount to decay each learning rate cycle (default: 0.5)') 141 | parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N', 142 | help='learning rate cycle limit, cycles enabled if > 1') 143 | parser.add_argument('--lr-k-decay', type=float, default=1.0, 144 | help='learning rate k-decay for cosine/poly (default: 1.0)') 145 | parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR', 146 | help='warmup learning rate (default: 0.0001)') 147 | parser.add_argument('--min-lr', type=float, default=1e-6, metavar='LR', # 1e-5 for vit 148 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 149 | parser.add_argument('--epochs', type=int, default=300, metavar='N', 150 | help='number of epochs to train (default: 300)') 151 | parser.add_argument('--epoch-repeats', type=float, default=0., metavar='N', 152 | help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).') 153 | parser.add_argument('--start-epoch', default=None, type=int, metavar='N', 154 | help='manual epoch number (useful on restarts)') 155 | parser.add_argument('--decay-epochs', type=float, default=100, metavar='N', 156 | help='epoch interval to decay LR') 157 | parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N', 158 | help='epochs to warmup LR, if scheduler supports') 159 | parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', 160 | help='epochs to cooldown LR at min_lr, after cyclic schedule ends') 161 | parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', 162 | help='patience epochs for Plateau LR scheduler (default: 10') 163 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', 164 | help='LR decay rate (default: 0.1)') 165 | 166 | # Augmentation & regularization parameters 167 | parser.add_argument('--no-aug', action='store_true', default=False, 168 | help='Disable all training augmentation, override other train aug args') 169 | parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT', 170 | help='Random resize scale (default: 0.08 1.0)') 171 | parser.add_argument('--ratio', type=float, nargs='+', default=[3./4., 4./3.], metavar='RATIO', 172 | help='Random resize aspect ratio (default: 0.75 1.33)') 173 | parser.add_argument('--hflip', type=float, default=0.5, 174 | help='Horizontal flip training aug probability') 175 | parser.add_argument('--vflip', type=float, default=0., 176 | help='Vertical flip training aug probability') 177 | parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', 178 | help='Color jitter factor (default: 0.4)') 179 | parser.add_argument('--aa', type=str, default=None, metavar='NAME', 180 | help='Use AutoAugment policy. "v0" or "original". (default: None)'), 181 | parser.add_argument('--aug-repeats', type=int, default=0, 182 | help='Number of augmentation repetitions (distributed training only) (default: 0)') 183 | parser.add_argument('--aug-splits', type=int, default=0, 184 | help='Number of augmentation splits (default: 0, valid: 0 or >=2)') 185 | parser.add_argument('--jsd-loss', action='store_true', default=False, 186 | help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.') 187 | parser.add_argument('--bce-loss', action='store_true', default=False, 188 | help='Enable BCE loss w/ Mixup/CutMix use.') 189 | parser.add_argument('--bce-target-thresh', type=float, default=None, 190 | help='Threshold for binarizing softened BCE targets (default: None, disabled)') 191 | parser.add_argument('--reprob', type=float, default=0., metavar='PCT', 192 | help='Random erase prob (default: 0.)') 193 | parser.add_argument('--remode', type=str, default='pixel', 194 | help='Random erase mode (default: "pixel")') 195 | parser.add_argument('--recount', type=int, default=1, 196 | help='Random erase count (default: 1)') 197 | parser.add_argument('--resplit', action='store_true', default=False, 198 | help='Do not random erase first (clean) augmentation split') 199 | parser.add_argument('--mixup', type=float, default=0.0, 200 | help='mixup alpha, mixup enabled if > 0. (default: 0.)') 201 | parser.add_argument('--cutmix', type=float, default=0.0, 202 | help='cutmix alpha, cutmix enabled if > 0. (default: 0.)') 203 | parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, 204 | help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') 205 | parser.add_argument('--mixup-prob', type=float, default=1.0, 206 | help='Probability of performing mixup or cutmix when either/both is enabled') 207 | parser.add_argument('--mixup-switch-prob', type=float, default=0.5, 208 | help='Probability of switching to cutmix when both mixup and cutmix enabled') 209 | parser.add_argument('--mixup-mode', type=str, default='batch', 210 | help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') 211 | parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', 212 | help='Turn off mixup after this epoch, disabled if 0 (default: 0)') 213 | parser.add_argument('--smoothing', type=float, default=0.1, 214 | help='Label smoothing (default: 0.1)') 215 | parser.add_argument('--train-interpolation', type=str, default='random', 216 | help='Training interpolation (random, bilinear, bicubic default: "random")') 217 | parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', 218 | help='Dropout rate (default: 0.)') 219 | parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT', 220 | help='Drop connect rate, DEPRECATED, use drop-path (default: None)') 221 | parser.add_argument('--drop-path', type=float, default=None, metavar='PCT', 222 | help='Drop path rate (default: None)') 223 | parser.add_argument('--drop-block', type=float, default=None, metavar='PCT', 224 | help='Drop block rate (default: None)') 225 | 226 | # Batch norm parameters (only works with gen_efficientnet based models currently) 227 | parser.add_argument('--bn-momentum', type=float, default=None, 228 | help='BatchNorm momentum override (if not None)') 229 | parser.add_argument('--bn-eps', type=float, default=None, 230 | help='BatchNorm epsilon override (if not None)') 231 | parser.add_argument('--sync-bn', action='store_true', 232 | help='Enable NVIDIA Apex or Torch synchronized BatchNorm.') 233 | parser.add_argument('--dist-bn', type=str, default='reduce', 234 | help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")') 235 | parser.add_argument('--split-bn', action='store_true', 236 | help='Enable separate BN layers per augmentation split.') 237 | 238 | # Model Exponential Moving Average 239 | parser.add_argument('--model-ema', action='store_true', default=False, 240 | help='Enable tracking moving average of model weights') 241 | parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, 242 | help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.') 243 | parser.add_argument('--model-ema-decay', type=float, default=0.9998, 244 | help='decay factor for model weights moving average (default: 0.9998)') 245 | 246 | # Misc 247 | parser.add_argument('--seed', type=int, default=42, metavar='S', 248 | help='random seed (default: 42)') 249 | parser.add_argument('--worker-seeding', type=str, default='all', 250 | help='worker seed mode (default: all)') 251 | parser.add_argument('--log-interval', type=int, default=50, metavar='N', 252 | help='how many batches to wait before logging training status') 253 | parser.add_argument('--recovery-interval', type=int, default=0, metavar='N', 254 | help='how many batches to wait before writing recovery checkpoint') 255 | parser.add_argument('--checkpoint-hist', type=int, default=10, metavar='N', 256 | help='number of checkpoints to keep (default: 10)') 257 | parser.add_argument('-j', '--workers', type=int, default=4, metavar='N', 258 | help='how many training processes to use (default: 4)') 259 | parser.add_argument('--save-images', action='store_true', default=False, 260 | help='save images of input bathes every log interval for debugging') 261 | parser.add_argument('--amp', action='store_true', default=False, 262 | help='use NVIDIA Apex AMP or Native AMP for mixed precision training') 263 | parser.add_argument('--apex-amp', action='store_true', default=False, 264 | help='Use NVIDIA Apex AMP mixed precision') 265 | parser.add_argument('--native-amp', action='store_true', default=False, 266 | help='Use Native Torch AMP mixed precision') 267 | parser.add_argument('--no-ddp-bb', action='store_true', default=False, 268 | help='Force broadcast buffers for native DDP to off.') 269 | parser.add_argument('--channels-last', action='store_true', default=False, 270 | help='Use channels_last memory layout') 271 | parser.add_argument('--pin-mem', action='store_true', default=False, 272 | help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') 273 | parser.add_argument('--no-prefetcher', action='store_true', default=False, 274 | help='disable fast prefetcher') 275 | parser.add_argument('--output', default='', type=str, metavar='PATH', 276 | help='path to output folder (default: none, current dir)') 277 | parser.add_argument('--experiment', default='', type=str, metavar='NAME', 278 | help='name of train experiment, name of sub-folder for output') 279 | parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC', 280 | help='Best metric (default: "top1"') 281 | parser.add_argument('--tta', type=int, default=0, metavar='N', 282 | help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') 283 | parser.add_argument("--local_rank", default=0, type=int) 284 | parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False, 285 | help='use the multi-epochs-loader to save time at the beginning of every epoch') 286 | parser.add_argument('--torchscript', dest='torchscript', action='store_true', 287 | help='convert model torchscript for inference') 288 | parser.add_argument('--fuser', default='', type=str, 289 | help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") 290 | parser.add_argument('--log-wandb', action='store_true', default=False, 291 | help='log training and validation metrics to wandb') 292 | # new flags we added to timm. 293 | parser.add_argument('--transmix', action='store_true', default=False, help='') 294 | parser.add_argument('--total-batch-size', type=int, default=None, 295 | help='input batch size for training (default: None), batch-size = total-batch-size / world_size') 296 | def _parse_args(): 297 | # Do we have a config file to parse? 298 | args_config, remaining = config_parser.parse_known_args() 299 | if args_config.config: 300 | with open(args_config.config, 'r') as f: 301 | cfg = yaml.safe_load(f) 302 | parser.set_defaults(**cfg) 303 | 304 | # The main arg parser parses the rest of the args, the usual 305 | # defaults will have been overridden if config file specified. 306 | args = parser.parse_args(remaining) 307 | 308 | # Cache the args as a text string to save them in the output dir later 309 | args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) 310 | return args, args_text 311 | 312 | 313 | def main(): 314 | setup_default_logging() 315 | args, args_text = _parse_args() 316 | print(args) 317 | 318 | if args.log_wandb: 319 | if has_wandb: 320 | wandb.init(project=args.experiment, config=args) 321 | else: 322 | _logger.warning("You've requested to log metrics to wandb but package not found. " 323 | "Metrics not being logged to wandb, try `pip install wandb`") 324 | 325 | args.prefetcher = not args.no_prefetcher 326 | args.distributed = False 327 | if 'WORLD_SIZE' in os.environ: 328 | args.distributed = int(os.environ['WORLD_SIZE']) > 1 329 | args.device = 'cuda:0' 330 | args.world_size = 1 331 | args.rank = 0 # global rank 332 | if args.distributed: 333 | args.device = 'cuda:%d' % args.local_rank 334 | torch.cuda.set_device(args.local_rank) 335 | torch.distributed.init_process_group(backend='nccl', init_method='env://') 336 | args.world_size = torch.distributed.get_world_size() 337 | args.rank = torch.distributed.get_rank() 338 | _logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' 339 | % (args.rank, args.world_size)) 340 | else: 341 | _logger.info('Training with a single process on 1 GPUs.') 342 | assert args.rank >= 0 343 | 344 | # resolve AMP arguments based on PyTorch / Apex availability 345 | use_amp = None 346 | if args.amp: 347 | # `--amp` chooses native amp before apex (APEX ver not actively maintained) 348 | if has_native_amp: 349 | args.native_amp = True 350 | elif has_apex: 351 | args.apex_amp = True 352 | if args.apex_amp and has_apex: 353 | use_amp = 'apex' 354 | elif args.native_amp and has_native_amp: 355 | use_amp = 'native' 356 | elif args.apex_amp or args.native_amp: 357 | _logger.warning("Neither APEX or native Torch AMP is available, using float32. " 358 | "Install NVIDA apex or upgrade to PyTorch 1.6") 359 | 360 | random_seed(args.seed, args.rank) 361 | 362 | model = create_model( 363 | args.model, 364 | pretrained=args.pretrained, 365 | num_classes=args.num_classes, 366 | drop_rate=args.drop, 367 | drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path 368 | drop_path_rate=args.drop_path, 369 | drop_block_rate=args.drop_block, 370 | global_pool=args.gp, 371 | bn_momentum=args.bn_momentum, 372 | bn_eps=args.bn_eps, 373 | scriptable=args.torchscript, 374 | checkpoint_path=args.initial_checkpoint) 375 | if args.num_classes is None: 376 | assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' 377 | args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly 378 | 379 | if args.local_rank == 0: 380 | _logger.info( 381 | f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}') 382 | 383 | data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) 384 | 385 | # setup augmentation batch splits for contrastive loss or split bn 386 | num_aug_splits = 0 387 | if args.aug_splits > 0: 388 | assert args.aug_splits > 1, 'A split of 1 makes no sense' 389 | num_aug_splits = args.aug_splits 390 | 391 | # enable split bn (separate bn stats per batch-portion) 392 | if args.split_bn: 393 | assert num_aug_splits > 1 or args.resplit 394 | model = convert_splitbn_model(model, max(num_aug_splits, 2)) 395 | 396 | # move model to GPU, enable channels last layout if set 397 | model.cuda() 398 | if args.channels_last: 399 | model = model.to(memory_format=torch.channels_last) 400 | 401 | # setup synchronized BatchNorm for distributed training 402 | if args.distributed and args.sync_bn: 403 | assert not args.split_bn 404 | if has_apex and use_amp == 'apex': 405 | # Apex SyncBN preferred unless native amp is activated 406 | model = convert_syncbn_model(model) 407 | else: 408 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 409 | if args.local_rank == 0: 410 | _logger.info( 411 | 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 412 | 'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.') 413 | 414 | if args.torchscript: 415 | assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' 416 | assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' 417 | model = torch.jit.script(model) 418 | 419 | optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args)) 420 | 421 | # setup automatic mixed-precision (AMP) loss scaling and op casting 422 | amp_autocast = suppress # do nothing 423 | loss_scaler = None 424 | if use_amp == 'apex': 425 | model, optimizer = amp.initialize(model, optimizer, opt_level='O1') 426 | loss_scaler = ApexScaler() 427 | if args.local_rank == 0: 428 | _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') 429 | elif use_amp == 'native': 430 | amp_autocast = torch.cuda.amp.autocast 431 | loss_scaler = NativeScaler() 432 | if args.local_rank == 0: 433 | _logger.info('Using native Torch AMP. Training in mixed precision.') 434 | else: 435 | if args.local_rank == 0: 436 | _logger.info('AMP not enabled. Training in float32.') 437 | 438 | # optionally resume from a checkpoint 439 | resume_epoch = None 440 | if args.resume: 441 | resume_epoch = resume_checkpoint( 442 | model, args.resume, 443 | optimizer=None if args.no_resume_opt else optimizer, 444 | loss_scaler=None if args.no_resume_opt else loss_scaler, 445 | log_info=args.local_rank == 0) 446 | 447 | # setup exponential moving average of model weights, SWA could be used here too 448 | model_ema = None 449 | if args.model_ema: 450 | # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper 451 | model_ema = ModelEmaV2( 452 | model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None) 453 | if args.resume: 454 | load_checkpoint(model_ema.module, args.resume, use_ema=True) 455 | 456 | # setup distributed training 457 | if args.distributed: 458 | if has_apex and use_amp == 'apex': 459 | # Apex DDP preferred unless native amp is activated 460 | if args.local_rank == 0: 461 | _logger.info("Using NVIDIA APEX DistributedDataParallel.") 462 | model = ApexDDP(model, delay_allreduce=True) 463 | else: 464 | if args.local_rank == 0: 465 | _logger.info("Using native Torch DistributedDataParallel.") 466 | model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb) 467 | # NOTE: EMA model does not need to be wrapped by DDP 468 | 469 | # setup learning rate schedule and starting epoch 470 | lr_scheduler, num_epochs = create_scheduler(args, optimizer) 471 | start_epoch = 0 472 | if args.start_epoch is not None: 473 | # a specified start_epoch will always override the resume epoch 474 | start_epoch = args.start_epoch 475 | elif resume_epoch is not None: 476 | start_epoch = resume_epoch 477 | if lr_scheduler is not None and start_epoch > 0: 478 | lr_scheduler.step(start_epoch) 479 | 480 | if args.local_rank == 0: 481 | _logger.info('Scheduled epochs: {}'.format(num_epochs)) 482 | 483 | # create the train and eval datasets 484 | dataset_train = create_dataset( 485 | args.dataset, root=args.data_dir, split=args.train_split, is_training=True, 486 | class_map=args.class_map, 487 | download=args.dataset_download, 488 | batch_size=args.batch_size, 489 | repeats=args.epoch_repeats) 490 | dataset_eval = create_dataset( 491 | args.dataset, root=args.data_dir, split=args.val_split, is_training=False, 492 | class_map=args.class_map, 493 | download=args.dataset_download, 494 | batch_size=args.batch_size) 495 | 496 | # setup mixup / cutmix 497 | collate_fn = None 498 | mixup_fn = None 499 | mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None 500 | if mixup_active: 501 | mixup_args = dict( 502 | mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, 503 | prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, 504 | label_smoothing=args.smoothing, num_classes=args.num_classes) 505 | if args.transmix: 506 | # wrap mixup_fn with TransMix helper, disable args.prefetcher 507 | from transmix import Mixup_transmix 508 | mixup_fn = Mixup_transmix(**mixup_args) 509 | else: 510 | if args.prefetcher: 511 | assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) 512 | collate_fn = FastCollateMixup(**mixup_args) 513 | else: 514 | mixup_fn = Mixup(**mixup_args) 515 | 516 | 517 | # wrap dataset in AugMix helper 518 | if num_aug_splits > 1: 519 | dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) 520 | 521 | # create data loaders w/ augmentation pipeiine 522 | train_interpolation = args.train_interpolation 523 | if args.no_aug or not train_interpolation: 524 | train_interpolation = data_config['interpolation'] 525 | 526 | if args.total_batch_size: 527 | args.batch_size = args.total_batch_size // args.world_size 528 | 529 | loader_train = create_loader( 530 | dataset_train, 531 | input_size=data_config['input_size'], 532 | batch_size=args.batch_size, 533 | is_training=True, 534 | use_prefetcher=args.prefetcher, 535 | no_aug=args.no_aug, 536 | re_prob=args.reprob, 537 | re_mode=args.remode, 538 | re_count=args.recount, 539 | re_split=args.resplit, 540 | scale=args.scale, 541 | ratio=args.ratio, 542 | hflip=args.hflip, 543 | vflip=args.vflip, 544 | color_jitter=args.color_jitter, 545 | auto_augment=args.aa, 546 | num_aug_repeats=args.aug_repeats, 547 | num_aug_splits=num_aug_splits, 548 | interpolation=train_interpolation, 549 | mean=data_config['mean'], 550 | std=data_config['std'], 551 | num_workers=args.workers, 552 | distributed=args.distributed, 553 | collate_fn=collate_fn, 554 | pin_memory=args.pin_mem, 555 | use_multi_epochs_loader=args.use_multi_epochs_loader, 556 | worker_seeding=args.worker_seeding, 557 | ) 558 | 559 | loader_eval = create_loader( 560 | dataset_eval, 561 | input_size=data_config['input_size'], 562 | batch_size=args.validation_batch_size or args.batch_size, 563 | is_training=False, 564 | use_prefetcher=args.prefetcher, 565 | interpolation=data_config['interpolation'], 566 | mean=data_config['mean'], 567 | std=data_config['std'], 568 | num_workers=args.workers, 569 | distributed=args.distributed, 570 | crop_pct=data_config['crop_pct'], 571 | pin_memory=args.pin_mem, 572 | ) 573 | 574 | # setup loss function 575 | if args.jsd_loss: 576 | assert num_aug_splits > 1 # JSD only valid with aug splits set 577 | train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing) 578 | elif mixup_active: 579 | # smoothing is handled with mixup target transform which outputs sparse, soft targets 580 | if args.bce_loss: 581 | train_loss_fn = BinaryCrossEntropy(target_threshold=args.bce_target_thresh) 582 | else: 583 | train_loss_fn = SoftTargetCrossEntropy() 584 | elif args.smoothing: 585 | if args.bce_loss: 586 | train_loss_fn = BinaryCrossEntropy(smoothing=args.smoothing, target_threshold=args.bce_target_thresh) 587 | else: 588 | train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing) 589 | else: 590 | train_loss_fn = nn.CrossEntropyLoss() 591 | train_loss_fn = train_loss_fn.cuda() 592 | validate_loss_fn = nn.CrossEntropyLoss().cuda() 593 | 594 | # setup checkpoint saver and eval metric tracking 595 | eval_metric = args.eval_metric 596 | best_metric = None 597 | best_epoch = None 598 | saver = None 599 | output_dir = None 600 | if args.rank == 0: 601 | if args.experiment: 602 | exp_name = args.experiment 603 | else: 604 | exp_name = '-'.join([ 605 | datetime.now().strftime("%Y%m%d-%H%M%S"), 606 | safe_model_name(args.model), 607 | str(data_config['input_size'][-1]) 608 | ]) 609 | output_dir = get_outdir(args.output if args.output else './output/train', exp_name) 610 | decreasing = True if eval_metric == 'loss' else False 611 | saver = CheckpointSaver( 612 | model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, 613 | checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist) 614 | with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: 615 | f.write(args_text) 616 | 617 | try: 618 | for epoch in range(start_epoch, num_epochs): 619 | if args.distributed and hasattr(loader_train.sampler, 'set_epoch'): 620 | loader_train.sampler.set_epoch(epoch) 621 | 622 | train_metrics = train_one_epoch( 623 | epoch, model, loader_train, optimizer, train_loss_fn, args, 624 | lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, 625 | amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) 626 | 627 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'): 628 | if args.local_rank == 0: 629 | _logger.info("Distributing BatchNorm running means and vars") 630 | distribute_bn(model, args.world_size, args.dist_bn == 'reduce') 631 | 632 | eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) 633 | 634 | if model_ema is not None and not args.model_ema_force_cpu: 635 | if args.distributed and args.dist_bn in ('broadcast', 'reduce'): 636 | distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') 637 | ema_eval_metrics = validate( 638 | model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') 639 | eval_metrics = ema_eval_metrics 640 | 641 | if lr_scheduler is not None: 642 | # step LR for next epoch 643 | lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) 644 | 645 | if output_dir is not None: 646 | update_summary( 647 | epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), 648 | write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb) 649 | 650 | if saver is not None: 651 | # save proper checkpoint with eval metric 652 | save_metric = eval_metrics[eval_metric] 653 | best_metric, best_epoch = saver.save_checkpoint(epoch, metric=save_metric) 654 | 655 | except KeyboardInterrupt: 656 | pass 657 | if best_metric is not None: 658 | _logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch)) 659 | 660 | 661 | def train_one_epoch( 662 | epoch, model, loader, optimizer, loss_fn, args, 663 | lr_scheduler=None, saver=None, output_dir=None, amp_autocast=suppress, 664 | loss_scaler=None, model_ema=None, mixup_fn=None): 665 | 666 | if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: 667 | if args.prefetcher and loader.mixup_enabled: 668 | loader.mixup_enabled = False 669 | elif mixup_fn is not None: 670 | mixup_fn.mixup_enabled = False 671 | 672 | second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 673 | batch_time_m = AverageMeter() 674 | data_time_m = AverageMeter() 675 | losses_m = AverageMeter() 676 | 677 | model.train() 678 | 679 | end = time.time() 680 | last_idx = len(loader) - 1 681 | num_updates = epoch * len(loader) 682 | for batch_idx, (input, target) in enumerate(loader): 683 | last_batch = batch_idx == last_idx 684 | data_time_m.update(time.time() - end) 685 | if not args.prefetcher: 686 | input, target = input.cuda(), target.cuda() 687 | if mixup_fn is not None: 688 | input, target = mixup_fn(input, target) # target (B, K), or target is tuple under transmix 689 | 690 | if args.channels_last: 691 | input = input.contiguous(memory_format=torch.channels_last) 692 | 693 | with amp_autocast(): 694 | output = model(input) 695 | if args.transmix: 696 | (output, attn) = output # attention from cls_token to images: (b, hw) 697 | if isinstance(target, tuple): # target is tuple of (target, y1, y2, lam) when switch to cutmix 698 | target = mixup_fn.transmix_label(target, attn, input.shape) 699 | loss = loss_fn(output, target) 700 | 701 | 702 | if not args.distributed: 703 | losses_m.update(loss.item(), input.size(0)) 704 | 705 | optimizer.zero_grad() 706 | if loss_scaler is not None: 707 | loss_scaler( 708 | loss, optimizer, 709 | clip_grad=args.clip_grad, clip_mode=args.clip_mode, 710 | parameters=model_parameters(model, exclude_head='agc' in args.clip_mode), 711 | create_graph=second_order) 712 | else: 713 | loss.backward(create_graph=second_order) 714 | if args.clip_grad is not None: 715 | dispatch_clip_grad( 716 | model_parameters(model, exclude_head='agc' in args.clip_mode), 717 | value=args.clip_grad, mode=args.clip_mode) 718 | optimizer.step() 719 | 720 | if model_ema is not None: 721 | model_ema.update(model) 722 | 723 | torch.cuda.synchronize() 724 | num_updates += 1 725 | batch_time_m.update(time.time() - end) 726 | if last_batch or batch_idx % args.log_interval == 0: 727 | lrl = [param_group['lr'] for param_group in optimizer.param_groups] 728 | lr = sum(lrl) / len(lrl) 729 | 730 | if args.distributed: 731 | reduced_loss = reduce_tensor(loss.data, args.world_size) 732 | losses_m.update(reduced_loss.item(), input.size(0)) 733 | 734 | if args.local_rank == 0: 735 | _logger.info( 736 | 'Train: {} [{:>4d}/{} ({:>3.0f}%)] ' 737 | 'Loss: {loss.val:#.4g} ({loss.avg:#.3g}) ' 738 | 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s ' 739 | '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 740 | 'LR: {lr:.3e} ' 741 | 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format( 742 | epoch, 743 | batch_idx, len(loader), 744 | 100. * batch_idx / last_idx, 745 | loss=losses_m, 746 | batch_time=batch_time_m, 747 | rate=input.size(0) * args.world_size / batch_time_m.val, 748 | rate_avg=input.size(0) * args.world_size / batch_time_m.avg, 749 | lr=lr, 750 | data_time=data_time_m)) 751 | 752 | if args.save_images and output_dir: 753 | torchvision.utils.save_image( 754 | input, 755 | os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx), 756 | padding=0, 757 | normalize=True) 758 | 759 | if saver is not None and args.recovery_interval and ( 760 | last_batch or (batch_idx + 1) % args.recovery_interval == 0): 761 | saver.save_recovery(epoch, batch_idx=batch_idx) 762 | 763 | if lr_scheduler is not None: 764 | lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) 765 | 766 | end = time.time() 767 | # end for 768 | 769 | if hasattr(optimizer, 'sync_lookahead'): 770 | optimizer.sync_lookahead() 771 | 772 | return OrderedDict([('loss', losses_m.avg)]) 773 | 774 | 775 | def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''): 776 | batch_time_m = AverageMeter() 777 | losses_m = AverageMeter() 778 | top1_m = AverageMeter() 779 | top5_m = AverageMeter() 780 | 781 | model.eval() 782 | 783 | end = time.time() 784 | last_idx = len(loader) - 1 785 | with torch.no_grad(): 786 | for batch_idx, (input, target) in enumerate(loader): 787 | last_batch = batch_idx == last_idx 788 | if not args.prefetcher: 789 | input = input.cuda() 790 | target = target.cuda() 791 | if args.channels_last: 792 | input = input.contiguous(memory_format=torch.channels_last) 793 | 794 | with amp_autocast(): 795 | output = model(input) 796 | if isinstance(output, (tuple, list)): 797 | output = output[0] 798 | 799 | # augmentation reduction 800 | reduce_factor = args.tta 801 | if reduce_factor > 1: 802 | output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) 803 | target = target[0:target.size(0):reduce_factor] 804 | 805 | loss = loss_fn(output, target) 806 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 807 | 808 | if args.distributed: 809 | reduced_loss = reduce_tensor(loss.data, args.world_size) 810 | acc1 = reduce_tensor(acc1, args.world_size) 811 | acc5 = reduce_tensor(acc5, args.world_size) 812 | else: 813 | reduced_loss = loss.data 814 | 815 | torch.cuda.synchronize() 816 | 817 | losses_m.update(reduced_loss.item(), input.size(0)) 818 | top1_m.update(acc1.item(), output.size(0)) 819 | top5_m.update(acc5.item(), output.size(0)) 820 | 821 | batch_time_m.update(time.time() - end) 822 | end = time.time() 823 | if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0): 824 | log_name = 'Test' + log_suffix 825 | _logger.info( 826 | '{0}: [{1:>4d}/{2}] ' 827 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 828 | 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 829 | 'Acc@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' 830 | 'Acc@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format( 831 | log_name, batch_idx, last_idx, batch_time=batch_time_m, 832 | loss=losses_m, top1=top1_m, top5=top5_m)) 833 | 834 | metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)]) 835 | 836 | return metrics 837 | 838 | 839 | if __name__ == '__main__': 840 | main() -------------------------------------------------------------------------------- /transmix.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | from timm.data.mixup import Mixup, cutmix_bbox_and_lam, one_hot 5 | 6 | def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda', return_y1y2=False): 7 | off_value = smoothing / num_classes 8 | on_value = 1. - smoothing + off_value 9 | y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device) 10 | y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device) 11 | if return_y1y2: 12 | return y1 * lam + y2 * (1. - lam), y1.clone(), y2.clone() 13 | else: 14 | return y1 * lam + y2 * (1. - lam) 15 | 16 | 17 | class Mixup_transmix(Mixup): 18 | """ act like Mixup(), but return useful information with method transmix_label() 19 | Mixup/Cutmix that applies different params to each element or whole batch, where per-batch is set as default 20 | 21 | Args: 22 | mixup_alpha (float): mixup alpha value, mixup is active if > 0. 23 | cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0. 24 | cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None. 25 | prob (float): probability of applying mixup or cutmix per batch or element 26 | switch_prob (float): probability of switching to cutmix instead of mixup when both are active 27 | mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element) 28 | correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders 29 | label_smoothing (float): apply label smoothing to the mixed target tensor 30 | num_classes (int): number of classes for target 31 | transmix (bool): enable TransMix or not 32 | """ 33 | def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5, 34 | mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000): 35 | self.mixup_alpha = mixup_alpha 36 | self.cutmix_alpha = cutmix_alpha 37 | self.cutmix_minmax = cutmix_minmax 38 | if self.cutmix_minmax is not None: 39 | assert len(self.cutmix_minmax) == 2 40 | # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe 41 | self.cutmix_alpha = 1.0 42 | self.mix_prob = prob 43 | self.switch_prob = switch_prob 44 | self.label_smoothing = label_smoothing 45 | self.num_classes = num_classes 46 | self.mode = mode 47 | self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix 48 | self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop) 49 | 50 | def _mix_batch(self, x): 51 | lam, use_cutmix = self._params_per_batch() 52 | 53 | if lam == 1.: 54 | return 1. 55 | if use_cutmix: 56 | (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( 57 | x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) 58 | x[:, :, yl:yh, xl:xh] = x.flip(0)[:, :, yl:yh, xl:xh] # cutmix for input! 59 | return lam, (yl, yh, xl, xh) # return box! 60 | else: 61 | x_flipped = x.flip(0).mul_(1. - lam) 62 | x.mul_(lam).add_(x_flipped) 63 | 64 | return lam 65 | 66 | 67 | def transmix_label(self, target, attn, input_shape, ratio=0.5): 68 | """use the self information? 69 | args: 70 | attn (torch.tensor): attention map from the last Transformer with shape (N, hw) 71 | target (tuple): (target, y1, y2, use_cutmix, box) 72 | target (torch.tensor): mixed target by area-ratio 73 | y1 (torch.tensor): one-hot label for image A (background image) (N, k) 74 | y2 (torch.tensor): one-hot label for image B (cropped patch) (N, k) 75 | use_cutmix (bool): enable cutmix if True, otherwise enable Mixup 76 | box (tuple): (yl, yh, xl, xh) 77 | returns: 78 | target (torch.tensor): with shape (N, K) 79 | """ 80 | # the placeholder _ is the area-based target 81 | (_, y1, y2, box) = target 82 | lam0 = (box[1]-box[0]) * (box[3]-box[2]) / (input_shape[2] * input_shape[3]) 83 | mask = torch.zeros((input_shape[2], input_shape[3])).cuda() 84 | mask[box[0]:box[1], box[2]:box[3]] = 1 85 | mask = nn.Upsample(size=int(math.sqrt(attn.shape[1])))(mask.unsqueeze(0).unsqueeze(0)).int() 86 | mask = mask.view(1, -1).repeat(len(attn), 1) # (b, hw) 87 | w1, w2 = torch.sum((1-mask) * attn, dim=1), torch.sum(mask * attn, dim=1) 88 | lam1 = w2 / (w1+w2) # (b, ) 89 | lam = (lam0 + lam1) / 2 # ()+(b,) ratio=0.5 90 | target = y1 * (1. - lam).unsqueeze(1) + y2 * lam.unsqueeze(1) 91 | return target 92 | 93 | def __call__(self, x, target): 94 | assert len(x) % 2 == 0, 'Batch size should be even when using this' 95 | assert self.mode == 'batch', 'Mixup mode is batch by default' 96 | lam = self._mix_batch(x) # tuple or value 97 | if isinstance(lam, tuple): 98 | lam, box = lam # lam: (b,) 99 | use_cutmix = True 100 | else: # lam is a value 101 | use_cutmix = False 102 | 103 | mixed_target, y1, y2 = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device, return_y1y2=True) # tuple or tensor 104 | if use_cutmix: 105 | return x, (mixed_target, y1, y2, box) 106 | else: 107 | return x, mixed_target --------------------------------------------------------------------------------