├── .gitignore ├── LICENSE ├── README.md ├── configs ├── convit │ ├── convit-ti_c100_base.yaml │ └── convit-ti_c100_ours.yaml ├── cvt │ ├── cvt-13_c100_base.yaml │ └── cvt-13_c100_ours.yaml ├── deit │ ├── deit-ti_c100_base.yaml │ ├── deit-ti_c100_ours.yaml │ └── deit-ti_c100_ours_offline.yaml ├── pit │ ├── pit-ti_c100_base.yaml │ └── pit-ti_c100_ours.yaml ├── pvt │ ├── pvt-ti_c100_base.yaml │ └── pvt-ti_c100_ours.yaml ├── pvtv2 │ ├── pvtv2-b0_c100_base.yaml │ └── pvtv2-b0_c100_ours.yaml ├── resnet │ ├── r-18_c100.yaml │ └── r-56_c100.yaml └── t2t │ ├── t2t-14_c100_base.yaml │ ├── t2t-14_c100_ours.yaml │ ├── t2t-7_c100_base.yaml │ └── t2t-7_c100_ours.yaml ├── method.png ├── precompute_feature.py ├── pycls ├── __init__.py ├── core │ ├── __init__.py │ ├── benchmark.py │ ├── builders.py │ ├── checkpoint.py │ ├── config.py │ ├── distributed.py │ ├── io.py │ ├── logging.py │ ├── meters.py │ ├── net.py │ ├── optimizer.py │ ├── timer.py │ └── trainer.py ├── datasets │ ├── __init__.py │ ├── base.py │ ├── chaoyang.py │ ├── cifar100.py │ ├── flowers.py │ ├── loader.py │ ├── tiny_imagenet.py │ └── transforms.py └── models │ ├── __init__.py │ ├── build.py │ ├── cnns │ ├── __init__.py │ ├── base.py │ ├── blocks.py │ └── resnet.py │ ├── distill.py │ └── transformers │ ├── __init__.py │ ├── base.py │ ├── common.py │ ├── convit.py │ ├── cvt.py │ ├── deit.py │ ├── pit.py │ ├── pvt.py │ ├── pvt_v2.py │ └── t2t_vit.py ├── requirements.txt └── run_net.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *$py.class 4 | 5 | *.so 6 | 7 | build/ 8 | *.egg-info/ 9 | *.egg 10 | 11 | *.swn 12 | *.swo 13 | *.swp 14 | 15 | .idea/ 16 | 17 | .DS_STORE 18 | 19 | /.vscode/ 20 | /data/ 21 | /work_dirs/ 22 | /temp/ 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Locality Guidance for Improving Vision Transformers on Tiny Datasets (ECCV 2022) 2 | 3 | [[arXiv paper](https://arxiv.org/pdf/2207.10026.pdf)] [[ECCV paper]()] 4 | 5 | ![method](method.png) 6 | 7 | ## Description 8 | 9 | This is a PyTorch implementation of the paper "Locality Guidance for Improving Vision Transformers on Tiny Datasets", supporting different Transformer models (including DeiT, T2T-ViT, PiT, PVT, PVTv2, ConViT, CvT) and different classification datasets (including CIFAR-100, Oxford Flowers, Tiny ImageNet, Chaoyang). 10 | 11 | ## Abstract 12 | 13 | While the Vision Transformer (VT) architecture is becoming trendy in computer vision, pure VT models perform poorly on tiny datasets. To address this issue, this paper proposes the locality guidance for improving the performance of VTs on tiny datasets. We first analyze that the local information, which is of great importance for understanding images, is hard to be learned with limited data due to the high flexibility and intrinsic globality of the self-attention mechanism in VTs. To facilitate local information, we realize the locality guidance for VTs by imitating the features of an already trained convolutional neural network (CNN), inspired by the built-in local-to-global hierarchy of CNN. Under our dual-task learning paradigm, the locality guidance provided by a lightweight CNN trained on low-resolution images is adequate to accelerate the convergence and improve the performance of VTs to a large extent. Therefore, our locality guidance approach is very simple and efficient, and can serve as a basic performance enhancement method for VTs on tiny datasets. Extensive experiments demonstrate that our method can significantly improve VTs when training from scratch on tiny datasets and is compatible with different kinds of VTs and datasets. For example, our proposed method can boost the performance of various VTs on tiny datasets (e.g., 13.07\% for DeiT, 8.98\% for T2T and 7.85\% for PVT), and enhance even stronger baseline PVTv2 by 1.86\% to 79.30\%, showing the potential of VTs on tiny datasets. 14 | 15 | ## Usage 16 | 17 | ### Dependencies 18 | 19 | The base environment we used for experiments is: 20 | 21 | - python = 3.8.12 22 | - pytorch = 1.8.0 23 | - cudatoolkit = 10.1 24 | 25 | Other dependencies can be installed by: 26 | 27 | ```shell 28 | pip install -r requirements.txt 29 | ``` 30 | 31 | ### Data Preparation 32 | 33 | **Step 1:** download datasets from their official websites: 34 | 35 | - [CIFAR-100](https://www.cs.toronto.edu/~kriz/cifar.html) 36 | - [Oxford Flowers](https://www.robots.ox.ac.uk/~vgg/data/flowers/102/) 37 | - [Tiny ImageNet](https://www.kaggle.com/c/tiny-imagenet) 38 | - [Chaoyang](https://bupt-ai-cz.github.io/HSA-NRL/) 39 | 40 | **Step 2:** move or link the datasets to `data/` directory. We show the layout of `data/` directory as follow: 41 | 42 | ``` 43 | data 44 | └── cifar-100-python 45 | |   ├── meta 46 | |   ├── test 47 | | └── train 48 | └── flowers 49 | |   ├── jpg 50 | |   ├── imagelabels.mat 51 | | └── setid.mat 52 | └── tiny-imagenet-200 53 | |   ├── train 54 | | ├── n01443537 55 | | └── ... 56 | | └── val 57 | | ├── images 58 | | └── val_annotations.txt 59 | └── chaoyang 60 |    ├── test 61 |    ├── train 62 |    ├── test.json 63 | └── train.json 64 | ``` 65 | 66 | ### Train from Scratch 67 | 68 | For example, you can train DeiT-Tiny from scratch using: 69 | 70 | ```shell 71 | python run_net.py --mode train --cfg configs/deit/deit-ti_c100_base.yaml 72 | ``` 73 | 74 | Besides, we provide configurations for different models and different datasets at `configs/`. 75 | 76 | ### Train with Locality Guidance 77 | 78 | **Step 1:** train the CNN guidance model (e.g., ResNet-56). This step will only take a little time and only needs to be executed once for each dataset. 79 | 80 | ```shell 81 | python run_net.py --mode train --cfg configs/resnet/r-56_c100.yaml 82 | ``` 83 | 84 | **Step 2:** train the target VT. 85 | 86 | ```shell 87 | python run_net.py --mode train --cfg configs/deit/deit-ti_c100_ours.yaml 88 | ``` 89 | 90 | As mentioned in the supplementary materials, the locality guidance can be executed offline using the per-computed features. To run in this setting, you can use: 91 | 92 | ```shell 93 | # Pre-compute features 94 | python precompute_feature.py --cfg configs/resnet/r-56_c100.yaml --ckpt work_dirs/r-56_c100/model.pyth 95 | # Train the model 96 | python run_net.py --mode train --cfg configs/deit/deit-ti_c100_ours_offline.yaml 97 | ``` 98 | 99 | ### Multi-GPU & Mixed Precision Support 100 | 101 | Just one argument needs to be added for multi-gpu or mixed precision training, for example: 102 | 103 | ```shell 104 | # Train DeiT from scratch with 2 gpus 105 | python run_net.py --mode train --cfg configs/deit/deit-ti_c100_base.yaml NUM_GPUS 2 106 | 107 | # Train DeiT from scratch with 2 gpus using mixed precision 108 | python run_net.py --mode train --cfg configs/deit/deit-ti_c100_base.yaml NUM_GPUS 2 TRAIN.MIXED_PRECISION True 109 | ``` 110 | 111 | ### Test 112 | 113 | ```shell 114 | python run_net.py --mode test --cfg configs/deit/deit-ti_c100_base.yaml TEST.WEIGHTS /path/to/model.pyth 115 | ``` 116 | 117 | ## Results 118 | 119 | | Model | Top-1 Acc. (Base) | Top-1 Acc. (Ours) | 120 | | :---------: | :----------------------------------------------------------: | :----------------------------------------------------------: | 121 | | DeiT-Tiny | 65.08 ( [weights](https://drive.google.com/file/d/1UpnIPvcTWrBZ2FYCYYY4FkTK4LhXazUY/view?usp=sharing) \| [log](https://drive.google.com/file/d/1uAIoYeNPOIE141AO-95JnKUZqKPgtz3C/view?usp=sharing) ) | 78.15 ( [weights](https://drive.google.com/file/d/1vo8jugJkgxmgFtiS4V1tIKAfmg5jdh0D/view?usp=sharing) \| [log](https://drive.google.com/file/d/1agOqk8eIGK3_XqfNnLPKOKwDbKBeqffu/view?usp=sharing) ) | 122 | | T2T-ViT-7 | 69.37 ( [weights](https://drive.google.com/file/d/1walDSuqyy2zfQv55NuG9a8Eq5d3GlRuf/view?usp=sharing) \| [log](https://drive.google.com/file/d/17xsso8wUlt-cf_-oZavTn9i-c-pTMhUW/view?usp=sharing) ) | 78.35 ( [weights](https://drive.google.com/file/d/1wD3wQ13O7otXjRo-4dC9DHg_HdLoUTVT/view?usp=sharing) \| [log](https://drive.google.com/file/d/1SNILqkf18lX-qcKdkg200ZBYB3N-bOue/view?usp=sharing) ) | 123 | | PiT-Tiny | 73.58 ( [weights](https://drive.google.com/file/d/1bTG9W0Kf-xNJSA35xv-Wmiw6G1Bfts3m/view?usp=sharing) \| [log](https://drive.google.com/file/d/1qhRMRp-AqBSFLvspHEsM06ANf8p6STox/view?usp=sharing) ) | 78.48 ( [weights](https://drive.google.com/file/d/14dPs5CzhVKqTwuwK3n75C-SWiWa3IQ6A/view?usp=sharing) \| [log](https://drive.google.com/file/d/1zYK9i9YN2mV9GMM02nbPRMOOGwqvehJg/view?usp=sharing) ) | 124 | | PVT-Tiny | 69.22 ( [weights](https://drive.google.com/file/d/18BbtQ3XF-_tzOB9BNbu04C-KDsHhrqmM/view?usp=sharing) \| [log](https://drive.google.com/file/d/1Qb3sOi0AuXl726hqxXCZSI7i-qH8_1YL/view?usp=sharing) ) | 77.07 ( [weights](https://drive.google.com/file/d/1rDFwcz3s1Irxk3FE4OhHks7qlzmoxM-w/view?usp=sharing) \| [log](https://drive.google.com/file/d/1FJ5ajTGN6zr0Eo12B8gW4XJ2FUIMSNoT/view?usp=sharing) ) | 125 | | PVTv2-B0 | 77.44 ( [weights](https://drive.google.com/file/d/1Aum9nL7IBFFan0Atkc9EKHKnv2LLfAAm/view?usp=sharing) \| [log](https://drive.google.com/file/d/1GNOdB2A2PHcMOsuJ7lTbE7kGCsEZZl3L/view?usp=sharing) ) | 79.30 ( [weights](https://drive.google.com/file/d/1a-ZAaPPDt9F_V4pabTGix0-HixIy1kE7/view?usp=sharing) \| [log](https://drive.google.com/file/d/1v38v0QhadSbrZmCfXH_kDi9W_Z5fjGqF/view?usp=sharing) ) | 126 | | ConViT-Tiny | 75.32 ( [weights](https://drive.google.com/file/d/1uAta933oxj45w9E_OIuxFnXbvrycmpHs/view?usp=sharing) \| [log](https://drive.google.com/file/d/1m79stHRfogaASovSoTZf1w_g6dXgBTQE/view?usp=sharing) ) | 78.95 ( [weights](https://drive.google.com/file/d/1nQHEKMQJDfw2TBT-dZ3mtdI1ozdnUzG2/view?usp=sharing) \| [log](https://drive.google.com/file/d/1wQMBcBL0FouIOD19PXyN0_OJVyqbc5rw/view?usp=sharing) ) | 127 | 128 | Here we provide pre-trained models and training logs (can be viewed via TensorBoard). 129 | 130 | ## Acknowledgement 131 | 132 | This repository is built upon [pycls](https://github.com/facebookresearch/pycls) and the official implementations of [DeiT](https://github.com/facebookresearch/deit), [T2T-ViT](https://github.com/yitu-opensource/T2T-ViT), [PiT](https://github.com/naver-ai/pit), [PVTv1/v2](https://github.com/whai362/PVT), [ConViT](https://github.com/facebookresearch/convit) and [CvT](https://github.com/microsoft/CvT). We would like to thank authors of these open source repositories. 133 | 134 | ## Citing 135 | 136 | ``` 137 | @article{li2022locality, 138 | title={Locality Guidance for Improving Vision Transformers on Tiny Datasets}, 139 | author={Li, Kehan and Yu, Runyi and Wang, Zhennan and Yuan, Li and Song, Guoli and Chen, Jie}, 140 | journal={arXiv preprint arXiv:2207.10026}, 141 | year={2022} 142 | } 143 | ``` 144 | 145 | -------------------------------------------------------------------------------- /configs/convit/convit-ti_c100_base.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: ConViT 3 | NUM_CLASSES: 100 4 | TRANSFORMER: 5 | PATCH_SIZE: 16 6 | HIDDEN_DIM: 192 7 | DEPTH: 12 8 | NUM_HEADS: 4 9 | MLP_RATIO: 4 10 | DROP_RATE: 0.0 11 | DROP_PATH_RATE: 0.1 12 | ATTENTION_DROP_RATE: 0.0 13 | CONVIT: 14 | LOCAL_LAYERS: 10 15 | LOCALITY_STRENGTH: 1.0 16 | OPTIM: 17 | OPTIMIZER: adamw 18 | BASE_LR: 5.0e-4 19 | MIN_LR: 5.0e-6 20 | LR_POLICY: cos 21 | MAX_EPOCH: 300 22 | WEIGHT_DECAY: 0.05 23 | WARMUP_FACTOR: 0.001 24 | WARMUP_EPOCHS: 20 25 | TRAIN: 26 | DATASET: cifar100 27 | SPLIT: train 28 | BATCH_SIZE: 128 29 | TEST: 30 | DATASET: cifar100 31 | SPLIT: test 32 | BATCH_SIZE: 200 33 | NUM_GPUS: 1 34 | DATA_LOADER: 35 | NUM_WORKERS: 4 36 | CUDNN: 37 | BENCHMARK: False 38 | -------------------------------------------------------------------------------- /configs/convit/convit-ti_c100_ours.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: ConViT 3 | NUM_CLASSES: 100 4 | TRANSFORMER: 5 | PATCH_SIZE: 16 6 | HIDDEN_DIM: 192 7 | DEPTH: 12 8 | NUM_HEADS: 4 9 | MLP_RATIO: 4 10 | DROP_RATE: 0.0 11 | DROP_PATH_RATE: 0.1 12 | ATTENTION_DROP_RATE: 0.0 13 | CONVIT: 14 | LOCAL_LAYERS: 10 15 | LOCALITY_STRENGTH: 1.0 16 | CNN: 17 | DEPTH: 56 18 | RESNET: 19 | TRANS_FUN: basic_transform 20 | DISTILLATION: 21 | ENABLE_INTER: True 22 | INTER_TRANSFORM: linear 23 | INTER_TEACHER_INDEX: [0, 1, 2] 24 | INTER_STUDENT_INDEX: [0, 6, 11] 25 | INTER_WEIGHT: 2.5 26 | TEACHER_MODEL: ResNet 27 | TEACHER_WEIGHTS: work_dirs/r-56_c100/model.pyth 28 | TEACHER_IMG_SIZE: 32 29 | OPTIM: 30 | OPTIMIZER: adamw 31 | BASE_LR: 5.0e-4 32 | MIN_LR: 5.0e-6 33 | LR_POLICY: cos 34 | MAX_EPOCH: 300 35 | WEIGHT_DECAY: 0.05 36 | WARMUP_FACTOR: 0.001 37 | WARMUP_EPOCHS: 20 38 | TRAIN: 39 | DATASET: cifar100 40 | SPLIT: train 41 | BATCH_SIZE: 128 42 | TEST: 43 | DATASET: cifar100 44 | SPLIT: test 45 | BATCH_SIZE: 200 46 | NUM_GPUS: 1 47 | DATA_LOADER: 48 | NUM_WORKERS: 4 49 | CUDNN: 50 | BENCHMARK: False 51 | -------------------------------------------------------------------------------- /configs/cvt/cvt-13_c100_base.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: CvT 3 | NUM_CLASSES: 100 4 | TRANSFORMER: 5 | PATCH_SIZE: [7, 3, 3] 6 | PATCH_STRIDE: [4, 2, 2] 7 | PATCH_PADDING: [2, 1, 1] 8 | HIDDEN_DIM: [64, 192, 384] 9 | DEPTH: [1, 2, 10] 10 | NUM_HEADS: [1, 3, 6] 11 | MLP_RATIO: [4, 4, 4] 12 | LN_EPS: 1.0e-5 13 | DROP_RATE: [0.0, 0.0, 0.0] 14 | DROP_PATH_RATE: [0.0, 0.0, 0.1] 15 | ATTENTION_DROP_RATE: [0.0, 0.0, 0.0] 16 | CVT: 17 | WITH_CLS_TOKEN: [False, False, True] 18 | QKV_PROJ_METHOD: ['dw_bn', 'dw_bn', 'dw_bn'] 19 | KERNEL_QKV: [3, 3, 3] 20 | STRIDE_KV: [2, 2, 2] 21 | STRIDE_Q: [1, 1, 1] 22 | PADDING_KV: [1, 1, 1] 23 | PADDING_Q: [1, 1, 1] 24 | OPTIM: 25 | OPTIMIZER: adamw 26 | BASE_LR: 5.0e-4 27 | MIN_LR: 5.0e-6 28 | LR_POLICY: cos 29 | MAX_EPOCH: 300 30 | WEIGHT_DECAY: 0.05 31 | WARMUP_FACTOR: 0.001 32 | WARMUP_EPOCHS: 20 33 | TRAIN: 34 | DATASET: cifar100 35 | SPLIT: train 36 | BATCH_SIZE: 128 37 | TEST: 38 | DATASET: cifar100 39 | SPLIT: test 40 | BATCH_SIZE: 200 41 | NUM_GPUS: 1 42 | DATA_LOADER: 43 | NUM_WORKERS: 4 44 | CUDNN: 45 | BENCHMARK: False 46 | -------------------------------------------------------------------------------- /configs/cvt/cvt-13_c100_ours.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: CvT 3 | NUM_CLASSES: 100 4 | TRANSFORMER: 5 | PATCH_SIZE: [7, 3, 3] 6 | PATCH_STRIDE: [4, 2, 2] 7 | PATCH_PADDING: [2, 1, 1] 8 | HIDDEN_DIM: [64, 192, 384] 9 | DEPTH: [1, 2, 10] 10 | NUM_HEADS: [1, 3, 6] 11 | MLP_RATIO: [4, 4, 4] 12 | LN_EPS: 1.0e-5 13 | DROP_RATE: [0.0, 0.0, 0.0] 14 | DROP_PATH_RATE: [0.0, 0.0, 0.1] 15 | ATTENTION_DROP_RATE: [0.0, 0.0, 0.0] 16 | CVT: 17 | WITH_CLS_TOKEN: [False, False, True] 18 | QKV_PROJ_METHOD: ['dw_bn', 'dw_bn', 'dw_bn'] 19 | KERNEL_QKV: [3, 3, 3] 20 | STRIDE_KV: [2, 2, 2] 21 | STRIDE_Q: [1, 1, 1] 22 | PADDING_KV: [1, 1, 1] 23 | PADDING_Q: [1, 1, 1] 24 | CNN: 25 | DEPTH: 56 26 | RESNET: 27 | TRANS_FUN: basic_transform 28 | DISTILLATION: 29 | ENABLE_INTER: True 30 | INTER_TRANSFORM: linear 31 | INTER_TEACHER_INDEX: [0, 1, 2] 32 | INTER_STUDENT_INDEX: [0, 6, 11] 33 | INTER_WEIGHT: 2.5 34 | TEACHER_MODEL: ResNet 35 | TEACHER_WEIGHTS: work_dirs/r-56_c100/model.pyth 36 | TEACHER_IMG_SIZE: 32 37 | OPTIM: 38 | OPTIMIZER: adamw 39 | BASE_LR: 5.0e-4 40 | MIN_LR: 5.0e-6 41 | LR_POLICY: cos 42 | MAX_EPOCH: 300 43 | WEIGHT_DECAY: 0.05 44 | WARMUP_FACTOR: 0.001 45 | WARMUP_EPOCHS: 20 46 | TRAIN: 47 | DATASET: cifar100 48 | SPLIT: train 49 | BATCH_SIZE: 128 50 | TEST: 51 | DATASET: cifar100 52 | SPLIT: test 53 | BATCH_SIZE: 200 54 | NUM_GPUS: 1 55 | DATA_LOADER: 56 | NUM_WORKERS: 4 57 | CUDNN: 58 | BENCHMARK: False 59 | -------------------------------------------------------------------------------- /configs/deit/deit-ti_c100_base.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: DeiT 3 | NUM_CLASSES: 100 4 | TRANSFORMER: 5 | PATCH_SIZE: 16 6 | HIDDEN_DIM: 192 7 | DEPTH: 12 8 | NUM_HEADS: 3 9 | MLP_RATIO: 4 10 | DROP_RATE: 0.0 11 | DROP_PATH_RATE: 0.1 12 | ATTENTION_DROP_RATE: 0.0 13 | OPTIM: 14 | OPTIMIZER: adamw 15 | BASE_LR: 5.0e-4 16 | MIN_LR: 5.0e-6 17 | LR_POLICY: cos 18 | MAX_EPOCH: 300 19 | WEIGHT_DECAY: 0.05 20 | WARMUP_FACTOR: 0.001 21 | WARMUP_EPOCHS: 20 22 | TRAIN: 23 | DATASET: cifar100 24 | SPLIT: train 25 | BATCH_SIZE: 128 26 | TEST: 27 | DATASET: cifar100 28 | SPLIT: test 29 | BATCH_SIZE: 200 30 | NUM_GPUS: 1 31 | DATA_LOADER: 32 | NUM_WORKERS: 4 33 | CUDNN: 34 | BENCHMARK: False 35 | -------------------------------------------------------------------------------- /configs/deit/deit-ti_c100_ours.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: DeiT 3 | NUM_CLASSES: 100 4 | TRANSFORMER: 5 | PATCH_SIZE: 16 6 | HIDDEN_DIM: 192 7 | DEPTH: 12 8 | NUM_HEADS: 3 9 | MLP_RATIO: 4 10 | DROP_RATE: 0.0 11 | DROP_PATH_RATE: 0.1 12 | ATTENTION_DROP_RATE: 0.0 13 | CNN: 14 | DEPTH: 56 15 | RESNET: 16 | TRANS_FUN: basic_transform 17 | DISTILLATION: 18 | ENABLE_INTER: True 19 | INTER_TRANSFORM: linear 20 | INTER_TEACHER_INDEX: [0, 1, 2] 21 | INTER_STUDENT_INDEX: [0, 6, 11] 22 | INTER_WEIGHT: 2.5 23 | TEACHER_MODEL: ResNet 24 | TEACHER_WEIGHTS: work_dirs/r-56_c100/model.pyth 25 | TEACHER_IMG_SIZE: 32 26 | OPTIM: 27 | OPTIMIZER: adamw 28 | BASE_LR: 5.0e-4 29 | MIN_LR: 5.0e-6 30 | LR_POLICY: cos 31 | MAX_EPOCH: 300 32 | WEIGHT_DECAY: 0.05 33 | WARMUP_FACTOR: 0.001 34 | WARMUP_EPOCHS: 20 35 | TRAIN: 36 | DATASET: cifar100 37 | SPLIT: train 38 | BATCH_SIZE: 128 39 | TEST: 40 | DATASET: cifar100 41 | SPLIT: test 42 | BATCH_SIZE: 200 43 | NUM_GPUS: 1 44 | DATA_LOADER: 45 | NUM_WORKERS: 4 46 | CUDNN: 47 | BENCHMARK: False 48 | -------------------------------------------------------------------------------- /configs/deit/deit-ti_c100_ours_offline.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: DeiT 3 | NUM_CLASSES: 100 4 | TRANSFORMER: 5 | PATCH_SIZE: 16 6 | HIDDEN_DIM: 192 7 | DEPTH: 12 8 | NUM_HEADS: 3 9 | MLP_RATIO: 4 10 | DROP_RATE: 0.0 11 | DROP_PATH_RATE: 0.1 12 | ATTENTION_DROP_RATE: 0.0 13 | CNN: 14 | DEPTH: 56 15 | RESNET: 16 | TRANS_FUN: basic_transform 17 | DISTILLATION: 18 | ENABLE_INTER: True 19 | INTER_TRANSFORM: linear 20 | INTER_TEACHER_INDEX: [0, 1, 2] 21 | INTER_STUDENT_INDEX: [0, 6, 11] 22 | INTER_WEIGHT: 2.5 23 | TEACHER_MODEL: ResNet 24 | TEACHER_WEIGHTS: work_dirs/r-56_c100/model.pyth 25 | TEACHER_IMG_SIZE: 32 26 | OFFLINE: True 27 | FEATURE_FILE: temp/r-56_c100.npz 28 | OPTIM: 29 | OPTIMIZER: adamw 30 | BASE_LR: 5.0e-4 31 | MIN_LR: 5.0e-6 32 | LR_POLICY: cos 33 | MAX_EPOCH: 300 34 | WEIGHT_DECAY: 0.05 35 | WARMUP_FACTOR: 0.001 36 | WARMUP_EPOCHS: 20 37 | TRAIN: 38 | DATASET: cifar100 39 | SPLIT: train 40 | BATCH_SIZE: 128 41 | TEST: 42 | DATASET: cifar100 43 | SPLIT: test 44 | BATCH_SIZE: 200 45 | NUM_GPUS: 1 46 | DATA_LOADER: 47 | NUM_WORKERS: 4 48 | CUDNN: 49 | BENCHMARK: False 50 | -------------------------------------------------------------------------------- /configs/pit/pit-ti_c100_base.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: PiT 3 | NUM_CLASSES: 100 4 | TRANSFORMER: 5 | PATCH_SIZE: 16 6 | HIDDEN_DIM: [64, 128, 256] 7 | DEPTH: [2, 6, 4] 8 | NUM_HEADS: [2, 4, 8] 9 | MLP_RATIO: 4 10 | DROP_RATE: 0.0 11 | DROP_PATH_RATE: 0.1 12 | ATTENTION_DROP_RATE: 0.0 13 | PIT: 14 | STRIDE: 8 15 | OPTIM: 16 | OPTIMIZER: adamw 17 | BASE_LR: 5.0e-4 18 | MIN_LR: 5.0e-6 19 | LR_POLICY: cos 20 | MAX_EPOCH: 300 21 | WEIGHT_DECAY: 0.05 22 | WARMUP_FACTOR: 0.001 23 | WARMUP_EPOCHS: 20 24 | TRAIN: 25 | DATASET: cifar100 26 | SPLIT: train 27 | BATCH_SIZE: 128 28 | TEST: 29 | DATASET: cifar100 30 | SPLIT: test 31 | BATCH_SIZE: 200 32 | NUM_GPUS: 1 33 | DATA_LOADER: 34 | NUM_WORKERS: 4 35 | CUDNN: 36 | BENCHMARK: False 37 | -------------------------------------------------------------------------------- /configs/pit/pit-ti_c100_ours.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: PiT 3 | NUM_CLASSES: 100 4 | TRANSFORMER: 5 | PATCH_SIZE: 16 6 | HIDDEN_DIM: [64, 128, 256] 7 | DEPTH: [2, 6, 4] 8 | NUM_HEADS: [2, 4, 8] 9 | MLP_RATIO: 4 10 | DROP_RATE: 0.0 11 | DROP_PATH_RATE: 0.1 12 | ATTENTION_DROP_RATE: 0.0 13 | PIT: 14 | STRIDE: 8 15 | CNN: 16 | DEPTH: 56 17 | RESNET: 18 | TRANS_FUN: basic_transform 19 | DISTILLATION: 20 | ENABLE_INTER: True 21 | INTER_TRANSFORM: linear 22 | INTER_TEACHER_INDEX: [0, 1, 2] 23 | INTER_STUDENT_INDEX: [0, 6, 11] 24 | INTER_WEIGHT: 2.5 25 | TEACHER_MODEL: ResNet 26 | TEACHER_WEIGHTS: work_dirs/r-56_c100/model.pyth 27 | TEACHER_IMG_SIZE: 32 28 | OPTIM: 29 | OPTIMIZER: adamw 30 | BASE_LR: 5.0e-4 31 | MIN_LR: 5.0e-6 32 | LR_POLICY: cos 33 | MAX_EPOCH: 300 34 | WEIGHT_DECAY: 0.05 35 | WARMUP_FACTOR: 0.001 36 | WARMUP_EPOCHS: 20 37 | TRAIN: 38 | DATASET: cifar100 39 | SPLIT: train 40 | BATCH_SIZE: 128 41 | TEST: 42 | DATASET: cifar100 43 | SPLIT: test 44 | BATCH_SIZE: 200 45 | NUM_GPUS: 1 46 | DATA_LOADER: 47 | NUM_WORKERS: 4 48 | CUDNN: 49 | BENCHMARK: False 50 | -------------------------------------------------------------------------------- /configs/pvt/pvt-ti_c100_base.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: PVT 3 | NUM_CLASSES: 100 4 | TRANSFORMER: 5 | PATCH_SIZE: [4, 2, 2, 2] 6 | HIDDEN_DIM: [64, 128, 320, 512] 7 | DEPTH: [2, 2, 2, 2] 8 | NUM_HEADS: [1, 2, 5, 8] 9 | MLP_RATIO: [8, 8, 4, 4] 10 | DROP_RATE: 0.0 11 | DROP_PATH_RATE: 0.1 12 | ATTENTION_DROP_RATE: 0.0 13 | PVT: 14 | SR_RATIO: [8, 4, 2, 1] 15 | OPTIM: 16 | OPTIMIZER: adamw 17 | BASE_LR: 5.0e-4 18 | MIN_LR: 5.0e-6 19 | LR_POLICY: cos 20 | MAX_EPOCH: 300 21 | WEIGHT_DECAY: 0.05 22 | WARMUP_FACTOR: 0.001 23 | WARMUP_EPOCHS: 20 24 | TRAIN: 25 | DATASET: cifar100 26 | SPLIT: train 27 | BATCH_SIZE: 128 28 | TEST: 29 | DATASET: cifar100 30 | SPLIT: test 31 | BATCH_SIZE: 200 32 | NUM_GPUS: 1 33 | DATA_LOADER: 34 | NUM_WORKERS: 4 35 | CUDNN: 36 | BENCHMARK: False 37 | -------------------------------------------------------------------------------- /configs/pvt/pvt-ti_c100_ours.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: PVT 3 | NUM_CLASSES: 100 4 | TRANSFORMER: 5 | PATCH_SIZE: [4, 2, 2, 2] 6 | HIDDEN_DIM: [64, 128, 320, 512] 7 | DEPTH: [2, 2, 2, 2] 8 | NUM_HEADS: [1, 2, 5, 8] 9 | MLP_RATIO: [8, 8, 4, 4] 10 | DROP_RATE: 0.0 11 | DROP_PATH_RATE: 0.1 12 | ATTENTION_DROP_RATE: 0.0 13 | PVT: 14 | SR_RATIO: [8, 4, 2, 1] 15 | CNN: 16 | DEPTH: 56 17 | RESNET: 18 | TRANS_FUN: basic_transform 19 | DISTILLATION: 20 | ENABLE_INTER: True 21 | INTER_TRANSFORM: linear 22 | INTER_TEACHER_INDEX: [0, 1, 2] 23 | INTER_STUDENT_INDEX: [0, 3, 7] 24 | INTER_WEIGHT: 2.5 25 | TEACHER_MODEL: ResNet 26 | TEACHER_WEIGHTS: work_dirs/r-56_c100/model.pyth 27 | TEACHER_IMG_SIZE: 32 28 | OPTIM: 29 | OPTIMIZER: adamw 30 | BASE_LR: 5.0e-4 31 | MIN_LR: 5.0e-6 32 | LR_POLICY: cos 33 | MAX_EPOCH: 300 34 | WEIGHT_DECAY: 0.05 35 | WARMUP_FACTOR: 0.001 36 | WARMUP_EPOCHS: 20 37 | TRAIN: 38 | DATASET: cifar100 39 | SPLIT: train 40 | BATCH_SIZE: 128 41 | TEST: 42 | DATASET: cifar100 43 | SPLIT: test 44 | BATCH_SIZE: 200 45 | NUM_GPUS: 1 46 | DATA_LOADER: 47 | NUM_WORKERS: 4 48 | CUDNN: 49 | BENCHMARK: False 50 | -------------------------------------------------------------------------------- /configs/pvtv2/pvtv2-b0_c100_base.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: PVTv2 3 | NUM_CLASSES: 100 4 | TRANSFORMER: 5 | PATCH_SIZE: [7, 3, 3, 3] 6 | PATCH_STRIDE: [4, 2, 2, 2] 7 | HIDDEN_DIM: [32, 64, 160, 256] 8 | DEPTH: [2, 2, 2, 2] 9 | NUM_HEADS: [1, 2, 5, 8] 10 | MLP_RATIO: [8, 8, 4, 4] 11 | DROP_RATE: 0.0 12 | DROP_PATH_RATE: 0.1 13 | ATTENTION_DROP_RATE: 0.0 14 | PVT: 15 | SR_RATIO: [8, 4, 2, 1] 16 | OPTIM: 17 | OPTIMIZER: adamw 18 | BASE_LR: 5.0e-4 19 | MIN_LR: 5.0e-6 20 | LR_POLICY: cos 21 | MAX_EPOCH: 300 22 | WEIGHT_DECAY: 0.05 23 | WARMUP_FACTOR: 0.001 24 | WARMUP_EPOCHS: 20 25 | TRAIN: 26 | DATASET: cifar100 27 | SPLIT: train 28 | BATCH_SIZE: 128 29 | TEST: 30 | DATASET: cifar100 31 | SPLIT: test 32 | BATCH_SIZE: 200 33 | NUM_GPUS: 1 34 | DATA_LOADER: 35 | NUM_WORKERS: 4 36 | CUDNN: 37 | BENCHMARK: False 38 | -------------------------------------------------------------------------------- /configs/pvtv2/pvtv2-b0_c100_ours.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: PVTv2 3 | NUM_CLASSES: 100 4 | TRANSFORMER: 5 | PATCH_SIZE: [7, 3, 3, 3] 6 | PATCH_STRIDE: [4, 2, 2, 2] 7 | HIDDEN_DIM: [32, 64, 160, 256] 8 | DEPTH: [2, 2, 2, 2] 9 | NUM_HEADS: [1, 2, 5, 8] 10 | MLP_RATIO: [8, 8, 4, 4] 11 | DROP_RATE: 0.0 12 | DROP_PATH_RATE: 0.1 13 | ATTENTION_DROP_RATE: 0.0 14 | PVT: 15 | SR_RATIO: [8, 4, 2, 1] 16 | CNN: 17 | DEPTH: 56 18 | RESNET: 19 | TRANS_FUN: basic_transform 20 | DISTILLATION: 21 | ENABLE_INTER: True 22 | INTER_TRANSFORM: linear 23 | INTER_TEACHER_INDEX: [0, 1, 2] 24 | INTER_STUDENT_INDEX: [0, 3, 7] 25 | INTER_WEIGHT: 2.5 26 | TEACHER_MODEL: ResNet 27 | TEACHER_WEIGHTS: work_dirs/r-56_c100/model.pyth 28 | TEACHER_IMG_SIZE: 32 29 | OPTIM: 30 | OPTIMIZER: adamw 31 | BASE_LR: 5.0e-4 32 | MIN_LR: 5.0e-6 33 | LR_POLICY: cos 34 | MAX_EPOCH: 300 35 | WEIGHT_DECAY: 0.05 36 | WARMUP_FACTOR: 0.001 37 | WARMUP_EPOCHS: 20 38 | TRAIN: 39 | DATASET: cifar100 40 | SPLIT: train 41 | BATCH_SIZE: 128 42 | TEST: 43 | DATASET: cifar100 44 | SPLIT: test 45 | BATCH_SIZE: 200 46 | NUM_GPUS: 1 47 | DATA_LOADER: 48 | NUM_WORKERS: 4 49 | CUDNN: 50 | BENCHMARK: False 51 | -------------------------------------------------------------------------------- /configs/resnet/r-18_c100.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: ResNet 3 | IMG_SIZE: 224 4 | NUM_CLASSES: 100 5 | CNN: 6 | DEPTH: 18 7 | RESNET: 8 | TRANS_FUN: basic_transform 9 | OPTIM: 10 | BASE_LR: 0.1 11 | LR_POLICY: cos 12 | MAX_EPOCH: 300 13 | MOMENTUM: 0.9 14 | NESTEROV: True 15 | WEIGHT_DECAY: 0.0005 16 | TRAIN: 17 | DATASET: cifar100 18 | SPLIT: train 19 | BATCH_SIZE: 128 20 | TEST: 21 | DATASET: cifar100 22 | SPLIT: test 23 | BATCH_SIZE: 200 24 | NUM_GPUS: 1 25 | DATA_LOADER: 26 | NUM_WORKERS: 4 27 | CUDNN: 28 | BENCHMARK: False 29 | -------------------------------------------------------------------------------- /configs/resnet/r-56_c100.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: ResNet 3 | IMG_SIZE: 32 4 | NUM_CLASSES: 100 5 | CNN: 6 | DEPTH: 56 7 | RESNET: 8 | TRANS_FUN: basic_transform 9 | OPTIM: 10 | BASE_LR: 0.1 11 | LR_POLICY: cos 12 | MAX_EPOCH: 300 13 | MOMENTUM: 0.9 14 | NESTEROV: True 15 | WEIGHT_DECAY: 0.0005 16 | TRAIN: 17 | DATASET: cifar100 18 | SPLIT: train 19 | BATCH_SIZE: 128 20 | TEST: 21 | DATASET: cifar100 22 | SPLIT: test 23 | BATCH_SIZE: 200 24 | NUM_GPUS: 1 25 | DATA_LOADER: 26 | NUM_WORKERS: 4 27 | CUDNN: 28 | BENCHMARK: False 29 | -------------------------------------------------------------------------------- /configs/t2t/t2t-14_c100_base.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: T2TViT 3 | NUM_CLASSES: 100 4 | TRANSFORMER: 5 | PATCH_SIZE: 16 6 | HIDDEN_DIM: 384 7 | DEPTH: 14 8 | NUM_HEADS: 6 9 | MLP_RATIO: 3 10 | DROP_RATE: 0.0 11 | DROP_PATH_RATE: 0.1 12 | ATTENTION_DROP_RATE: 0.0 13 | T2T: 14 | TOKEN_DIM: 64 15 | KERNEL_SIZE: [7, 3, 3] 16 | STRIDE: [4, 2, 2] 17 | PADDING: [2, 1, 1] 18 | OPTIM: 19 | OPTIMIZER: adamw 20 | BASE_LR: 5.0e-4 21 | MIN_LR: 5.0e-6 22 | LR_POLICY: cos 23 | MAX_EPOCH: 300 24 | WEIGHT_DECAY: 0.05 25 | WARMUP_FACTOR: 0.001 26 | WARMUP_EPOCHS: 20 27 | TRAIN: 28 | DATASET: cifar100 29 | SPLIT: train 30 | BATCH_SIZE: 128 31 | TEST: 32 | DATASET: cifar100 33 | SPLIT: test 34 | BATCH_SIZE: 200 35 | NUM_GPUS: 1 36 | DATA_LOADER: 37 | NUM_WORKERS: 4 38 | CUDNN: 39 | BENCHMARK: False 40 | -------------------------------------------------------------------------------- /configs/t2t/t2t-14_c100_ours.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: T2TViT 3 | NUM_CLASSES: 100 4 | TRANSFORMER: 5 | PATCH_SIZE: 16 6 | HIDDEN_DIM: 384 7 | DEPTH: 14 8 | NUM_HEADS: 6 9 | MLP_RATIO: 3 10 | DROP_RATE: 0.0 11 | DROP_PATH_RATE: 0.1 12 | ATTENTION_DROP_RATE: 0.0 13 | T2T: 14 | TOKEN_DIM: 64 15 | KERNEL_SIZE: [7, 3, 3] 16 | STRIDE: [4, 2, 2] 17 | PADDING: [2, 1, 1] 18 | CNN: 19 | DEPTH: 56 20 | RESNET: 21 | TRANS_FUN: basic_transform 22 | DISTILLATION: 23 | ENABLE_INTER: True 24 | INTER_TRANSFORM: linear 25 | INTER_TEACHER_INDEX: [0, 1, 2] 26 | INTER_STUDENT_INDEX: [0, 7, 13] 27 | INTER_WEIGHT: 2.5 28 | TEACHER_MODEL: ResNet 29 | TEACHER_WEIGHTS: work_dirs/r-56_c100/model.pyth 30 | TEACHER_IMG_SIZE: 32 31 | OPTIM: 32 | OPTIMIZER: adamw 33 | BASE_LR: 5.0e-4 34 | MIN_LR: 5.0e-6 35 | LR_POLICY: cos 36 | MAX_EPOCH: 300 37 | WEIGHT_DECAY: 0.05 38 | WARMUP_FACTOR: 0.001 39 | WARMUP_EPOCHS: 20 40 | TRAIN: 41 | DATASET: cifar100 42 | SPLIT: train 43 | BATCH_SIZE: 128 44 | TEST: 45 | DATASET: cifar100 46 | SPLIT: test 47 | BATCH_SIZE: 200 48 | NUM_GPUS: 1 49 | DATA_LOADER: 50 | NUM_WORKERS: 4 51 | CUDNN: 52 | BENCHMARK: False 53 | -------------------------------------------------------------------------------- /configs/t2t/t2t-7_c100_base.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: T2TViT 3 | NUM_CLASSES: 100 4 | TRANSFORMER: 5 | PATCH_SIZE: 16 6 | HIDDEN_DIM: 256 7 | DEPTH: 7 8 | NUM_HEADS: 4 9 | MLP_RATIO: 2 10 | DROP_RATE: 0.0 11 | DROP_PATH_RATE: 0.1 12 | ATTENTION_DROP_RATE: 0.0 13 | T2T: 14 | TOKEN_DIM: 64 15 | KERNEL_SIZE: [7, 3, 3] 16 | STRIDE: [4, 2, 2] 17 | PADDING: [2, 1, 1] 18 | OPTIM: 19 | OPTIMIZER: adamw 20 | BASE_LR: 5.0e-4 21 | MIN_LR: 5.0e-6 22 | LR_POLICY: cos 23 | MAX_EPOCH: 300 24 | WEIGHT_DECAY: 0.05 25 | WARMUP_FACTOR: 0.001 26 | WARMUP_EPOCHS: 20 27 | TRAIN: 28 | DATASET: cifar100 29 | SPLIT: train 30 | BATCH_SIZE: 128 31 | TEST: 32 | DATASET: cifar100 33 | SPLIT: test 34 | BATCH_SIZE: 200 35 | NUM_GPUS: 1 36 | DATA_LOADER: 37 | NUM_WORKERS: 4 38 | CUDNN: 39 | BENCHMARK: False 40 | -------------------------------------------------------------------------------- /configs/t2t/t2t-7_c100_ours.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | TYPE: T2TViT 3 | NUM_CLASSES: 100 4 | TRANSFORMER: 5 | PATCH_SIZE: 16 6 | HIDDEN_DIM: 256 7 | DEPTH: 7 8 | NUM_HEADS: 4 9 | MLP_RATIO: 2 10 | DROP_RATE: 0.0 11 | DROP_PATH_RATE: 0.1 12 | ATTENTION_DROP_RATE: 0.0 13 | T2T: 14 | TOKEN_DIM: 64 15 | KERNEL_SIZE: [7, 3, 3] 16 | STRIDE: [4, 2, 2] 17 | PADDING: [2, 1, 1] 18 | CNN: 19 | DEPTH: 56 20 | RESNET: 21 | TRANS_FUN: basic_transform 22 | DISTILLATION: 23 | ENABLE_INTER: True 24 | INTER_TRANSFORM: linear 25 | INTER_TEACHER_INDEX: [0, 1, 2] 26 | INTER_STUDENT_INDEX: [0, 3, 6] 27 | INTER_WEIGHT: 2.5 28 | TEACHER_MODEL: ResNet 29 | TEACHER_WEIGHTS: work_dirs/r-56_c100/model.pyth 30 | TEACHER_IMG_SIZE: 32 31 | OPTIM: 32 | OPTIMIZER: adamw 33 | BASE_LR: 5.0e-4 34 | MIN_LR: 5.0e-6 35 | LR_POLICY: cos 36 | MAX_EPOCH: 300 37 | WEIGHT_DECAY: 0.05 38 | WARMUP_FACTOR: 0.001 39 | WARMUP_EPOCHS: 20 40 | TRAIN: 41 | DATASET: cifar100 42 | SPLIT: train 43 | BATCH_SIZE: 128 44 | TEST: 45 | DATASET: cifar100 46 | SPLIT: test 47 | BATCH_SIZE: 200 48 | NUM_GPUS: 1 49 | DATA_LOADER: 50 | NUM_WORKERS: 4 51 | CUDNN: 52 | BENCHMARK: False 53 | -------------------------------------------------------------------------------- /method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lkhl/tiny-transformers/d2165f74049c906b0afc9f957491960fb3c0cc8b/method.png -------------------------------------------------------------------------------- /precompute_feature.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pycls.core.config as config 3 | import pycls.core.builders as builders 4 | from pycls.datasets.transforms import create_test_transform 5 | 6 | import os 7 | import numpy as np 8 | from tqdm import tqdm 9 | from collections import defaultdict 10 | 11 | import torch 12 | from torch.utils.data import DataLoader 13 | from torchvision.datasets import CIFAR100 14 | from torchvision.transforms import Compose 15 | 16 | 17 | def main(): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--cfg', type=str, default='configs/resnet/r-56_c100.yaml') 20 | parser.add_argument('--ckpt', type=str, default='work_dirs/r-56_c100/model.pyth') 21 | args = parser.parse_args() 22 | 23 | save_dir = 'temp' 24 | os.makedirs(save_dir, exist_ok=True) 25 | 26 | config.load_cfg(args.cfg) 27 | 28 | transforms = create_test_transform() 29 | transform = Compose(transforms) 30 | dataset = CIFAR100(root='data', train=True, transform=transform, download=True) 31 | loader = DataLoader(dataset, batch_size=200) 32 | 33 | model = builders.build_model() 34 | model.load_state_dict(torch.load(args.ckpt, map_location='cpu')['model_state']) 35 | model.cuda() 36 | 37 | feat_dict = defaultdict(list) 38 | for img, _ in tqdm(loader): 39 | img = img.cuda() 40 | with torch.no_grad(): 41 | model(img) 42 | 43 | for i, feat in enumerate(model.features): 44 | N, _, H, W = feat.shape 45 | feat = feat.cpu() 46 | feat_dict[f'layer_{i}'].append(feat) 47 | 48 | for k in feat_dict: 49 | feat_dict[k] = torch.cat(feat_dict[k], dim=0).numpy() 50 | 51 | cfg_name = os.path.splitext(os.path.basename(args.cfg))[0] 52 | np.savez(os.path.join(save_dir, f'{cfg_name}.npz'), **feat_dict) 53 | 54 | 55 | if __name__ == '__main__': 56 | main() 57 | -------------------------------------------------------------------------------- /pycls/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lkhl/tiny-transformers/d2165f74049c906b0afc9f957491960fb3c0cc8b/pycls/__init__.py -------------------------------------------------------------------------------- /pycls/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lkhl/tiny-transformers/d2165f74049c906b0afc9f957491960fb3c0cc8b/pycls/core/__init__.py -------------------------------------------------------------------------------- /pycls/core/benchmark.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Modified from pycls. 10 | https://github.com/facebookresearch/pycls/blob/main/pycls/core/benchmark.py 11 | """ 12 | 13 | import numpy as np 14 | import pycls.core.logging as logging 15 | import pycls.core.net as net 16 | import pycls.datasets.loader as loader 17 | import torch 18 | import torch.cuda.amp as amp 19 | from pycls.core.config import cfg 20 | from pycls.core.timer import Timer 21 | 22 | 23 | logger = logging.get_logger(__name__) 24 | 25 | 26 | @torch.no_grad() 27 | def compute_time_eval(model): 28 | model.eval() 29 | im_size, batch_size = cfg.MODEL.IMG_SIZE, int(cfg.TEST.BATCH_SIZE / cfg.NUM_GPUS) 30 | inputs = torch.zeros(batch_size, 3, im_size, im_size).cuda(non_blocking=False) 31 | timer = Timer() 32 | total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER 33 | for cur_iter in range(total_iter): 34 | if cur_iter == cfg.PREC_TIME.WARMUP_ITER: 35 | timer.reset() 36 | timer.tic() 37 | model(inputs) 38 | torch.cuda.synchronize() 39 | timer.toc() 40 | return timer.average_time 41 | 42 | 43 | def compute_time_train(model, loss_fun): 44 | model.train() 45 | im_size, batch_size = cfg.MODEL.IMG_SIZE, int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS) 46 | inputs = torch.rand(batch_size, 3, im_size, im_size).cuda(non_blocking=False) 47 | labels = torch.zeros(batch_size, dtype=torch.int64).cuda(non_blocking=False) 48 | labels_one_hot = net.smooth_one_hot_labels(labels) 49 | offline_features = None 50 | if hasattr(net.unwrap_model(model), 'guidance_loss') and cfg.DISTILLATION.OFFLINE: 51 | kd_data = np.load(cfg.DISTILLATION.FEATURE_FILE) 52 | offline_features = [] 53 | for i in range(len(kd_data.files)): 54 | feat = torch.from_numpy(kd_data[f'layer_{i}'][0]).cuda(non_blocking=False) 55 | offline_features.append(feat.unsqueeze(0).repeat(batch_size, 1, 1, 1)) 56 | bns = [m for m in model.modules() if isinstance(m, torch.nn.BatchNorm2d)] 57 | bn_stats = [[bn.running_mean.clone(), bn.running_var.clone()] for bn in bns] 58 | scaler = amp.GradScaler(enabled=cfg.TRAIN.MIXED_PRECISION) 59 | fw_timer, bw_timer = Timer(), Timer() 60 | total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER 61 | for cur_iter in range(total_iter): 62 | if cur_iter == cfg.PREC_TIME.WARMUP_ITER: 63 | fw_timer.reset() 64 | bw_timer.reset() 65 | fw_timer.tic() 66 | with amp.autocast(enabled=cfg.TRAIN.MIXED_PRECISION): 67 | preds = model(inputs) 68 | loss_cls = loss_fun(preds, labels_one_hot) 69 | loss, loss_inter, loss_logit = loss_cls, inputs.new_tensor(0.0), inputs.new_tensor(0.0) 70 | if hasattr(net.unwrap_model(model), 'guidance_loss'): 71 | loss_inter, loss_logit = net.unwrap_model(model).guidance_loss(inputs, offline_features) 72 | if cfg.DISTILLATION.ENABLE_LOGIT: 73 | loss_cls = loss_cls * (1 - cfg.DISTILLATION.LOGIT_WEIGHT) 74 | loss_logit = loss_logit * cfg.DISTILLATION.LOGIT_WEIGHT 75 | loss = loss_cls + loss_logit 76 | if cfg.DISTILLATION.ENABLE_INTER: 77 | loss_inter = loss_inter * cfg.DISTILLATION.INTER_WEIGHT 78 | loss = loss_cls + loss_inter 79 | torch.cuda.synchronize() 80 | fw_timer.toc() 81 | bw_timer.tic() 82 | scaler.scale(loss).backward() 83 | torch.cuda.synchronize() 84 | bw_timer.toc() 85 | for bn, (mean, var) in zip(bns, bn_stats): 86 | bn.running_mean, bn.running_var = mean, var 87 | return fw_timer.average_time, bw_timer.average_time 88 | 89 | 90 | def compute_time_loader(data_loader): 91 | timer = Timer() 92 | loader.shuffle(data_loader, 0) 93 | data_loader_iterator = iter(data_loader) 94 | total_iter = cfg.PREC_TIME.NUM_ITER + cfg.PREC_TIME.WARMUP_ITER 95 | total_iter = min(total_iter, len(data_loader)) 96 | for cur_iter in range(total_iter): 97 | if cur_iter == cfg.PREC_TIME.WARMUP_ITER: 98 | timer.reset() 99 | timer.tic() 100 | next(data_loader_iterator) 101 | timer.toc() 102 | return timer.average_time 103 | 104 | 105 | def compute_time_model(model, loss_fun): 106 | logger.info("Computing model timings only...") 107 | test_fw_time = compute_time_eval(model) 108 | train_fw_time, train_bw_time = compute_time_train(model, loss_fun) 109 | train_fw_bw_time = train_fw_time + train_bw_time 110 | iter_times = { 111 | "test_fw_time": test_fw_time, 112 | "train_fw_time": train_fw_time, 113 | "train_bw_time": train_bw_time, 114 | "train_fw_bw_time": train_fw_bw_time, 115 | } 116 | logger.info(logging.dump_log_data(iter_times, "iter_times")) 117 | 118 | 119 | def compute_time_full(model, loss_fun, train_loader, test_loader): 120 | logger.info("Computing model and loader timings...") 121 | test_fw_time = compute_time_eval(model) 122 | train_fw_time, train_bw_time = compute_time_train(model, loss_fun) 123 | train_fw_bw_time = train_fw_time + train_bw_time 124 | train_loader_time = compute_time_loader(train_loader) 125 | iter_times = { 126 | "test_fw_time": test_fw_time, 127 | "train_fw_time": train_fw_time, 128 | "train_bw_time": train_bw_time, 129 | "train_fw_bw_time": train_fw_bw_time, 130 | "train_loader_time": train_loader_time, 131 | } 132 | logger.info(logging.dump_log_data(iter_times, "iter_times")) 133 | epoch_times = { 134 | "test_fw_time": test_fw_time * len(test_loader), 135 | "train_fw_time": train_fw_time * len(train_loader), 136 | "train_bw_time": train_bw_time * len(train_loader), 137 | "train_fw_bw_time": train_fw_bw_time * len(train_loader), 138 | "train_loader_time": train_loader_time * len(train_loader), 139 | } 140 | logger.info(logging.dump_log_data(epoch_times, "epoch_times")) 141 | overhead = max(0, train_loader_time - train_fw_bw_time) / train_fw_bw_time 142 | logger.info("Overhead of data loader is {:.2f}%".format(overhead * 100)) 143 | -------------------------------------------------------------------------------- /pycls/core/builders.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Modified from pycls. 10 | https://github.com/facebookresearch/pycls/blob/main/pycls/core/builders.py 11 | """ 12 | 13 | from pycls.core.config import cfg 14 | from pycls.core.net import SoftCrossEntropyLoss 15 | from pycls.models.build import build_model 16 | 17 | 18 | _loss_funs = {"cross_entropy": SoftCrossEntropyLoss} 19 | 20 | 21 | def get_loss_fun(): 22 | err_str = "Loss function type '{}' not supported" 23 | assert cfg.MODEL.LOSS_FUN in _loss_funs.keys(), err_str.format(cfg.TRAIN.LOSS) 24 | return _loss_funs[cfg.MODEL.LOSS_FUN] 25 | 26 | 27 | def build_loss_fun(): 28 | return get_loss_fun()() 29 | 30 | 31 | def register_loss_fun(name, ctor): 32 | _loss_funs[name] = ctor 33 | -------------------------------------------------------------------------------- /pycls/core/checkpoint.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Modified from pycls. 10 | https://github.com/facebookresearch/pycls/blob/main/pycls/core/checkpoint.py 11 | """ 12 | 13 | import os 14 | 15 | import pycls.core.distributed as dist 16 | import torch 17 | from pycls.core.config import cfg 18 | from pycls.core.io import pathmgr 19 | from pycls.core.net import unwrap_model 20 | import pycls.core.logging as logging 21 | 22 | 23 | logger = logging.get_logger(__name__) 24 | 25 | 26 | _NAME_PREFIX = "model_epoch_" 27 | 28 | _DIR_NAME = "checkpoints" 29 | 30 | 31 | def get_checkpoint_dir(): 32 | return os.path.join(cfg.OUT_DIR, _DIR_NAME) 33 | 34 | 35 | def get_checkpoint(epoch): 36 | name = "{}{:04d}.pyth".format(_NAME_PREFIX, epoch) 37 | return os.path.join(get_checkpoint_dir(), name) 38 | 39 | 40 | def get_checkpoint_best(): 41 | return os.path.join(cfg.OUT_DIR, "model.pyth") 42 | 43 | 44 | def get_last_checkpoint(): 45 | checkpoint_dir = get_checkpoint_dir() 46 | checkpoints = [f for f in pathmgr.ls(checkpoint_dir) if _NAME_PREFIX in f] 47 | last_checkpoint_name = sorted(checkpoints)[-1] 48 | return os.path.join(checkpoint_dir, last_checkpoint_name) 49 | 50 | 51 | def has_checkpoint(): 52 | checkpoint_dir = get_checkpoint_dir() 53 | if not pathmgr.exists(checkpoint_dir): 54 | return False 55 | return any(_NAME_PREFIX in f for f in pathmgr.ls(checkpoint_dir)) 56 | 57 | 58 | def save_checkpoint(model, model_ema, optimizer, epoch, test_err, ema_err): 59 | if not dist.is_main_proc(): 60 | return 61 | pathmgr.mkdirs(get_checkpoint_dir()) 62 | checkpoint = { 63 | "epoch": epoch, 64 | "test_err": test_err, 65 | "ema_err": ema_err, 66 | "model_state": unwrap_model(model).state_dict(), 67 | "ema_state": unwrap_model(model_ema).state_dict(), 68 | "optimizer_state": optimizer.state_dict(), 69 | "cfg": cfg.dump(), 70 | } 71 | checkpoint_file = get_checkpoint(epoch + 1) 72 | with pathmgr.open(checkpoint_file, "wb") as f: 73 | torch.save(checkpoint, f) 74 | if not pathmgr.exists(get_checkpoint_best()): 75 | pathmgr.copy(checkpoint_file, get_checkpoint_best()) 76 | else: 77 | with pathmgr.open(get_checkpoint_best(), "rb") as f: 78 | best = torch.load(f, map_location="cpu") 79 | if test_err < best["test_err"] or ema_err < best["ema_err"]: 80 | if test_err < best["test_err"]: 81 | best["model_state"] = checkpoint["model_state"] 82 | best["test_err"] = test_err 83 | if ema_err < best["ema_err"]: 84 | best["ema_state"] = checkpoint["ema_state"] 85 | best["ema_err"] = ema_err 86 | with pathmgr.open(get_checkpoint_best(), "wb") as f: 87 | torch.save(best, f) 88 | return checkpoint_file 89 | 90 | 91 | def load_checkpoint(checkpoint_file, model, model_ema=None, optimizer=None): 92 | err_str = "Checkpoint '{}' not found" 93 | assert pathmgr.exists(checkpoint_file), err_str.format(checkpoint_file) 94 | with pathmgr.open(checkpoint_file, "rb") as f: 95 | checkpoint = torch.load(f, map_location="cpu") 96 | test_err = checkpoint["test_err"] if "test_err" in checkpoint else 100 97 | ema_err = checkpoint["ema_err"] if "ema_err" in checkpoint else 100 98 | ema_state = "ema_state" if "ema_state" in checkpoint else "model_state" 99 | if model_ema: 100 | logger.info(unwrap_model(model).load_state_dict(checkpoint["model_state"], strict=False)) 101 | unwrap_model(model_ema).load_state_dict(checkpoint[ema_state], strict=False) 102 | else: 103 | best_state = "model_state" if test_err <= ema_err else ema_state 104 | logger.info(unwrap_model(model).load_state_dict(checkpoint[best_state], strict=False)) 105 | if optimizer: 106 | optimizer.load_state_dict(checkpoint["optimizer_state"]) 107 | return checkpoint["epoch"], test_err, ema_err 108 | 109 | 110 | def delete_checkpoints(checkpoint_dir=None, keep="all"): 111 | assert keep in ["all", "last", "none"], "Invalid keep setting: {}".format(keep) 112 | checkpoint_dir = checkpoint_dir if checkpoint_dir else get_checkpoint_dir() 113 | if keep == "all" or not pathmgr.exists(checkpoint_dir): 114 | return 0 115 | checkpoints = [f for f in pathmgr.ls(checkpoint_dir) if _NAME_PREFIX in f] 116 | checkpoints = sorted(checkpoints)[:-1] if keep == "last" else checkpoints 117 | for checkpoint in checkpoints: 118 | pathmgr.rm(os.path.join(checkpoint_dir, checkpoint)) 119 | return len(checkpoints) 120 | -------------------------------------------------------------------------------- /pycls/core/config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Modified from pycls. 10 | https://github.com/facebookresearch/pycls/blob/main/pycls/core/config.py 11 | """ 12 | 13 | import os 14 | 15 | from pycls.core.io import pathmgr 16 | from yacs.config import CfgNode 17 | 18 | 19 | # Global config object (example usage: from core.config import cfg) 20 | _C = CfgNode() 21 | cfg = _C 22 | 23 | 24 | # -------------------------- Knowledge distillation options -------------------------- # 25 | _C.DISTILLATION = CfgNode() 26 | 27 | # Intermediate layers distillation options 28 | _C.DISTILLATION.ENABLE_INTER = False 29 | _C.DISTILLATION.INTER_TRANSFORM = "linear" 30 | _C.DISTILLATION.INTER_LOSS = "l2" 31 | _C.DISTILLATION.INTER_TEACHER_INDEX = [] 32 | _C.DISTILLATION.INTER_STUDENT_INDEX = [] 33 | _C.DISTILLATION.INTER_WEIGHT = 2.5 34 | 35 | # Logits distillation options 36 | _C.DISTILLATION.ENABLE_LOGIT = False 37 | _C.DISTILLATION.LOGIT_LOSS = "soft" 38 | _C.DISTILLATION.LOGIT_TEMP = 1.0 39 | _C.DISTILLATION.LOGIT_WEIGHT = 0.5 40 | 41 | # Teacher model 42 | _C.DISTILLATION.TEACHER_MODEL = "ResNet" 43 | _C.DISTILLATION.TEACHER_WEIGHTS = None 44 | _C.DISTILLATION.TEACHER_IMG_SIZE = 32 45 | 46 | # Offline settings 47 | _C.DISTILLATION.OFFLINE = False 48 | _C.DISTILLATION.FEATURE_FILE = None 49 | 50 | 51 | # ------------------------------- Common model options ------------------------------- # 52 | _C.MODEL = CfgNode() 53 | 54 | _C.MODEL.TYPE = "ResNet" 55 | _C.MODEL.IMG_SIZE = 224 56 | _C.MODEL.IN_CHANNELS = 3 57 | _C.MODEL.NUM_CLASSES = 100 58 | _C.MODEL.LOSS_FUN = "cross_entropy" 59 | 60 | 61 | # ------------------------------------ CNN options ----------------------------------- # 62 | _C.CNN = CfgNode() 63 | 64 | _C.CNN.DEPTH = 56 65 | _C.CNN.ACTIVATION_FUN = "relu" 66 | _C.CNN.ACTIVATION_INPLACE = True 67 | _C.CNN.BN_EPS = 1e-5 68 | _C.CNN.BN_MOMENTUM = 0.1 69 | _C.CNN.ZERO_INIT_FINAL_BN_GAMMA = False 70 | 71 | 72 | _C.RESNET = CfgNode() 73 | 74 | _C.RESNET.TRANS_FUN = "basic_transform" 75 | _C.RESNET.NUM_GROUPS = 1 76 | _C.RESNET.WIDTH_PER_GROUP = 64 77 | _C.RESNET.STRIDE_1X1 = True 78 | 79 | 80 | # -------------------------------- Transformer options ------------------------------- # 81 | _C.TRANSFORMER = CfgNode() 82 | 83 | _C.TRANSFORMER.PATCH_SIZE = None 84 | _C.TRANSFORMER.PATCH_STRIDE = None 85 | _C.TRANSFORMER.PATCH_PADDING = None 86 | _C.TRANSFORMER.HIDDEN_DIM = None 87 | _C.TRANSFORMER.DEPTH = None 88 | _C.TRANSFORMER.NUM_HEADS = None 89 | _C.TRANSFORMER.MLP_RATIO = None 90 | 91 | _C.TRANSFORMER.LN_EPS = 1e-6 92 | _C.TRANSFORMER.DROP_RATE = None 93 | _C.TRANSFORMER.DROP_PATH_RATE = None 94 | _C.TRANSFORMER.ATTENTION_DROP_RATE = None 95 | 96 | 97 | _C.T2T = CfgNode() 98 | 99 | _C.T2T.TOKEN_DIM = 64 100 | _C.T2T.KERNEL_SIZE = (7, 3, 3) 101 | _C.T2T.STRIDE = (4, 2, 2) 102 | _C.T2T.PADDING = (2, 1, 1) 103 | 104 | 105 | _C.PIT = CfgNode() 106 | 107 | _C.PIT.STRIDE = 8 108 | 109 | 110 | _C.PVT = CfgNode() 111 | 112 | _C.PVT.SR_RATIO = [8, 4, 2, 1] 113 | 114 | 115 | _C.CONVIT = CfgNode() 116 | 117 | _C.CONVIT.LOCAL_LAYERS = 10 118 | _C.CONVIT.LOCALITY_STRENGTH = 1.0 119 | 120 | 121 | _C.CVT = CfgNode() 122 | 123 | _C.CVT.WITH_CLS_TOKEN = [False, False, True] 124 | _C.CVT.QKV_PROJ_METHOD = ['dw_bn', 'dw_bn', 'dw_bn'] 125 | _C.CVT.KERNEL_QKV = [3, 3, 3] 126 | _C.CVT.STRIDE_KV = [2, 2, 2] 127 | _C.CVT.STRIDE_Q = [1, 1, 1] 128 | _C.CVT.PADDING_KV = [1, 1, 1] 129 | _C.CVT.PADDING_Q = [1, 1, 1] 130 | 131 | 132 | # -------------------------------- Optimizer options --------------------------------- # 133 | _C.OPTIM = CfgNode() 134 | 135 | # Type of optimizer select from {'sgd', 'adam', 'adamw'} 136 | _C.OPTIM.OPTIMIZER = "sgd" 137 | 138 | # Learning rate of body ranges from BASE_LR to MIN_LR according to the LR_POLICY 139 | _C.OPTIM.BASE_LR = 0.1 140 | _C.OPTIM.MIN_LR = 0.0 141 | 142 | # Base learning of head is TRANSFER_LR_RATIO * BASE_LR 143 | _C.OPTIM.HEAD_LR_RATIO = 1.0 144 | 145 | # Learning rate policy select from {'cos', 'exp', 'lin', 'steps'} 146 | _C.OPTIM.LR_POLICY = "cos" 147 | 148 | # Steps for 'steps' policy (in epochs) 149 | _C.OPTIM.STEPS = [] 150 | 151 | # Learning rate multiplier for 'steps' policy 152 | _C.OPTIM.LR_MULT = 0.1 153 | 154 | # Maximal number of epochs 155 | _C.OPTIM.MAX_EPOCH = 200 156 | 157 | # Momentum 158 | _C.OPTIM.MOMENTUM = 0.9 159 | 160 | # Momentum dampening 161 | _C.OPTIM.DAMPENING = 0.0 162 | 163 | # Nesterov momentum 164 | _C.OPTIM.NESTEROV = True 165 | 166 | # Betas (for Adam/AdamW optimizer) 167 | _C.OPTIM.BETA1 = 0.9 168 | _C.OPTIM.BETA2 = 0.999 169 | 170 | # L2 regularization 171 | _C.OPTIM.WEIGHT_DECAY = 5e-4 172 | 173 | # Start the warm up from OPTIM.BASE_LR * OPTIM.WARMUP_FACTOR 174 | _C.OPTIM.WARMUP_FACTOR = 0.1 175 | 176 | # Gradually warm up the OPTIM.BASE_LR over this number of epochs 177 | _C.OPTIM.WARMUP_EPOCHS = 0 178 | 179 | # Exponential Moving Average (EMA) update value 180 | _C.OPTIM.EMA_ALPHA = 1e-5 181 | 182 | # Iteration frequency with which to update EMA weights 183 | _C.OPTIM.EMA_UPDATE_PERIOD = 0 184 | 185 | 186 | # --------------------------------- Training options --------------------------------- # 187 | _C.TRAIN = CfgNode() 188 | 189 | # Dataset and split 190 | _C.TRAIN.DATASET = "" 191 | _C.TRAIN.SPLIT = "train" 192 | 193 | # Total mini-batch size 194 | _C.TRAIN.BATCH_SIZE = 128 195 | 196 | # Resume training from the latest checkpoint in the output directory 197 | _C.TRAIN.AUTO_RESUME = True 198 | 199 | # Weights to start training from 200 | _C.TRAIN.WEIGHTS = "" 201 | 202 | # If True train using mixed precision 203 | _C.TRAIN.MIXED_PRECISION = False 204 | 205 | # Label smoothing value in 0 to 1 where (0 gives no smoothing) 206 | _C.TRAIN.LABEL_SMOOTHING = 0.0 207 | 208 | # Batch mixup regularization value in 0 to 1 (0 gives no mixup) 209 | _C.TRAIN.MIXUP_ALPHA = 0.0 210 | 211 | # Batch cutmix regularization value in 0 to 1 (0 gives no cutmix) 212 | _C.TRAIN.CUTMIX_ALPHA = 0.0 213 | 214 | _C.TRAIN.STRONG_AUGMENTATION = True 215 | 216 | 217 | # --------------------------------- Testing options ---------------------------------- # 218 | _C.TEST = CfgNode() 219 | 220 | # Dataset and split 221 | _C.TEST.DATASET = "" 222 | _C.TEST.SPLIT = "val" 223 | 224 | # Total mini-batch size 225 | _C.TEST.BATCH_SIZE = 200 226 | 227 | # Weights to use for testing 228 | _C.TEST.WEIGHTS = "" 229 | 230 | 231 | # ------------------------------- Data loader options -------------------------------- # 232 | _C.DATA_LOADER = CfgNode() 233 | 234 | # Number of data loader workers per process 235 | _C.DATA_LOADER.NUM_WORKERS = 8 236 | 237 | # Load data to pinned host memory 238 | _C.DATA_LOADER.PIN_MEMORY = True 239 | 240 | 241 | # ---------------------------------- CUDNN options ----------------------------------- # 242 | _C.CUDNN = CfgNode() 243 | 244 | # Perform benchmarking to select fastest CUDNN algorithms (best for fixed input sizes) 245 | _C.CUDNN.BENCHMARK = True 246 | 247 | 248 | # ------------------------------- Precise time options ------------------------------- # 249 | _C.PREC_TIME = CfgNode() 250 | 251 | # Number of iterations to warm up the caches 252 | _C.PREC_TIME.WARMUP_ITER = 3 253 | 254 | # Number of iterations to compute avg time 255 | _C.PREC_TIME.NUM_ITER = 30 256 | 257 | 258 | # ---------------------------------- Launch options ---------------------------------- # 259 | _C.LAUNCH = CfgNode() 260 | 261 | # The launch mode, may be 'local' or 'slurm' (or 'submitit_local' for debugging) 262 | # The 'local' mode uses a multi-GPU setup via torch.multiprocessing.run_processes. 263 | # The 'slurm' mode uses submitit to launch a job on a SLURM cluster and provides 264 | # support for MULTI-NODE jobs (and is the only way to launch MULTI-NODE jobs). 265 | # In 'slurm' mode, the LAUNCH options below can be used to control the SLURM options. 266 | # Note that NUM_GPUS (not part of LAUNCH options) determines total GPUs requested. 267 | _C.LAUNCH.MODE = "local" 268 | 269 | # Launch options that are only used if LAUNCH.MODE is 'slurm' 270 | _C.LAUNCH.MAX_RETRY = 3 271 | _C.LAUNCH.NAME = "pycls_job" 272 | _C.LAUNCH.COMMENT = "" 273 | _C.LAUNCH.CPUS_PER_GPU = 10 274 | _C.LAUNCH.MEM_PER_GPU = 60 275 | _C.LAUNCH.PARTITION = "devlab" 276 | _C.LAUNCH.GPU_TYPE = "volta" 277 | _C.LAUNCH.TIME_LIMIT = 4200 278 | _C.LAUNCH.EMAIL = "" 279 | 280 | 281 | # ----------------------------------- Misc options ----------------------------------- # 282 | # Optional description of a config 283 | _C.DESC = "" 284 | 285 | # If True output additional info to log 286 | _C.VERBOSE = True 287 | 288 | # Number of GPUs to use (applies to both training and testing) 289 | _C.NUM_GPUS = 1 290 | 291 | # Maximum number of GPUs available per node (unlikely to need to be changed) 292 | _C.MAX_GPUS_PER_NODE = 8 293 | 294 | # Output directory 295 | _C.OUT_DIR = None 296 | 297 | # Config destination (in OUT_DIR) 298 | _C.CFG_DEST = "config.yaml" 299 | 300 | # Note that non-determinism is still be present due to non-deterministic GPU ops 301 | _C.RNG_SEED = 1 302 | 303 | # Log destination ('stdout' or 'file') 304 | _C.LOG_DEST = "stdout" 305 | 306 | # Log period in iters 307 | _C.LOG_PERIOD = 10 308 | 309 | # Distributed backend 310 | _C.DIST_BACKEND = "nccl" 311 | 312 | # Hostname and port range for multi-process groups (actual port selected randomly) 313 | _C.HOST = "localhost" 314 | _C.PORT_RANGE = [10000, 65000] 315 | 316 | # Models weights referred to by URL are downloaded to this local cache 317 | _C.DOWNLOAD_CACHE = "/tmp/pycls-download-cache" 318 | 319 | 320 | # ---------------------------------- Default config ---------------------------------- # 321 | _CFG_DEFAULT = _C.clone() 322 | _CFG_DEFAULT.freeze() 323 | 324 | 325 | def assert_cfg(): 326 | """Checks config values invariants.""" 327 | err_str = "The first lr step must start at 0" 328 | assert not _C.OPTIM.STEPS or _C.OPTIM.STEPS[0] == 0, err_str 329 | data_splits = ["train", "val", "test"] 330 | err_str = "Data split '{}' not supported" 331 | assert _C.TRAIN.SPLIT in data_splits, err_str.format(_C.TRAIN.SPLIT) 332 | assert _C.TEST.SPLIT in data_splits, err_str.format(_C.TEST.SPLIT) 333 | err_str = "Mini-batch size should be a multiple of NUM_GPUS." 334 | assert _C.TRAIN.BATCH_SIZE % _C.NUM_GPUS == 0, err_str 335 | assert _C.TEST.BATCH_SIZE % _C.NUM_GPUS == 0, err_str 336 | err_str = "Log destination '{}' not supported" 337 | assert _C.LOG_DEST in ["stdout", "file"], err_str.format(_C.LOG_DEST) 338 | err_str = "NUM_GPUS must be divisible by or less than MAX_GPUS_PER_NODE" 339 | num_gpus, max_gpus_per_node = _C.NUM_GPUS, _C.MAX_GPUS_PER_NODE 340 | assert num_gpus <= max_gpus_per_node or num_gpus % max_gpus_per_node == 0, err_str 341 | err_str = "Invalid mode {}".format(_C.LAUNCH.MODE) 342 | assert _C.LAUNCH.MODE in ["local", "submitit_local", "slurm"], err_str 343 | 344 | 345 | def dump_cfg(): 346 | """Dumps the config to the output directory.""" 347 | cfg_file = os.path.join(_C.OUT_DIR, _C.CFG_DEST) 348 | with pathmgr.open(cfg_file, "w") as f: 349 | _C.dump(stream=f) 350 | return cfg_file 351 | 352 | 353 | def load_cfg(cfg_file): 354 | """Loads config from specified file.""" 355 | with pathmgr.open(cfg_file, "r") as f: 356 | _C.merge_from_other_cfg(_C.load_cfg(f)) 357 | 358 | 359 | def reset_cfg(): 360 | """Reset config to initial state.""" 361 | _C.merge_from_other_cfg(_CFG_DEFAULT) 362 | -------------------------------------------------------------------------------- /pycls/core/distributed.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Modified from pycls. 10 | https://github.com/facebookresearch/pycls/blob/main/pycls/core/distributed.py 11 | """ 12 | 13 | import os 14 | import random 15 | 16 | import submitit 17 | import torch 18 | from pycls.core.config import cfg 19 | 20 | 21 | os.environ["MKL_THREADING_LAYER"] = "GNU" 22 | 23 | 24 | class SubmititRunner(submitit.helpers.Checkpointable): 25 | 26 | def __init__(self, port, fun, cfg_state): 27 | self.cfg_state = cfg_state 28 | self.port = port 29 | self.fun = fun 30 | 31 | def __call__(self): 32 | job_env = submitit.JobEnvironment() 33 | os.environ["MASTER_ADDR"] = job_env.hostnames[0] 34 | os.environ["MASTER_PORT"] = str(self.port) 35 | os.environ["RANK"] = str(job_env.global_rank) 36 | os.environ["LOCAL_RANK"] = str(job_env.local_rank) 37 | os.environ["WORLD_SIZE"] = str(job_env.num_tasks) 38 | setup_distributed(self.cfg_state) 39 | self.fun() 40 | 41 | 42 | def is_main_proc(local=False): 43 | m = cfg.MAX_GPUS_PER_NODE if local else cfg.NUM_GPUS 44 | return cfg.NUM_GPUS == 1 or torch.distributed.get_rank() % m == 0 45 | 46 | 47 | def scaled_all_reduce(tensors): 48 | if cfg.NUM_GPUS == 1: 49 | return tensors 50 | reductions = [] 51 | for tensor in tensors: 52 | reduction = torch.distributed.all_reduce(tensor, async_op=True) 53 | reductions.append(reduction) 54 | for reduction in reductions: 55 | reduction.wait() 56 | for tensor in tensors: 57 | tensor.mul_(1.0 / cfg.NUM_GPUS) 58 | return tensors 59 | 60 | 61 | def setup_distributed(cfg_state): 62 | cfg.defrost() 63 | cfg.update(**cfg_state) 64 | cfg.freeze() 65 | local_rank = int(os.environ["LOCAL_RANK"]) 66 | torch.distributed.init_process_group(backend=cfg.DIST_BACKEND) 67 | torch.cuda.set_device(local_rank) 68 | 69 | 70 | def single_proc_run(local_rank, fun, main_port, cfg_state, world_size): 71 | os.environ["MASTER_ADDR"] = "localhost" 72 | os.environ["MASTER_PORT"] = str(main_port) 73 | os.environ["RANK"] = str(local_rank) 74 | os.environ["LOCAL_RANK"] = str(local_rank) 75 | os.environ["WORLD_SIZE"] = str(world_size) 76 | setup_distributed(cfg_state) 77 | fun() 78 | 79 | 80 | def multi_proc_run(num_proc, fun): 81 | launch = cfg.LAUNCH 82 | if launch.MODE in ["submitit_local", "slurm"]: 83 | use_slurm = launch.MODE == "slurm" 84 | executor = submitit.AutoExecutor if use_slurm else submitit.LocalExecutor 85 | kwargs = {"slurm_max_num_timeout": launch.MAX_RETRY} if use_slurm else {} 86 | executor = executor(folder=cfg.OUT_DIR, **kwargs) 87 | num_gpus_per_node = min(cfg.NUM_GPUS, cfg.MAX_GPUS_PER_NODE) 88 | executor.update_parameters( 89 | mem_gb=launch.MEM_PER_GPU * num_gpus_per_node, 90 | gpus_per_node=num_gpus_per_node, 91 | tasks_per_node=num_gpus_per_node, 92 | cpus_per_task=launch.CPUS_PER_GPU, 93 | nodes=max(1, cfg.NUM_GPUS // cfg.MAX_GPUS_PER_NODE), 94 | timeout_min=launch.TIME_LIMIT, 95 | name=launch.NAME, 96 | slurm_partition=launch.PARTITION, 97 | slurm_comment=launch.COMMENT, 98 | slurm_constraint=launch.GPU_TYPE, 99 | slurm_additional_parameters={"mail-user": launch.EMAIL, "mail-type": "END"}, 100 | ) 101 | main_port = random.randint(cfg.PORT_RANGE[0], cfg.PORT_RANGE[1]) 102 | job = executor.submit(SubmititRunner(main_port, fun, cfg)) 103 | print("Submitted job_id {} with out_dir: {}".format(job.job_id, cfg.OUT_DIR)) 104 | if not use_slurm: 105 | job.wait() 106 | elif num_proc > 1: 107 | main_port = random.randint(cfg.PORT_RANGE[0], cfg.PORT_RANGE[1]) 108 | mp_runner = torch.multiprocessing.start_processes 109 | args = (fun, main_port, cfg, num_proc) 110 | mp_runner(single_proc_run, args=args, nprocs=num_proc, start_method="fork") 111 | else: 112 | fun() 113 | -------------------------------------------------------------------------------- /pycls/core/io.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Modified from pycls. 10 | https://github.com/facebookresearch/pycls/blob/main/pycls/core/io.py 11 | """ 12 | 13 | import logging 14 | import os 15 | import re 16 | import sys 17 | from urllib import request as urlrequest 18 | 19 | from iopath.common.file_io import PathManagerFactory 20 | 21 | 22 | pathmgr = PathManagerFactory.get() 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | _PYCLS_BASE_URL = "" 27 | 28 | 29 | def cache_url(url_or_file, cache_dir, base_url=_PYCLS_BASE_URL, download=True): 30 | is_url = re.match(r"^(?:http)s?://", url_or_file, re.IGNORECASE) is not None 31 | if not is_url: 32 | return url_or_file 33 | url = url_or_file 34 | assert url.startswith(base_url), "url must start with: {}".format(base_url) 35 | cache_file_path = url.replace(base_url, cache_dir) 36 | if pathmgr.exists(cache_file_path): 37 | return cache_file_path 38 | cache_file_dir = os.path.dirname(cache_file_path) 39 | if not pathmgr.exists(cache_file_dir): 40 | pathmgr.mkdirs(cache_file_dir) 41 | if download: 42 | logger.info("Downloading remote file {} to {}".format(url, cache_file_path)) 43 | download_url(url, cache_file_path) 44 | return cache_file_path 45 | 46 | 47 | def _progress_bar(count, total): 48 | bar_len = 60 49 | filled_len = int(round(bar_len * count / float(total))) 50 | percents = round(100.0 * count / float(total), 1) 51 | bar = "=" * filled_len + "-" * (bar_len - filled_len) 52 | sys.stdout.write( 53 | " [{}] {}% of {:.1f}MB file \r".format(bar, percents, total / 1024 / 1024) 54 | ) 55 | sys.stdout.flush() 56 | if count >= total: 57 | sys.stdout.write("\n") 58 | 59 | 60 | def download_url(url, dst_file_path, chunk_size=8192, progress_hook=_progress_bar): 61 | req = urlrequest.Request(url) 62 | response = urlrequest.urlopen(req) 63 | total_size = response.info().get("Content-Length").strip() 64 | total_size = int(total_size) 65 | bytes_so_far = 0 66 | with pathmgr.open(dst_file_path, "wb") as f: 67 | while 1: 68 | chunk = response.read(chunk_size) 69 | bytes_so_far += len(chunk) 70 | if not chunk: 71 | break 72 | if progress_hook: 73 | progress_hook(bytes_so_far, total_size) 74 | f.write(chunk) 75 | return bytes_so_far 76 | -------------------------------------------------------------------------------- /pycls/core/logging.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Modified from pycls. 10 | https://github.com/facebookresearch/pycls/blob/main/pycls/core/logging.py 11 | """ 12 | 13 | from torch.utils.tensorboard import SummaryWriter 14 | import builtins 15 | import decimal 16 | import logging 17 | import os 18 | import sys 19 | from logging import FileHandler, INFO 20 | 21 | import pycls.core.distributed as dist 22 | import simplejson 23 | from pycls.core.config import cfg 24 | from pycls.core.io import pathmgr 25 | 26 | 27 | _FORMAT = "[%(filename)s: %(lineno)3d]: %(message)s" 28 | 29 | _LOG_FILE = "stdout.log" 30 | 31 | _TAG = "json_stats: " 32 | 33 | _TYPE = "_type" 34 | 35 | 36 | def _suppress_print(): 37 | def ignore(*_objects, _sep=" ", _end="\n", _file=sys.stdout, _flush=False): 38 | pass 39 | 40 | builtins.print = ignore 41 | 42 | 43 | def setup_logging(): 44 | if dist.is_main_proc(): 45 | logging.root.handlers = [] 46 | logging_config = {"level": logging.INFO, "format": _FORMAT} 47 | if cfg.LOG_DEST == "stdout": 48 | logging_config["stream"] = sys.stdout 49 | else: 50 | logging_config["filename"] = os.path.join(cfg.OUT_DIR, _LOG_FILE) 51 | logging.basicConfig(**logging_config) 52 | else: 53 | _suppress_print() 54 | 55 | 56 | def get_logger(name): 57 | return logging.getLogger(name) 58 | 59 | 60 | def dump_log_data(data, data_type, prec=4): 61 | data[_TYPE] = data_type 62 | data = float_to_decimal(data, prec) 63 | data_json = simplejson.dumps(data, sort_keys=True, use_decimal=True) 64 | return "{:s}{:s}".format(_TAG, data_json) 65 | 66 | 67 | def float_to_decimal(data, prec=4): 68 | if prec and isinstance(data, dict): 69 | return {k: float_to_decimal(v, prec) for k, v in data.items()} 70 | if prec and isinstance(data, float): 71 | return decimal.Decimal(("{:." + str(prec) + "f}").format(data)) 72 | else: 73 | return data 74 | 75 | 76 | def get_log_files(log_dir, name_filter="", log_file=_LOG_FILE): 77 | names = [n for n in sorted(pathmgr.ls(log_dir)) if name_filter in n] 78 | files = [os.path.join(log_dir, n, log_file) for n in names] 79 | f_n_ps = [(f, n) for (f, n) in zip(files, names) if pathmgr.exists(f)] 80 | files, names = zip(*f_n_ps) if f_n_ps else ([], []) 81 | return files, names 82 | 83 | 84 | def load_log_data(log_file, data_types_to_skip=()): 85 | assert pathmgr.exists(log_file), "Log file not found: {}".format(log_file) 86 | with pathmgr.open(log_file, "r") as f: 87 | lines = f.readlines() 88 | lines = [l[l.find(_TAG) + len(_TAG) :] for l in lines if _TAG in l] 89 | lines = [simplejson.loads(l) for l in lines] 90 | lines = [l for l in lines if _TYPE in l and not l[_TYPE] in data_types_to_skip] 91 | data_types = [l[_TYPE] for l in lines] 92 | data = {t: [] for t in data_types} 93 | for t, line in zip(data_types, lines): 94 | del line[_TYPE] 95 | data[t].append(line) 96 | for t in data: 97 | metrics = sorted(data[t][0].keys()) 98 | err_str = "Inconsistent metrics in log for _type={}: {}".format(t, metrics) 99 | assert all(sorted(d.keys()) == metrics for d in data[t]), err_str 100 | data[t] = {m: [d[m] for d in data[t]] for m in metrics} 101 | return data 102 | 103 | 104 | def sort_log_data(data): 105 | for t in data: 106 | if "epoch" in data[t]: 107 | assert "epoch_ind" not in data[t] and "epoch_max" not in data[t] 108 | data[t]["epoch_ind"] = [int(e.split("/")[0]) for e in data[t]["epoch"]] 109 | data[t]["epoch_max"] = [int(e.split("/")[1]) for e in data[t]["epoch"]] 110 | epoch = data[t]["epoch_ind"] 111 | if "iter" in data[t]: 112 | assert "iter_ind" not in data[t] and "iter_max" not in data[t] 113 | data[t]["iter_ind"] = [int(i.split("/")[0]) for i in data[t]["iter"]] 114 | data[t]["iter_max"] = [int(i.split("/")[1]) for i in data[t]["iter"]] 115 | itr = zip(epoch, data[t]["iter_ind"], data[t]["iter_max"]) 116 | epoch = [e + (i_ind - 1) / i_max for e, i_ind, i_max in itr] 117 | for m in data[t]: 118 | data[t][m] = [v for _, v in sorted(zip(epoch, data[t][m]))] 119 | else: 120 | data[t] = {m: d[0] for m, d in data[t].items()} 121 | return data 122 | 123 | 124 | class TFLogger(object): 125 | 126 | def __init__(self): 127 | self.writer = None 128 | 129 | def initialize(self, log_dir): 130 | self.writer = SummaryWriter(log_dir=log_dir) 131 | 132 | def log_iter_stats(self, stats): 133 | if self.writer is None: 134 | return 135 | total_iter = stats['total_iter'] 136 | for k in stats: 137 | if 'loss' in k: 138 | self.writer.add_scalar(f'iteration/{k}', stats[k], total_iter) 139 | self.writer.add_scalar('iteration/top1_acc', 100 - stats['top1_err'], total_iter) 140 | self.writer.add_scalar('iteration/top5_acc', 100 - stats['top5_err'], total_iter) 141 | 142 | def log_epoch_stats(self, stats): 143 | if self.writer is None: 144 | return 145 | epoch = int(stats['epoch'].split('/')[0]) 146 | self.writer.add_scalar('epoch/loss', stats['loss'], epoch) 147 | self.writer.add_scalar('epoch/learning_rate', stats['lr'], epoch) 148 | self.writer.add_scalar('epoch/top1_acc', 100 - stats['top1_err'], epoch) 149 | self.writer.add_scalar('epoch/top5_acc', 100 - stats['top5_err'], epoch) 150 | 151 | def log_test_stats(self, stats): 152 | if self.writer is None: 153 | return 154 | epoch = int(stats['epoch'].split('/')[0]) 155 | self.writer.add_scalar('test/top1_acc', 100 - stats['top1_err'], epoch) 156 | self.writer.add_scalar('test/top5_acc', 100 - stats['top5_err'], epoch) 157 | self.writer.add_scalar('test/max_top1_acc', 100 - stats['min_top1_err'], epoch) 158 | self.writer.add_scalar('test/max_top5_acc', 100 - stats['min_top5_err'], epoch) 159 | 160 | 161 | _tf_logger = TFLogger() 162 | 163 | 164 | def get_tflogger(): 165 | return _tf_logger 166 | -------------------------------------------------------------------------------- /pycls/core/meters.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Modified from pycls. 10 | https://github.com/facebookresearch/pycls/blob/main/pycls/core/meters.py 11 | """ 12 | 13 | from collections import deque 14 | 15 | import os 16 | import numpy as np 17 | import pycls.core.logging as logging 18 | import torch 19 | from pycls.core.config import cfg 20 | from pycls.core.timer import Timer 21 | 22 | 23 | def time_string(seconds): 24 | days, rem = divmod(int(seconds), 24 * 3600) 25 | hrs, rem = divmod(rem, 3600) 26 | mins, secs = divmod(rem, 60) 27 | return "{0:02},{1:02}:{2:02}:{3:02}".format(days, hrs, mins, secs) 28 | 29 | 30 | def topk_errors(preds, labels, ks): 31 | if isinstance(preds, list): 32 | preds = preds[0] + preds[1] 33 | return _topk_errors(preds, labels, ks) 34 | 35 | 36 | def _topk_errors(preds, labels, ks): 37 | err_str = "Batch dim of predictions and labels must match" 38 | assert preds.size(0) == labels.size(0), err_str 39 | ks = [min(k, preds.size(-1)) for k in ks] 40 | _top_max_k_vals, top_max_k_inds = torch.topk( 41 | preds, max(ks), dim=1, largest=True, sorted=True 42 | ) 43 | top_max_k_inds = top_max_k_inds.t() 44 | rep_max_k_labels = labels.view(1, -1).expand_as(top_max_k_inds) 45 | top_max_k_correct = top_max_k_inds.eq(rep_max_k_labels) 46 | topks_correct = [top_max_k_correct[:k, :].reshape(-1).float().sum() for k in ks] 47 | return [(1.0 - x / preds.size(0)) * 100.0 for x in topks_correct] 48 | 49 | 50 | def gpu_mem_usage(): 51 | mem_usage_bytes = torch.cuda.max_memory_allocated() 52 | return mem_usage_bytes / 1024 / 1024 53 | 54 | 55 | class ScalarMeter(object): 56 | 57 | def __init__(self, window_size): 58 | self.deque = deque(maxlen=window_size) 59 | self.total = 0.0 60 | self.count = 0 61 | 62 | def reset(self): 63 | self.deque.clear() 64 | self.total = 0.0 65 | self.count = 0 66 | 67 | def add_value(self, value): 68 | self.deque.append(value) 69 | self.count += 1 70 | self.total += value 71 | 72 | def get_win_median(self): 73 | return np.median(self.deque) 74 | 75 | def get_win_avg(self): 76 | return np.mean(self.deque) 77 | 78 | def get_global_avg(self): 79 | return self.total / self.count 80 | 81 | 82 | class TrainMeter(object): 83 | 84 | def __init__(self, epoch_iters, phase="train"): 85 | self.logger = logging.get_logger(__name__) 86 | log_file = os.path.join(cfg.OUT_DIR, 'log.txt') 87 | file_handler = logging.FileHandler(log_file, 'w') 88 | file_handler.setLevel(logging.INFO) 89 | self.logger.addHandler(file_handler) 90 | 91 | self.tf_logger = logging.get_tflogger() 92 | self.tf_logger.initialize(cfg.OUT_DIR) 93 | 94 | self.epoch_iters = epoch_iters 95 | self.max_iter = cfg.OPTIM.MAX_EPOCH * epoch_iters 96 | self.phase = phase 97 | self.iter_timer = Timer() 98 | self.cls_loss = ScalarMeter(cfg.LOG_PERIOD) 99 | self.inter_loss = ScalarMeter(cfg.LOG_PERIOD) 100 | self.logit_loss = ScalarMeter(cfg.LOG_PERIOD) 101 | self.loss = ScalarMeter(cfg.LOG_PERIOD) 102 | self.loss_total = 0.0 103 | self.lr = None 104 | self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD) 105 | self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD) 106 | self.num_top1_mis = 0 107 | self.num_top5_mis = 0 108 | self.num_samples = 0 109 | 110 | def reset(self, timer=False): 111 | if timer: 112 | self.iter_timer.reset() 113 | self.cls_loss.reset() 114 | self.inter_loss.reset() 115 | self.logit_loss.reset() 116 | self.loss.reset() 117 | self.loss_total = 0.0 118 | self.lr = None 119 | self.mb_top1_err.reset() 120 | self.mb_top5_err.reset() 121 | self.num_top1_mis = 0 122 | self.num_top5_mis = 0 123 | self.num_samples = 0 124 | 125 | def iter_tic(self): 126 | self.iter_timer.tic() 127 | 128 | def iter_toc(self): 129 | self.iter_timer.toc() 130 | 131 | def update_stats(self, top1_err, top5_err, cls_loss, inter_loss, logit_loss, loss, lr, mb_size): 132 | self.mb_top1_err.add_value(top1_err) 133 | self.mb_top5_err.add_value(top5_err) 134 | self.cls_loss.add_value(cls_loss) 135 | self.inter_loss.add_value(inter_loss) 136 | self.logit_loss.add_value(logit_loss) 137 | self.loss.add_value(loss) 138 | self.lr = lr 139 | self.num_top1_mis += top1_err * mb_size 140 | self.num_top5_mis += top5_err * mb_size 141 | self.loss_total += loss * mb_size 142 | self.num_samples += mb_size 143 | 144 | def get_iter_stats(self, cur_epoch, cur_iter): 145 | cur_iter_total = cur_epoch * self.epoch_iters + cur_iter + 1 146 | eta_sec = self.iter_timer.average_time * (self.max_iter - cur_iter_total) 147 | mem_usage = gpu_mem_usage() 148 | stats = { 149 | "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH), 150 | "iter": "{}/{}".format(cur_iter + 1, self.epoch_iters), 151 | "total_iter": cur_iter_total, 152 | "time_avg": self.iter_timer.average_time, 153 | "time_diff": self.iter_timer.diff, 154 | "eta": time_string(eta_sec), 155 | "top1_err": self.mb_top1_err.get_win_median(), 156 | "top5_err": self.mb_top5_err.get_win_median(), 157 | "cls_loss": self.cls_loss.get_win_median(), 158 | "loss": self.loss.get_win_median(), 159 | "lr": self.lr, 160 | "mem": int(np.ceil(mem_usage)), 161 | } 162 | if cfg.DISTILLATION.ENABLE_INTER: 163 | stats["inter_distill_loss"] = self.inter_loss.get_win_median() 164 | if cfg.DISTILLATION.ENABLE_LOGIT: 165 | stats["logit_distill_loss"] = self.logit_loss.get_win_median() 166 | return stats 167 | 168 | def log_iter_stats(self, cur_epoch, cur_iter): 169 | if (cur_iter + 1) % cfg.LOG_PERIOD == 0: 170 | stats = self.get_iter_stats(cur_epoch, cur_iter) 171 | self.logger.info(logging.dump_log_data(stats, self.phase + "_iter")) 172 | self.tf_logger.log_iter_stats(stats) 173 | 174 | def get_epoch_stats(self, cur_epoch): 175 | cur_iter_total = (cur_epoch + 1) * self.epoch_iters 176 | eta_sec = self.iter_timer.average_time * (self.max_iter - cur_iter_total) 177 | mem_usage = gpu_mem_usage() 178 | top1_err = self.num_top1_mis / self.num_samples 179 | top5_err = self.num_top5_mis / self.num_samples 180 | avg_loss = self.loss_total / self.num_samples 181 | stats = { 182 | "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH), 183 | "time_avg": self.iter_timer.average_time, 184 | "time_epoch": self.iter_timer.average_time * self.epoch_iters, 185 | "eta": time_string(eta_sec), 186 | "top1_err": top1_err, 187 | "top5_err": top5_err, 188 | "loss": avg_loss, 189 | "lr": self.lr, 190 | "mem": int(np.ceil(mem_usage)), 191 | } 192 | return stats 193 | 194 | def log_epoch_stats(self, cur_epoch): 195 | stats = self.get_epoch_stats(cur_epoch) 196 | self.logger.info(logging.dump_log_data(stats, self.phase + "_epoch")) 197 | self.tf_logger.log_epoch_stats(stats) 198 | 199 | 200 | class TestMeter(object): 201 | 202 | def __init__(self, epoch_iters, phase="test"): 203 | self.logger = logging.get_logger(__name__) 204 | self.tf_logger = logging.get_tflogger() 205 | self.epoch_iters = epoch_iters 206 | self.phase = phase 207 | self.iter_timer = Timer() 208 | self.mb_top1_err = ScalarMeter(cfg.LOG_PERIOD) 209 | self.mb_top5_err = ScalarMeter(cfg.LOG_PERIOD) 210 | self.min_top1_err = 100.0 211 | self.min_top5_err = 100.0 212 | self.num_top1_mis = 0 213 | self.num_top5_mis = 0 214 | self.num_samples = 0 215 | 216 | def reset(self, min_errs=False): 217 | if min_errs: 218 | self.min_top1_err = 100.0 219 | self.min_top5_err = 100.0 220 | self.iter_timer.reset() 221 | self.mb_top1_err.reset() 222 | self.mb_top5_err.reset() 223 | self.num_top1_mis = 0 224 | self.num_top5_mis = 0 225 | self.num_samples = 0 226 | 227 | def iter_tic(self): 228 | self.iter_timer.tic() 229 | 230 | def iter_toc(self): 231 | self.iter_timer.toc() 232 | 233 | def update_stats(self, top1_err, top5_err, mb_size): 234 | self.mb_top1_err.add_value(top1_err) 235 | self.mb_top5_err.add_value(top5_err) 236 | self.num_top1_mis += top1_err * mb_size 237 | self.num_top5_mis += top5_err * mb_size 238 | self.num_samples += mb_size 239 | 240 | def get_iter_stats(self, cur_epoch, cur_iter): 241 | mem_usage = gpu_mem_usage() 242 | iter_stats = { 243 | "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH), 244 | "iter": "{}/{}".format(cur_iter + 1, self.epoch_iters), 245 | "time_avg": self.iter_timer.average_time, 246 | "time_diff": self.iter_timer.diff, 247 | "top1_err": self.mb_top1_err.get_win_median(), 248 | "top5_err": self.mb_top5_err.get_win_median(), 249 | "mem": int(np.ceil(mem_usage)), 250 | } 251 | return iter_stats 252 | 253 | def log_iter_stats(self, cur_epoch, cur_iter): 254 | if (cur_iter + 1) % cfg.LOG_PERIOD == 0: 255 | stats = self.get_iter_stats(cur_epoch, cur_iter) 256 | self.logger.info(logging.dump_log_data(stats, self.phase + "_iter")) 257 | 258 | def get_epoch_stats(self, cur_epoch): 259 | top1_err = self.num_top1_mis / self.num_samples 260 | top5_err = self.num_top5_mis / self.num_samples 261 | self.min_top1_err = min(self.min_top1_err, top1_err) 262 | self.min_top5_err = min(self.min_top5_err, top5_err) 263 | mem_usage = gpu_mem_usage() 264 | stats = { 265 | "epoch": "{}/{}".format(cur_epoch + 1, cfg.OPTIM.MAX_EPOCH), 266 | "time_avg": self.iter_timer.average_time, 267 | "time_epoch": self.iter_timer.average_time * self.epoch_iters, 268 | "top1_err": top1_err, 269 | "top5_err": top5_err, 270 | "min_top1_err": self.min_top1_err, 271 | "min_top5_err": self.min_top5_err, 272 | "mem": int(np.ceil(mem_usage)), 273 | } 274 | return stats 275 | 276 | def log_epoch_stats(self, cur_epoch): 277 | stats = self.get_epoch_stats(cur_epoch) 278 | self.logger.info(logging.dump_log_data(stats, self.phase + "_epoch")) 279 | if self.phase == 'test': 280 | self.tf_logger.log_test_stats(stats) 281 | -------------------------------------------------------------------------------- /pycls/core/net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Modified from pycls. 10 | https://github.com/facebookresearch/pycls/blob/main/pycls/core/net.py 11 | """ 12 | 13 | import itertools 14 | 15 | import numpy as np 16 | import pycls.core.distributed as dist 17 | import torch 18 | from pycls.core.config import cfg 19 | 20 | 21 | def unwrap_model(model): 22 | wrapped = isinstance(model, torch.nn.parallel.distributed.DistributedDataParallel) 23 | return model.module if wrapped else model 24 | 25 | 26 | def smooth_one_hot_labels(labels): 27 | n_classes, label_smooth = cfg.MODEL.NUM_CLASSES, cfg.TRAIN.LABEL_SMOOTHING 28 | err_str = "Invalid input to one_hot_vector()" 29 | assert labels.ndim == 1 and labels.max() < n_classes, err_str 30 | shape = (labels.shape[0], n_classes) 31 | neg_val = label_smooth / n_classes 32 | pos_val = 1.0 - label_smooth + neg_val 33 | labels_one_hot = torch.full(shape, neg_val, dtype=torch.float, device=labels.device) 34 | labels_one_hot.scatter_(1, labels.long().view(-1, 1), pos_val) 35 | return labels_one_hot 36 | 37 | 38 | class SoftCrossEntropyLoss(torch.nn.Module): 39 | 40 | def __init__(self): 41 | super(SoftCrossEntropyLoss, self).__init__() 42 | 43 | def _cross_entropy(self, x, y): 44 | loss = -y * torch.nn.functional.log_softmax(x, -1) 45 | return torch.sum(loss) / x.shape[0] 46 | 47 | def forward(self, x, y): 48 | if isinstance(x, list): 49 | losses = [self._cross_entropy(pred, y) / len(x) for pred in x] 50 | return sum(losses) 51 | return self._cross_entropy(x, y) 52 | 53 | 54 | def mixup(inputs, labels): 55 | assert labels.shape[1] == cfg.MODEL.NUM_CLASSES, "mixup labels must be one-hot" 56 | mixup_alpha, cutmix_alpha = cfg.TRAIN.MIXUP_ALPHA, cfg.TRAIN.CUTMIX_ALPHA 57 | mixup_alpha = mixup_alpha if (cutmix_alpha == 0 or np.random.rand() < 0.5) else 0 58 | if mixup_alpha > 0: 59 | m = np.random.beta(mixup_alpha, mixup_alpha) 60 | permutation = torch.randperm(labels.shape[0]) 61 | inputs = m * inputs + (1.0 - m) * inputs[permutation, :] 62 | labels = m * labels + (1.0 - m) * labels[permutation, :] 63 | elif cutmix_alpha > 0: 64 | m = np.random.beta(cutmix_alpha, cutmix_alpha) 65 | permutation = torch.randperm(labels.shape[0]) 66 | h, w = inputs.shape[2], inputs.shape[3] 67 | w_b, h_b = np.int(w * np.sqrt(1.0 - m)), np.int(h * np.sqrt(1.0 - m)) 68 | x_c, y_c = np.random.randint(w), np.random.randint(h) 69 | x_0, y_0 = np.clip(x_c - w_b // 2, 0, w), np.clip(y_c - h_b // 2, 0, h) 70 | x_1, y_1 = np.clip(x_c + w_b // 2, 0, w), np.clip(y_c + h_b // 2, 0, h) 71 | m = 1.0 - ((x_1 - x_0) * (y_1 - y_0) / (h * w)) 72 | inputs[:, :, y_0:y_1, x_0:x_1] = inputs[permutation, :, y_0:y_1, x_0:x_1] 73 | labels = m * labels + (1.0 - m) * labels[permutation, :] 74 | return inputs, labels, labels.argmax(1) 75 | 76 | 77 | def update_model_ema(model, model_ema, cur_epoch, cur_iter): 78 | update_period = cfg.OPTIM.EMA_UPDATE_PERIOD 79 | if update_period == 0 or cur_iter % update_period != 0: 80 | return 81 | adjust = cfg.TRAIN.BATCH_SIZE / cfg.OPTIM.MAX_EPOCH * update_period 82 | alpha = min(1.0, cfg.OPTIM.EMA_ALPHA * adjust) 83 | alpha = 1.0 if cur_epoch < cfg.OPTIM.WARMUP_EPOCHS else alpha 84 | params = unwrap_model(model).state_dict() 85 | for name, param in unwrap_model(model_ema).state_dict().items(): 86 | param.copy_(param * (1.0 - alpha) + params[name] * alpha) 87 | -------------------------------------------------------------------------------- /pycls/core/optimizer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Modified from pycls. 10 | https://github.com/facebookresearch/pycls/blob/main/pycls/core/optimizer.py 11 | """ 12 | 13 | import torch 14 | from timm.scheduler.cosine_lr import CosineLRScheduler 15 | 16 | from pycls.core.config import cfg 17 | from pycls.core.net import unwrap_model 18 | 19 | 20 | def reset_lr_weight_decay(model): 21 | body_lr = cfg.OPTIM.BASE_LR 22 | head_lr = cfg.OPTIM.HEAD_LR_RATIO * body_lr 23 | skip_list = ['cls_token', 'pos_embed', 'distill_token'] 24 | head_decay = [] 25 | head_no_decay = [] 26 | body_decay = [] 27 | body_no_decay = [] 28 | 29 | head_decay_name = [] 30 | head_no_decay_name = [] 31 | body_decay_name = [] 32 | body_no_decay_name = [] 33 | 34 | for name, param in unwrap_model(model).named_parameters(): 35 | if not param.requires_grad: 36 | continue # frozen weights 37 | if name.startswith("head."): 38 | if len(param.shape) == 1 or name.endswith(".bias"): 39 | head_no_decay.append(param) 40 | head_no_decay_name.append(name) 41 | else: 42 | head_decay.append(param) 43 | head_decay_name.append(name) 44 | else: 45 | skip = any([k in name for k in skip_list]) 46 | if len(param.shape) == 1 or name.endswith(".bias") or skip: 47 | body_no_decay.append(param) 48 | body_no_decay_name.append(name) 49 | else: 50 | body_decay.append(param) 51 | body_decay_name.append(name) 52 | return [ 53 | {'params': head_no_decay, 'lr': head_lr, 'weight_decay': 0.}, 54 | {'params': head_decay, 'lr': head_lr, 'weight_decay': cfg.OPTIM.WEIGHT_DECAY}, 55 | {'params': body_no_decay, 'lr': body_lr, 'weight_decay': 0.}, 56 | {'params': body_decay, 'lr': body_lr, 'weight_decay': cfg.OPTIM.WEIGHT_DECAY}] 57 | 58 | 59 | def construct_optimizer(model): 60 | optim = cfg.OPTIM 61 | param_wds = reset_lr_weight_decay(model) 62 | if optim.OPTIMIZER == "sgd": 63 | optimizer = torch.optim.SGD( 64 | param_wds, 65 | lr=optim.BASE_LR, 66 | momentum=optim.MOMENTUM, 67 | weight_decay=optim.WEIGHT_DECAY, 68 | dampening=optim.DAMPENING, 69 | nesterov=optim.NESTEROV, 70 | ) 71 | elif optim.OPTIMIZER == "adam": 72 | optimizer = torch.optim.Adam( 73 | param_wds, 74 | lr=optim.BASE_LR, 75 | betas=(optim.BETA1, optim.BETA2), 76 | weight_decay=optim.WEIGHT_DECAY, 77 | ) 78 | elif optim.OPTIMIZER == "adamw": 79 | optimizer = torch.optim.AdamW( 80 | param_wds, 81 | lr=optim.BASE_LR, 82 | betas=(optim.BETA1, optim.BETA2), 83 | weight_decay=optim.WEIGHT_DECAY, 84 | ) 85 | else: 86 | raise NotImplementedError 87 | return optimizer 88 | 89 | 90 | def construct_scheduler(optimizer): 91 | warmup_lr = cfg.OPTIM.WARMUP_FACTOR * cfg.OPTIM.BASE_LR 92 | if cfg.OPTIM.LR_POLICY == 'cos': 93 | scheduler = CosineLRScheduler( 94 | optimizer, 95 | t_initial=cfg.OPTIM.MAX_EPOCH, 96 | lr_min=cfg.OPTIM.MIN_LR, 97 | warmup_t=cfg.OPTIM.WARMUP_EPOCHS, 98 | warmup_lr_init=warmup_lr) 99 | else: 100 | raise NotImplementedError 101 | return scheduler 102 | 103 | 104 | def get_current_lr(optimizer): 105 | return optimizer.param_groups[0]['lr'] 106 | -------------------------------------------------------------------------------- /pycls/core/timer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Modified from pycls. 10 | https://github.com/facebookresearch/pycls/blob/main/pycls/core/timer.py 11 | """ 12 | 13 | import time 14 | 15 | 16 | class Timer(object): 17 | 18 | def __init__(self): 19 | self.total_time = None 20 | self.calls = None 21 | self.start_time = None 22 | self.diff = None 23 | self.average_time = None 24 | self.reset() 25 | 26 | def tic(self): 27 | self.start_time = time.time() 28 | 29 | def toc(self): 30 | self.diff = time.time() - self.start_time 31 | self.total_time += self.diff 32 | self.calls += 1 33 | self.average_time = self.total_time / self.calls 34 | 35 | def reset(self): 36 | self.total_time = 0.0 37 | self.calls = 0 38 | self.start_time = 0.0 39 | self.diff = 0.0 40 | self.average_time = 0.0 41 | -------------------------------------------------------------------------------- /pycls/core/trainer.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Modified from pycls. 10 | https://github.com/facebookresearch/pycls/blob/main/pycls/core/trainer.py 11 | """ 12 | 13 | import os 14 | import random 15 | import warnings 16 | 17 | import numpy as np 18 | import pycls.core.benchmark as benchmark 19 | import pycls.core.builders as builders 20 | import pycls.core.checkpoint as cp 21 | import pycls.core.config as config 22 | import pycls.core.distributed as dist 23 | import pycls.core.logging as logging 24 | import pycls.core.meters as meters 25 | import pycls.core.net as net 26 | import pycls.core.optimizer as optim 27 | import pycls.datasets.loader as data_loader 28 | import torch 29 | import torch.cuda.amp as amp 30 | from pycls.core.config import cfg 31 | from pycls.core.io import cache_url, pathmgr 32 | 33 | 34 | logger = logging.get_logger(__name__) 35 | 36 | 37 | def setup_env(): 38 | if dist.is_main_proc(): 39 | pathmgr.mkdirs(cfg.OUT_DIR) 40 | config.dump_cfg() 41 | logging.setup_logging() 42 | version = [torch.__version__, torch.version.cuda, torch.backends.cudnn.version()] 43 | logger.info("PyTorch Version: torch={}, cuda={}, cudnn={}".format(*version)) 44 | env = "".join([f"{key}: {value}\n" for key, value in sorted(os.environ.items())]) 45 | logger.info(f"os.environ:\n{env}") 46 | logger.info("Config:\n{}".format(cfg)) if cfg.VERBOSE else () 47 | logger.info(logging.dump_log_data(cfg, "cfg", None)) 48 | np.random.seed(cfg.RNG_SEED) 49 | torch.manual_seed(cfg.RNG_SEED) 50 | random.seed(cfg.RNG_SEED) 51 | torch.backends.cudnn.benchmark = cfg.CUDNN.BENCHMARK 52 | 53 | 54 | def setup_model(setup_ema=True): 55 | model = builders.build_model() 56 | logger.info("Model:\n{}".format(model)) if cfg.VERBOSE else () 57 | logger.info(logging.dump_log_data(net.unwrap_model(model).complexity(), "complexity")) 58 | cur_device = torch.cuda.current_device() 59 | model = model.cuda(device=cur_device) 60 | model_state = model.state_dict() 61 | if cfg.NUM_GPUS > 1: 62 | ddp = torch.nn.parallel.DistributedDataParallel 63 | model = ddp(module=model, device_ids=[cur_device], output_device=cur_device) 64 | if not setup_ema: 65 | return model 66 | else: 67 | ema = builders.build_model() 68 | ema = ema.cuda(device=cur_device) 69 | ema.load_state_dict(model_state) 70 | if cfg.NUM_GPUS > 1: 71 | ddp = torch.nn.parallel.DistributedDataParallel 72 | ema = ddp(module=ema, device_ids=[cur_device], output_device=cur_device) 73 | return model, ema 74 | 75 | 76 | def get_weights_file(weights_file): 77 | download = dist.is_main_proc(local=True) 78 | weights_file = cache_url(weights_file, cfg.DOWNLOAD_CACHE, download=download) 79 | if cfg.NUM_GPUS > 1: 80 | torch.distributed.barrier() 81 | return weights_file 82 | 83 | 84 | def train_epoch(loader, model, ema, loss_fun, optimizer, scheduler, scaler, meter, cur_epoch): 85 | data_loader.shuffle(loader, cur_epoch) 86 | lr = optim.get_current_lr(optimizer) 87 | model.train() 88 | ema.train() 89 | meter.reset() 90 | meter.iter_tic() 91 | for cur_iter, (inputs, labels, offline_features) in enumerate(loader): 92 | inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True) 93 | offline_features = [f.cuda() for f in offline_features] 94 | labels_one_hot = net.smooth_one_hot_labels(labels) 95 | inputs, labels_one_hot, labels = net.mixup(inputs, labels_one_hot) 96 | with amp.autocast(enabled=cfg.TRAIN.MIXED_PRECISION): 97 | preds = model(inputs) 98 | loss_cls = loss_fun(preds, labels_one_hot) 99 | loss, loss_inter, loss_logit = loss_cls, inputs.new_tensor(0.0), inputs.new_tensor(0.0) 100 | if hasattr(net.unwrap_model(model), 'guidance_loss'): 101 | loss_inter, loss_logit = net.unwrap_model(model).guidance_loss(inputs, offline_features) 102 | if cfg.DISTILLATION.ENABLE_LOGIT: 103 | loss_cls = loss_cls * (1 - cfg.DISTILLATION.LOGIT_WEIGHT) 104 | loss_logit = loss_logit * cfg.DISTILLATION.LOGIT_WEIGHT 105 | loss = loss_cls + loss_logit 106 | if cfg.DISTILLATION.ENABLE_INTER: 107 | loss_inter = loss_inter * cfg.DISTILLATION.INTER_WEIGHT 108 | loss = loss_cls + loss_inter 109 | optimizer.zero_grad() 110 | scaler.scale(loss).backward() 111 | scaler.step(optimizer) 112 | scaler.update() 113 | net.update_model_ema(model, ema, cur_epoch, cur_iter) 114 | top1_err, top5_err = meters.topk_errors(preds, labels, [1, 5]) 115 | loss_cls, loss_inter, loss_logit, loss, top1_err, top5_err = dist.scaled_all_reduce([loss_cls, loss_inter, loss_logit, loss, top1_err, top5_err]) 116 | loss_cls, loss_inter, loss_logit, loss, top1_err, top5_err = loss_cls.item(), loss_inter.item(), loss_logit.item(), loss.item(), top1_err.item(), top5_err.item() 117 | meter.iter_toc() 118 | mb_size = inputs.size(0) * cfg.NUM_GPUS 119 | meter.update_stats(top1_err, top5_err, loss_cls, loss_inter, loss_logit, loss, lr, mb_size) 120 | meter.log_iter_stats(cur_epoch, cur_iter) 121 | meter.iter_tic() 122 | meter.log_epoch_stats(cur_epoch) 123 | scheduler.step(cur_epoch + 1) 124 | 125 | 126 | @torch.no_grad() 127 | def test_epoch(loader, model, meter, cur_epoch): 128 | model.eval() 129 | meter.reset() 130 | meter.iter_tic() 131 | for cur_iter, (inputs, labels, _) in enumerate(loader): 132 | inputs, labels = inputs.cuda(), labels.cuda(non_blocking=True) 133 | preds = model(inputs) 134 | top1_err, top5_err = meters.topk_errors(preds, labels, [1, 5]) 135 | top1_err, top5_err = dist.scaled_all_reduce([top1_err, top5_err]) 136 | top1_err, top5_err = top1_err.item(), top5_err.item() 137 | meter.iter_toc() 138 | meter.update_stats(top1_err, top5_err, inputs.size(0) * cfg.NUM_GPUS) 139 | meter.log_iter_stats(cur_epoch, cur_iter) 140 | meter.iter_tic() 141 | meter.log_epoch_stats(cur_epoch) 142 | 143 | 144 | def train_model(): 145 | setup_env() 146 | model, ema = setup_model() 147 | loss_fun = builders.build_loss_fun().cuda() 148 | optimizer = optim.construct_optimizer(model) 149 | scheduler = optim.construct_scheduler(optimizer) 150 | start_epoch = 0 151 | if cfg.TRAIN.AUTO_RESUME and cp.has_checkpoint(): 152 | if cfg.DISTILLATION.ENABLE_INTER and cfg.DISTILLATION.INTER_TRANSFORM == 'linear': 153 | warnings.warn('Linear transform is not supported for resuming. This will cause the linear transformation to be trained from scratch.') 154 | file = cp.get_last_checkpoint() 155 | logger.info("Loaded checkpoint from: {}".format(file)) 156 | epoch = cp.load_checkpoint(file, model, ema, optimizer)[0] 157 | start_epoch = epoch + 1 158 | elif cfg.TRAIN.WEIGHTS: 159 | train_weights = get_weights_file(cfg.TRAIN.WEIGHTS) 160 | logger.info("Loaded initial weights from: {}".format(train_weights)) 161 | cp.load_checkpoint(train_weights, model, ema) 162 | train_loader = data_loader.construct_train_loader() 163 | test_loader = data_loader.construct_test_loader() 164 | train_meter = meters.TrainMeter(len(train_loader)) 165 | test_meter = meters.TestMeter(len(test_loader)) 166 | ema_meter = meters.TestMeter(len(test_loader), "test_ema") 167 | scaler = amp.GradScaler(enabled=cfg.TRAIN.MIXED_PRECISION) 168 | if start_epoch == 0 and cfg.PREC_TIME.NUM_ITER > 0: 169 | benchmark.compute_time_full(model, loss_fun, train_loader, test_loader) 170 | logger.info("Start epoch: {}".format(start_epoch + 1)) 171 | for cur_epoch in range(start_epoch, cfg.OPTIM.MAX_EPOCH): 172 | params = (train_loader, model, ema, loss_fun, optimizer, scheduler, scaler, train_meter) 173 | train_epoch(*params, cur_epoch) 174 | test_epoch(test_loader, model, test_meter, cur_epoch) 175 | test_err = test_meter.get_epoch_stats(cur_epoch)["top1_err"] 176 | ema_err = 100.0 177 | if cfg.OPTIM.EMA_UPDATE_PERIOD > 0: 178 | test_epoch(test_loader, ema, ema_meter, cur_epoch) 179 | ema_err = ema_meter.get_epoch_stats(cur_epoch)["top1_err"] 180 | file = cp.save_checkpoint(model, ema, optimizer, cur_epoch, test_err, ema_err) 181 | logger.info("Wrote checkpoint to: {}".format(file)) 182 | 183 | 184 | def test_model(): 185 | setup_env() 186 | model = setup_model(setup_ema=False) 187 | test_weights = get_weights_file(cfg.TEST.WEIGHTS) 188 | cp.load_checkpoint(test_weights, model) 189 | logger.info("Loaded model weights from: {}".format(test_weights)) 190 | test_loader = data_loader.construct_test_loader() 191 | test_meter = meters.TestMeter(len(test_loader)) 192 | test_epoch(test_loader, model, test_meter, 0) 193 | 194 | 195 | def time_model(): 196 | setup_env() 197 | model = setup_model(setup_ema=False) 198 | loss_fun = builders.build_loss_fun().cuda() 199 | benchmark.compute_time_model(model, loss_fun) 200 | 201 | 202 | def time_model_and_loader(): 203 | setup_env() 204 | model = setup_model(setup_ema=False) 205 | loss_fun = builders.build_loss_fun().cuda() 206 | train_loader = data_loader.construct_train_loader() 207 | test_loader = data_loader.construct_test_loader() 208 | benchmark.compute_time_full(model, loss_fun, train_loader, test_loader) 209 | -------------------------------------------------------------------------------- /pycls/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lkhl/tiny-transformers/d2165f74049c906b0afc9f957491960fb3c0cc8b/pycls/datasets/__init__.py -------------------------------------------------------------------------------- /pycls/datasets/base.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from abc import ABCMeta, abstractmethod 3 | 4 | import torch 5 | import numpy as np 6 | from pycls.core.config import cfg 7 | from .transforms import create_train_transform, create_test_transform 8 | 9 | 10 | class BaseDataset(Dataset, metaclass=ABCMeta): 11 | """ 12 | Base class for dataset with support for offline distillation. 13 | """ 14 | 15 | def __init__(self, split): 16 | if split == 'train': 17 | transforms = create_train_transform() 18 | else: 19 | transforms = create_test_transform() 20 | self.primary_tfl, self.secondary_tfl, self.final_tfl = transforms 21 | 22 | self.features = None 23 | if cfg.DISTILLATION.OFFLINE and split == 'train': 24 | features = [] 25 | kd_data = np.load(cfg.DISTILLATION.FEATURE_FILE) 26 | for i in range(len(kd_data.files)): 27 | features.append(kd_data[f'layer_{i}']) 28 | self.features = features 29 | 30 | @abstractmethod 31 | def _get_data(self, index): 32 | """ 33 | Returns the image and its label at index. 34 | """ 35 | pass 36 | 37 | def __getitem__(self, index): 38 | img, label = self._get_data(index) 39 | if self.features: 40 | features = [torch.from_numpy(f[index].copy()) for f in self.features] 41 | for t in self.primary_tfl: 42 | img, features = t(img, features) 43 | else: 44 | img = self.primary_tfl(img) 45 | features = [] 46 | 47 | img = self.secondary_tfl(img) 48 | img = self.final_tfl(img) 49 | 50 | return img, label, features 51 | -------------------------------------------------------------------------------- /pycls/datasets/chaoyang.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from PIL import Image 4 | 5 | from .base import BaseDataset 6 | from pycls.core.io import pathmgr 7 | 8 | 9 | class Chaoyang(BaseDataset): 10 | 11 | def __init__(self, data_path, split): 12 | super(Chaoyang, self).__init__(split) 13 | assert pathmgr.exists(data_path), "Data path '{}' not found".format(data_path) 14 | splits = ["train", "test"] 15 | assert split in splits, "Split '{}' not supported for Chaoyang".format(split) 16 | self.data_path = data_path 17 | with open(os.path.join(data_path, f'{split}.json'), 'r') as f: 18 | anns = json.load(f) 19 | self.data = anns 20 | 21 | def __len__(self): 22 | return len(self.data) 23 | 24 | def _get_data(self, index): 25 | ann = self.data[index] 26 | img = Image.open(os.path.join(self.data_path, ann['name'])) 27 | return img, ann['label'] 28 | -------------------------------------------------------------------------------- /pycls/datasets/cifar100.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import CIFAR100 2 | 3 | from .base import BaseDataset 4 | from pycls.core.io import pathmgr 5 | 6 | 7 | class Cifar100(BaseDataset): 8 | 9 | def __init__(self, data_path, split): 10 | super(Cifar100, self).__init__(split) 11 | assert pathmgr.exists(data_path), "Data path '{}' not found".format(data_path) 12 | splits = ["train", "test"] 13 | assert split in splits, "Split '{}' not supported for cifar".format(split) 14 | self.database = CIFAR100(root=data_path, train=split=='train', download=True) 15 | 16 | def __len__(self): 17 | return len(self.database) 18 | 19 | def _get_data(self, index): 20 | return self.database[index] 21 | -------------------------------------------------------------------------------- /pycls/datasets/flowers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | from scipy.io import loadmat 5 | 6 | from .base import BaseDataset 7 | from pycls.core.io import pathmgr 8 | 9 | 10 | class Flowers(BaseDataset): 11 | 12 | def __init__(self, data_path, split): 13 | super(Flowers, self).__init__(split) 14 | assert pathmgr.exists(data_path), "Data path '{}' not found".format(data_path) 15 | splits = ["train", "test"] 16 | assert split in splits, "Split '{}' not supported for Flowers".format(split) 17 | self.data_path = data_path 18 | self.labels = loadmat(os.path.join(data_path, 'imagelabels.mat'))['labels'][0] - 1 19 | all_files = loadmat(os.path.join(data_path, 'setid.mat')) 20 | if split == 'train': 21 | self.ids = np.concatenate([all_files['trnid'][0], all_files['valid'][0]]) 22 | else: 23 | self.ids = all_files['tstid'][0] 24 | 25 | def __len__(self): 26 | return len(self.ids) 27 | 28 | def _get_data(self, idx): 29 | label = self.labels[self.ids[idx] - 1] 30 | fname = 'image_%05d.jpg'%self.ids[idx] 31 | img = Image.open(os.path.join(self.data_path, 'jpg', fname)) 32 | return img, label 33 | -------------------------------------------------------------------------------- /pycls/datasets/loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Modified from pycls. 10 | https://github.com/facebookresearch/pycls/blob/main/pycls/datasets/loader.py 11 | """ 12 | 13 | import os 14 | 15 | import torch 16 | from pycls.core.config import cfg 17 | from pycls.datasets.cifar100 import Cifar100 18 | from pycls.datasets.flowers import Flowers 19 | from pycls.datasets.chaoyang import Chaoyang 20 | from pycls.datasets.tiny_imagenet import TinyImageNet 21 | from torch.utils.data.distributed import DistributedSampler 22 | from torch.utils.data.sampler import RandomSampler 23 | 24 | 25 | _DATASETS = {'cifar100': Cifar100, 'flowers': Flowers, "chaoyang": Chaoyang, "tiny_imagenet": TinyImageNet} 26 | 27 | _DATA_DIR = os.path.join('.', "data") 28 | if not os.path.exists(_DATA_DIR): 29 | os.makedirs(_DATA_DIR) 30 | 31 | _PATHS = {"cifar100": "", 'flowers': "flowers", "chaoyang": "chaoyang", "tiny_imagenet": "tiny-imagenet-200"} 32 | 33 | 34 | def _construct_loader(dataset_name, split, batch_size, shuffle, drop_last): 35 | err_str = "Dataset '{}' not supported".format(dataset_name) 36 | assert dataset_name in _DATASETS and dataset_name in _PATHS, err_str 37 | data_path = os.path.join(_DATA_DIR, _PATHS[dataset_name]) 38 | dataset = _DATASETS[dataset_name](data_path, split) 39 | sampler = DistributedSampler(dataset) if cfg.NUM_GPUS > 1 else None 40 | loader = torch.utils.data.DataLoader( 41 | dataset, 42 | batch_size=batch_size, 43 | shuffle=(False if sampler else shuffle), 44 | sampler=sampler, 45 | num_workers=cfg.DATA_LOADER.NUM_WORKERS, 46 | pin_memory=cfg.DATA_LOADER.PIN_MEMORY, 47 | drop_last=drop_last, 48 | ) 49 | return loader 50 | 51 | 52 | def construct_train_loader(): 53 | return _construct_loader( 54 | dataset_name=cfg.TRAIN.DATASET, 55 | split=cfg.TRAIN.SPLIT, 56 | batch_size=int(cfg.TRAIN.BATCH_SIZE / cfg.NUM_GPUS), 57 | shuffle=True, 58 | drop_last=True, 59 | ) 60 | 61 | 62 | def construct_test_loader(): 63 | return _construct_loader( 64 | dataset_name=cfg.TEST.DATASET, 65 | split=cfg.TEST.SPLIT, 66 | batch_size=int(cfg.TEST.BATCH_SIZE / cfg.NUM_GPUS), 67 | shuffle=False, 68 | drop_last=False, 69 | ) 70 | 71 | 72 | def shuffle(loader, cur_epoch): 73 | err_str = "Sampler type '{}' not supported".format(type(loader.sampler)) 74 | assert isinstance(loader.sampler, (RandomSampler, DistributedSampler)), err_str 75 | if isinstance(loader.sampler, DistributedSampler): 76 | loader.sampler.set_epoch(cur_epoch) 77 | -------------------------------------------------------------------------------- /pycls/datasets/tiny_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | 7 | import pycls.core.logging as logging 8 | 9 | 10 | logger = logging.get_logger(__name__) 11 | 12 | 13 | class TinyImageNet(Dataset): 14 | 15 | def __init__(self, data_path, split): 16 | super(TinyImageNet, self).__init__(split) 17 | assert os.path.exists(data_path), "Data path '{}' not found".format(data_path) 18 | splits = ["train", "val"] 19 | assert split in splits, "Split '{}' not supported for Tiny ImageNet".format(split) 20 | logger.info("Constructing Tiny ImageNet {}...".format(split)) 21 | self._data_path, self._split = data_path, split 22 | self._construct_imdb() 23 | 24 | def _construct_imdb(self): 25 | split_path = os.path.join(self._data_path, self._split) 26 | logger.info("{} data path: {}".format(self._split, split_path)) 27 | 28 | if self._split == 'train': 29 | split_files = os.listdir(split_path) 30 | self._class_ids = sorted(f for f in split_files if re.match(r"^n[0-9]+$", f)) 31 | self._class_id_cont_id = {v: i for i, v in enumerate(self._class_ids)} 32 | self._imdb = [] 33 | for class_id in self._class_ids: 34 | cont_id = self._class_id_cont_id[class_id] 35 | im_dir = os.path.join(split_path, class_id, 'images') 36 | for im_name in os.listdir(im_dir): 37 | im_path = os.path.join(im_dir, im_name) 38 | self._imdb.append({"im_path": im_path, "class": cont_id}) 39 | else: 40 | class_ids = set() 41 | with open(os.path.join(split_path, 'val_annotations.txt')) as f: 42 | for line in f.readlines(): 43 | class_ids.add(line.split()[1]) 44 | self._class_ids = sorted(f for f in class_ids if re.match(r"^n[0-9]+$", f)) 45 | self._class_id_cont_id = {v: i for i, v in enumerate(self._class_ids)} 46 | self._imdb = [] 47 | im_dir = os.path.join(split_path, 'images') 48 | with open(os.path.join(split_path, 'val_annotations.txt')) as f: 49 | for line in f.readlines(): 50 | im_name = line.split()[0] 51 | class_id = line.split()[1] 52 | im_path = os.path.join(im_dir, im_name) 53 | cont_id = self._class_id_cont_id[class_id] 54 | self._imdb.append({"im_path": im_path, "class": cont_id}) 55 | 56 | logger.info("Number of images: {}".format(len(self._imdb))) 57 | logger.info("Number of classes: {}".format(len(self._class_ids))) 58 | 59 | def __len__(self): 60 | return len(self._imdb) 61 | 62 | def _get_data(self, index): 63 | img = Image.open(self._imdb[index]["im_path"]).convert('RGB') 64 | label = self._imdb[index]["class"] 65 | return img, label 66 | -------------------------------------------------------------------------------- /pycls/datasets/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | from pycls.core.config import cfg 4 | 5 | import torchvision.transforms.functional as F 6 | import torchvision.transforms as transforms 7 | 8 | from timm.data import create_transform 9 | from timm.data.transforms import RandomResizedCropAndInterpolation as _RandomResizedCropAndInterpolation 10 | 11 | 12 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) 13 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) 14 | 15 | 16 | class RandomResizedCropAndInterpolation(_RandomResizedCropAndInterpolation): 17 | 18 | def __call__(self, img, features): 19 | i, j, h, w = self.get_params(img, self.scale, self.ratio) 20 | if isinstance(self.interpolation, (tuple, list)): 21 | interpolation = random.choice(self.interpolation) 22 | else: 23 | interpolation = self.interpolation 24 | 25 | out_img = F.resized_crop(img, i, j, h, w, self.size, interpolation) 26 | 27 | i, j, h, w = i / img.size[1], j / img.size[0], h / img.size[1], w / img.size[0] 28 | out_feats = [] 29 | for feat in features: 30 | feat_h, feat_w = feat.shape[-2:] 31 | feat = F.resized_crop(feat, int(i*feat_h), int(j*feat_w), int(h*feat_h), int(w*feat_w), size=(feat_h, feat_w)) 32 | out_feats.append(feat) 33 | 34 | return out_img, out_feats 35 | 36 | 37 | class RandomHorizontalFlip(transforms.RandomHorizontalFlip): 38 | 39 | def forward(self, img, features): 40 | if torch.rand(1) < self.p: 41 | out_img = F.hflip(img) 42 | out_feats = [] 43 | for feat in features: 44 | out_feats.append(F.hflip(feat)) 45 | return out_img, out_feats 46 | return img, features 47 | 48 | 49 | def create_train_transform(mean=None, std=None): 50 | mean = IMAGENET_DEFAULT_MEAN if mean is None else mean 51 | std = IMAGENET_DEFAULT_STD if std is None else std 52 | 53 | size = (cfg.MODEL.IMG_SIZE, cfg.MODEL.IMG_SIZE) 54 | transform = create_transform( 55 | input_size=size, 56 | is_training=True, 57 | color_jitter=0.4, 58 | auto_augment='rand-m9-mstd0.5-inc1', 59 | re_prob=0.25, 60 | re_mode='pixel', 61 | re_count=1, 62 | interpolation='bicubic', 63 | separate=True, 64 | mean=mean, 65 | std=std) 66 | primary_tfl, secondary_tfl, final_tfl = transform 67 | 68 | if cfg.DISTILLATION.OFFLINE: 69 | primary_tfl = [ 70 | RandomResizedCropAndInterpolation(size, interpolation='bicubic'), 71 | RandomHorizontalFlip(p=0.5)] 72 | if not cfg.TRAIN.STRONG_AUGMENTATION: 73 | primary_tfl = transforms.Compose([ 74 | transforms.Resize(size), 75 | transforms.RandomCrop(size, padding=cfg.MODEL.IMG_SIZE//8), 76 | transforms.RandomHorizontalFlip(p=0.5)]) 77 | secondary_tfl = transforms.Compose([]) 78 | final_tfl = transforms.Compose([ 79 | transforms.ToTensor(), 80 | transforms.Normalize( 81 | mean=torch.tensor(mean), 82 | std=torch.tensor(std))]) 83 | 84 | return primary_tfl, secondary_tfl, final_tfl 85 | 86 | 87 | def create_test_transform(mean=None, std=None): 88 | mean = IMAGENET_DEFAULT_MEAN if mean is None else mean 89 | std = IMAGENET_DEFAULT_STD if std is None else std 90 | 91 | primary_tfl = transforms.Resize((cfg.MODEL.IMG_SIZE, cfg.MODEL.IMG_SIZE)) 92 | secondary_tfl = transforms.Compose([]) 93 | final_tfl = transforms.Compose([ 94 | transforms.ToTensor(), 95 | transforms.Normalize( 96 | mean=torch.tensor(IMAGENET_DEFAULT_MEAN), 97 | std=torch.tensor(IMAGENET_DEFAULT_STD))]) 98 | return primary_tfl, secondary_tfl, final_tfl 99 | -------------------------------------------------------------------------------- /pycls/models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import cnns, transformers 2 | -------------------------------------------------------------------------------- /pycls/models/build.py: -------------------------------------------------------------------------------- 1 | from fvcore.common.registry import Registry 2 | 3 | from pycls.core.config import cfg 4 | from .distill import DistillationWrapper 5 | 6 | 7 | MODEL = Registry('MODEL') 8 | 9 | 10 | def build_model(): 11 | model = MODEL.get(cfg.MODEL.TYPE)() 12 | if cfg.DISTILLATION.ENABLE_INTER or cfg.DISTILLATION.ENABLE_LOGIT: 13 | teacher_mode = MODEL.get(cfg.DISTILLATION.TEACHER_MODEL)() 14 | model = DistillationWrapper(model, teacher_mode) 15 | return model 16 | -------------------------------------------------------------------------------- /pycls/models/cnns/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import ResNet 2 | -------------------------------------------------------------------------------- /pycls/models/cnns/base.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from abc import ABCMeta, abstractmethod 3 | 4 | from pycls.core.config import cfg 5 | 6 | 7 | class BaseConvModel(nn.Module, metaclass=ABCMeta): 8 | """ 9 | Base class for conv models. 10 | 11 | Attributes: 12 | - self.features (List[Tensor]): the features in each stage. 13 | - self.feature_dims (List[int]): the dimension of features in each stage. 14 | """ 15 | 16 | def __init__(self): 17 | super(BaseConvModel, self).__init__() 18 | self.depth = cfg.CNN.DEPTH 19 | self.img_size = cfg.MODEL.IMG_SIZE 20 | self.in_channels = cfg.MODEL.IN_CHANNELS 21 | self.num_classes = cfg.MODEL.NUM_CLASSES 22 | self.features = list() 23 | self.feature_dims = None 24 | 25 | def initialize_hooks(self, layers, feature_dims): 26 | """ 27 | Initialize hooks for the given layers. 28 | """ 29 | for layer in layers: 30 | layer.register_forward_hook(self._feature_hook) 31 | self.feature_dims = feature_dims 32 | self.register_forward_pre_hook(lambda module, inp: self.features.clear()) 33 | 34 | @abstractmethod 35 | def _feature_hook(self, module, inputs, outputs): 36 | pass 37 | 38 | def complexity(self): 39 | params = sum(p.numel() for p in self.parameters()) 40 | return {'params': f'{round(params/1e6, 2)}M'} 41 | -------------------------------------------------------------------------------- /pycls/models/cnns/blocks.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Modified from pycls. 10 | https://github.com/facebookresearch/pycls/blob/main/pycls/models/blocks.py 11 | """ 12 | 13 | import numpy as np 14 | import torch 15 | import torch.nn as nn 16 | from pycls.core.config import cfg 17 | from torch.nn import Module 18 | 19 | 20 | def conv2d(w_in, w_out, k, *, stride=1, groups=1, bias=False): 21 | assert k % 2 == 1, "Only odd size kernels supported to avoid padding issues." 22 | s, p, g, b = stride, (k - 1) // 2, groups, bias 23 | return nn.Conv2d(w_in, w_out, k, stride=s, padding=p, groups=g, bias=b) 24 | 25 | 26 | def norm2d(w_in): 27 | return nn.BatchNorm2d(num_features=w_in, eps=cfg.CNN.BN_EPS, momentum=cfg.CNN.BN_MOMENTUM) 28 | 29 | 30 | def pool2d(_w_in, k, *, stride=1): 31 | assert k % 2 == 1, "Only odd size kernels supported to avoid padding issues." 32 | return nn.MaxPool2d(k, stride=stride, padding=(k - 1) // 2) 33 | 34 | 35 | def gap2d(_w_in): 36 | return nn.AdaptiveAvgPool2d((1, 1)) 37 | 38 | 39 | def linear(w_in, w_out, *, bias=False): 40 | return nn.Linear(w_in, w_out, bias=bias) 41 | 42 | 43 | def activation(activation_fun=None): 44 | activation_fun = (activation_fun or cfg.CNN.ACTIVATION_FUN).lower() 45 | if activation_fun == "relu": 46 | return nn.ReLU(inplace=cfg.CNN.ACTIVATION_INPLACE) 47 | elif activation_fun == "silu" or activation_fun == "swish": 48 | try: 49 | return torch.nn.SiLU() 50 | except AttributeError: 51 | return SiLU() 52 | elif activation_fun == "gelu": 53 | return torch.nn.GELU() 54 | else: 55 | raise AssertionError("Unknown MODEL.ACTIVATION_FUN: " + activation_fun) 56 | 57 | 58 | def conv2d_cx(cx, w_in, w_out, k, *, stride=1, groups=1, bias=False): 59 | assert k % 2 == 1, "Only odd size kernels supported to avoid padding issues." 60 | h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"] 61 | h, w = (h - 1) // stride + 1, (w - 1) // stride + 1 62 | flops += k * k * w_in * w_out * h * w // groups + (w_out * h * w if bias else 0) 63 | params += k * k * w_in * w_out // groups + (w_out if bias else 0) 64 | acts += w_out * h * w 65 | return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts} 66 | 67 | 68 | def norm2d_cx(cx, w_in): 69 | h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"] 70 | params += 2 * w_in 71 | return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts} 72 | 73 | 74 | def pool2d_cx(cx, w_in, k, *, stride=1): 75 | assert k % 2 == 1, "Only odd size kernels supported to avoid padding issues." 76 | h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"] 77 | h, w = (h - 1) // stride + 1, (w - 1) // stride + 1 78 | acts += w_in * h * w 79 | return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts} 80 | 81 | 82 | def gap2d_cx(cx, _w_in): 83 | flops, params, acts = cx["flops"], cx["params"], cx["acts"] 84 | return {"h": 1, "w": 1, "flops": flops, "params": params, "acts": acts} 85 | 86 | 87 | def layernorm_cx(cx, w_in): 88 | h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"] 89 | params += 2 * w_in 90 | return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts} 91 | 92 | 93 | def linear_cx(cx, w_in, w_out, *, bias=False, num_locations=1): 94 | h, w, flops, params, acts = cx["h"], cx["w"], cx["flops"], cx["params"], cx["acts"] 95 | flops += w_in * w_out * num_locations + (w_out * num_locations if bias else 0) 96 | params += w_in * w_out + (w_out if bias else 0) 97 | acts += w_out * num_locations 98 | return {"h": h, "w": w, "flops": flops, "params": params, "acts": acts} 99 | 100 | 101 | class SiLU(Module): 102 | 103 | def __init__(self): 104 | super(SiLU, self).__init__() 105 | 106 | def forward(self, x): 107 | return x * torch.sigmoid(x) 108 | 109 | 110 | class SE(Module): 111 | 112 | def __init__(self, w_in, w_se): 113 | super(SE, self).__init__() 114 | self.avg_pool = gap2d(w_in) 115 | self.f_ex = nn.Sequential( 116 | conv2d(w_in, w_se, 1, bias=True), 117 | activation(), 118 | conv2d(w_se, w_in, 1, bias=True), 119 | nn.Sigmoid(), 120 | ) 121 | 122 | def forward(self, x): 123 | return x * self.f_ex(self.avg_pool(x)) 124 | 125 | @staticmethod 126 | def complexity(cx, w_in, w_se): 127 | h, w = cx["h"], cx["w"] 128 | cx = gap2d_cx(cx, w_in) 129 | cx = conv2d_cx(cx, w_in, w_se, 1, bias=True) 130 | cx = conv2d_cx(cx, w_se, w_in, 1, bias=True) 131 | cx["h"], cx["w"] = h, w 132 | return cx 133 | 134 | 135 | def adjust_block_compatibility(ws, bs, gs): 136 | assert len(ws) == len(bs) == len(gs) 137 | assert all(w > 0 and b > 0 and g > 0 for w, b, g in zip(ws, bs, gs)) 138 | assert all(b < 1 or b % 1 == 0 for b in bs) 139 | vs = [int(max(1, w * b)) for w, b in zip(ws, bs)] 140 | gs = [int(min(g, v)) for g, v in zip(gs, vs)] 141 | ms = [np.lcm(g, int(b)) if b > 1 else g for g, b in zip(gs, bs)] 142 | vs = [max(m, int(round(v / m) * m)) for v, m in zip(vs, ms)] 143 | ws = [int(v / b) for v, b in zip(vs, bs)] 144 | assert all(w * b % g == 0 for w, b, g in zip(ws, bs, gs)) 145 | return ws, bs, gs 146 | 147 | 148 | def init_weights(m): 149 | if isinstance(m, nn.Conv2d): 150 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 151 | m.weight.data.normal_(mean=0.0, std=np.sqrt(2.0 / fan_out)) 152 | elif isinstance(m, nn.BatchNorm2d): 153 | zero_init_gamma = cfg.CNN.ZERO_INIT_FINAL_BN_GAMMA 154 | zero_init_gamma = hasattr(m, "final_bn") and m.final_bn and zero_init_gamma 155 | m.weight.data.fill_(0.0 if zero_init_gamma else 1.0) 156 | m.bias.data.zero_() 157 | elif isinstance(m, nn.Linear): 158 | m.weight.data.normal_(mean=0.0, std=0.01) 159 | m.bias.data.zero_() 160 | 161 | 162 | def drop_connect(x, drop_ratio): 163 | keep_ratio = 1.0 - drop_ratio 164 | mask = torch.empty([x.shape[0], 1, 1, 1], dtype=x.dtype, device=x.device) 165 | mask.bernoulli_(keep_ratio) 166 | x.div_(keep_ratio) 167 | x.mul_(mask) 168 | return x 169 | -------------------------------------------------------------------------------- /pycls/models/cnns/resnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Modified from pycls. 10 | https://github.com/facebookresearch/pycls/blob/main/pycls/models/resnet.py 11 | """ 12 | 13 | from torch.nn import Module 14 | 15 | from ..build import MODEL 16 | from .base import BaseConvModel 17 | from pycls.core.config import cfg 18 | from .blocks import ( 19 | activation, 20 | conv2d, 21 | gap2d, 22 | init_weights, 23 | linear, 24 | norm2d, 25 | pool2d, 26 | ) 27 | 28 | 29 | _IN_STAGE_DS = { 30 | 18: (2, 2, 2, 2), 34: (3, 4, 6, 3), 31 | 50: (3, 4, 6, 3), 101: (3, 4, 23, 3), 152: (3, 8, 36, 3)} 32 | 33 | 34 | def get_trans_fun(name): 35 | trans_funs = { 36 | "basic_transform": BasicTransform, 37 | "bottleneck_transform": BottleneckTransform, 38 | } 39 | err_str = "Transformation function '{}' not supported" 40 | assert name in trans_funs.keys(), err_str.format(name) 41 | return trans_funs[name] 42 | 43 | 44 | class ResHead(Module): 45 | 46 | def __init__(self, w_in, num_classes): 47 | super(ResHead, self).__init__() 48 | self.avg_pool = gap2d(w_in) 49 | self.fc = linear(w_in, num_classes, bias=True) 50 | 51 | def forward(self, x): 52 | x = self.avg_pool(x) 53 | x = x.view(x.size(0), -1) 54 | x = self.fc(x) 55 | return x 56 | 57 | 58 | class BasicTransform(Module): 59 | 60 | def __init__(self, w_in, w_out, stride, w_b=None, groups=1): 61 | err_str = "Basic transform does not support w_b and groups options" 62 | assert w_b is None and groups == 1, err_str 63 | super(BasicTransform, self).__init__() 64 | self.a = conv2d(w_in, w_out, 3, stride=stride) 65 | self.a_bn = norm2d(w_out) 66 | self.a_af = activation() 67 | self.b = conv2d(w_out, w_out, 3) 68 | self.b_bn = norm2d(w_out) 69 | self.b_bn.final_bn = True 70 | 71 | def forward(self, x): 72 | for layer in self.children(): 73 | x = layer(x) 74 | return x 75 | 76 | 77 | class BottleneckTransform(Module): 78 | 79 | def __init__(self, w_in, w_out, stride, w_b, groups): 80 | super(BottleneckTransform, self).__init__() 81 | (s1, s3) = (stride, 1) if cfg.RESNET.STRIDE_1X1 else (1, stride) 82 | self.a = conv2d(w_in, w_b, 1, stride=s1) 83 | self.a_bn = norm2d(w_b) 84 | self.a_af = activation() 85 | self.b = conv2d(w_b, w_b, 3, stride=s3, groups=groups) 86 | self.b_bn = norm2d(w_b) 87 | self.b_af = activation() 88 | self.c = conv2d(w_b, w_out, 1) 89 | self.c_bn = norm2d(w_out) 90 | self.c_bn.final_bn = True 91 | 92 | def forward(self, x): 93 | for layer in self.children(): 94 | x = layer(x) 95 | return x 96 | 97 | 98 | class ResBlock(Module): 99 | 100 | def __init__(self, w_in, w_out, stride, trans_fun, w_b=None, groups=1): 101 | super(ResBlock, self).__init__() 102 | self.proj, self.bn = None, None 103 | if (w_in != w_out) or (stride != 1): 104 | self.proj = conv2d(w_in, w_out, 1, stride=stride) 105 | self.bn = norm2d(w_out) 106 | self.f = trans_fun(w_in, w_out, stride, w_b, groups) 107 | self.af = activation() 108 | 109 | def forward(self, x): 110 | x_p = self.bn(self.proj(x)) if self.proj else x 111 | return self.af(x_p + self.f(x)) 112 | 113 | 114 | class ResStage(Module): 115 | 116 | def __init__(self, w_in, w_out, stride, d, w_b=None, groups=1): 117 | super(ResStage, self).__init__() 118 | for i in range(d): 119 | b_stride = stride if i == 0 else 1 120 | b_w_in = w_in if i == 0 else w_out 121 | trans_fun = get_trans_fun(cfg.RESNET.TRANS_FUN) 122 | res_block = ResBlock(b_w_in, w_out, b_stride, trans_fun, w_b, groups) 123 | self.add_module("b{}".format(i + 1), res_block) 124 | self.out_channels = w_out 125 | 126 | def forward(self, x): 127 | for block in self.children(): 128 | x = block(x) 129 | return x 130 | 131 | 132 | class ResStemCifar(Module): 133 | 134 | def __init__(self, w_in, w_out): 135 | super(ResStemCifar, self).__init__() 136 | self.conv = conv2d(w_in, w_out, 3) 137 | self.bn = norm2d(w_out) 138 | self.af = activation() 139 | 140 | def forward(self, x): 141 | for layer in self.children(): 142 | x = layer(x) 143 | return x 144 | 145 | 146 | class ResStemIN(Module): 147 | 148 | def __init__(self, w_in, w_out): 149 | super(ResStemIN, self).__init__() 150 | self.conv = conv2d(w_in, w_out, 7, stride=2) 151 | self.bn = norm2d(w_out) 152 | self.af = activation() 153 | self.pool = pool2d(w_out, 3, stride=2) 154 | 155 | def forward(self, x): 156 | for layer in self.children(): 157 | x = layer(x) 158 | return x 159 | 160 | 161 | @MODEL.register() 162 | class ResNet(BaseConvModel): 163 | 164 | def __init__(self): 165 | super(ResNet, self).__init__() 166 | if self.depth in _IN_STAGE_DS: 167 | self._construct_imagenet() 168 | else: 169 | self._construct_cifar() 170 | layers = [m for m in self.modules() if isinstance(m, ResStage)] 171 | feature_dims = [m.out_channels for m in layers] 172 | self.initialize_hooks(layers, feature_dims) 173 | self.apply(init_weights) 174 | 175 | def _feature_hook(self, module, inputs, outputs): 176 | self.features.append(outputs) 177 | 178 | def _construct_cifar(self): 179 | err_str = "Model depth should be of the format 6n + 2 for cifar" 180 | assert (self.depth - 2) % 6 == 0, err_str 181 | d = int((self.depth - 2) / 6) 182 | self.stem = ResStemCifar(self.in_channels, 16) 183 | self.s1 = ResStage(16, 16, stride=1, d=d) 184 | self.s2 = ResStage(16, 32, stride=2, d=d) 185 | self.s3 = ResStage(32, 64, stride=2, d=d) 186 | self.head = ResHead(64, self.num_classes) 187 | 188 | def _construct_imagenet(self): 189 | g, gw = cfg.RESNET.NUM_GROUPS, cfg.RESNET.WIDTH_PER_GROUP 190 | (d1, d2, d3, d4) = _IN_STAGE_DS[self.depth] 191 | w_b = gw * g 192 | self.stem = ResStemIN(self.in_channels, 64) 193 | if cfg.RESNET.TRANS_FUN == 'bottleneck_transform': 194 | self.s1 = ResStage(64, 256, stride=1, d=d1, w_b=w_b, groups=g) 195 | self.s2 = ResStage(256, 512, stride=2, d=d2, w_b=w_b * 2, groups=g) 196 | self.s3 = ResStage(512, 1024, stride=2, d=d3, w_b=w_b * 4, groups=g) 197 | self.s4 = ResStage(1024, 2048, stride=2, d=d4, w_b=w_b * 8, groups=g) 198 | self.head = ResHead(2048, self.num_classes) 199 | else: 200 | self.s1 = ResStage(64, 64, stride=1, d=d1) 201 | self.s2 = ResStage(64, 128, stride=2, d=d2) 202 | self.s3 = ResStage(128, 256, stride=2, d=d3) 203 | self.s4 = ResStage(256, 512, stride=2, d=d4) 204 | self.head = ResHead(512, self.num_classes) 205 | 206 | def forward(self, x): 207 | for module in self.children(): 208 | x = module(x) 209 | return x 210 | -------------------------------------------------------------------------------- /pycls/models/distill.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from pycls.core.config import cfg 6 | import pycls.core.logging as logging 7 | 8 | 9 | logger = logging.get_logger(__name__) 10 | 11 | 12 | def attention_transform(feat): 13 | return F.normalize(feat.pow(2).mean(1).view(feat.size(0), -1)) 14 | 15 | 16 | def similarity_transform(feat): 17 | feat = feat.view(feat.size(0), -1) 18 | gram = feat @ feat.t() 19 | return F.normalize(gram) 20 | 21 | 22 | _TRANS_FUNC = {"attention": attention_transform, "similarity": similarity_transform, "linear": lambda x : x} 23 | 24 | 25 | def inter_distill_loss(feat_t, feat_s, transform_type): 26 | assert transform_type in _TRANS_FUNC, f"Transformation function {transform_type} is not supported." 27 | trans_func = _TRANS_FUNC[transform_type] 28 | feat_t = trans_func(feat_t) 29 | feat_s = trans_func(feat_s) 30 | return (feat_t - feat_s).pow(2).mean() 31 | 32 | 33 | def logit_distill_loss(logits_t, logits_s, loss_type, temperature): 34 | if loss_type == "soft": 35 | distillation_loss = F.kl_div( 36 | F.log_softmax(logits_s / temperature, dim=1), 37 | F.log_softmax(logits_t / temperature, dim=1), 38 | reduction='sum', 39 | log_target=True 40 | ) * (temperature * temperature) / logits_s.numel() 41 | elif loss_type == "hard": 42 | distillation_loss = F.cross_entropy(logits_s, logits_t.argmax(dim=1)) 43 | else: 44 | raise NotImplementedError 45 | 46 | return distillation_loss 47 | 48 | 49 | class DistillationWrapper(nn.Module): 50 | 51 | def __init__(self, student_model, teacher_mode): 52 | super(DistillationWrapper, self).__init__() 53 | self.enable_inter = cfg.DISTILLATION.ENABLE_INTER 54 | self.inter_transform_type = cfg.DISTILLATION.INTER_TRANSFORM 55 | self.student_idx = cfg.DISTILLATION.INTER_STUDENT_INDEX 56 | self.teacher_idx = cfg.DISTILLATION.INTER_TEACHER_INDEX 57 | self.enable_logit = cfg.DISTILLATION.ENABLE_LOGIT 58 | self.logit_loss_type = cfg.DISTILLATION.LOGIT_LOSS 59 | self.teacher_img_size = cfg.DISTILLATION.TEACHER_IMG_SIZE 60 | self.offline = cfg.DISTILLATION.OFFLINE 61 | assert not self.offline or not self.enable_logit, 'Logit distillation is not supported when offline is enabled.' 62 | 63 | self.student_model = student_model 64 | 65 | self.teacher_model = teacher_mode 66 | for p in self.teacher_model.parameters(): 67 | p.requires_grad = False 68 | logger.info("Build teacher model {}".format(type(self.teacher_model))) 69 | 70 | teacher_weights = cfg.DISTILLATION.TEACHER_WEIGHTS 71 | if teacher_weights: 72 | checkpoint = torch.load(teacher_weights)["model_state"] 73 | logger.info("Loaded initial weights of teacher model from: {}".format(teacher_weights)) 74 | self.teacher_model.load_state_dict(checkpoint) 75 | 76 | if self.inter_transform_type == 'linear': 77 | self.feature_transforms = nn.ModuleList() 78 | for i, j in zip(self.student_idx, self.teacher_idx): 79 | self.feature_transforms.append(nn.Conv2d(self.student_model.feature_dims[i], self.teacher_model.feature_dims[j], 1)) 80 | 81 | def load_state_dict(self, state_dict, strict=True): 82 | return self.student_model.load_state_dict(state_dict, strict) 83 | 84 | def state_dict(self, destination=None, prefix='', keep_vars=False): 85 | return self.student_model.state_dict(destination, prefix, keep_vars) 86 | 87 | def forward(self, x): 88 | return self.student_model(x) 89 | 90 | def complexity(self): 91 | complexity = dict() 92 | student_complexity = self.student_model.complexity() 93 | teacher_complexity = self.teacher_model.complexity() 94 | complexity["student"] = student_complexity 95 | complexity["teacher"] = teacher_complexity 96 | return complexity 97 | 98 | def guidance_loss(self, x, offline_feats): 99 | logits_s = self.student_model.distill_logits 100 | feats_s = self.student_model.features 101 | 102 | if self.offline: 103 | logits_t = None 104 | feats_t = offline_feats 105 | else: 106 | x = F.interpolate(x, size=(self.teacher_img_size, self.teacher_img_size), mode='bilinear', align_corners=False) 107 | with torch.no_grad(): 108 | logits_t = self.teacher_model(x) 109 | feats_t = self.teacher_model.features 110 | 111 | loss_inter = x.new_tensor(0.0) 112 | if self.enable_inter: 113 | for i, (idx_t, idx_s) in enumerate(zip(self.teacher_idx, self.student_idx)): 114 | feat_t = feats_t[idx_t] 115 | feat_s = feats_s[idx_s] 116 | 117 | if self.inter_transform_type == 'linear': 118 | feat_s = self.feature_transforms[i](feat_s) 119 | 120 | dsize = (max(feat_t.size(-2), feat_s.size(-2)), max(feat_t.size(-1), feat_s.size(-1))) 121 | feat_t = F.interpolate(feat_t, dsize, mode='bilinear', align_corners=False) 122 | feat_s = F.interpolate(feat_s, dsize, mode='bilinear', align_corners=False) 123 | loss_inter = loss_inter + inter_distill_loss(feat_t, feat_s, self.inter_transform_type) 124 | 125 | loss_logit = logit_distill_loss(logits_t, logits_s, self.logit_loss_type) if self.enable_logit else x.new_tensor(0.0) 126 | 127 | return loss_inter, loss_logit 128 | -------------------------------------------------------------------------------- /pycls/models/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | from .t2t_vit import T2TViT 2 | from .deit import DeiT 3 | from .pvt_v2 import PVTv2 4 | from .pvt import PVT 5 | from .pit import PiT 6 | from .convit import ConViT 7 | from .cvt import CvT 8 | -------------------------------------------------------------------------------- /pycls/models/transformers/base.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from abc import ABCMeta, abstractmethod 3 | 4 | from pycls.core.config import cfg 5 | 6 | 7 | class BaseTransformerModel(nn.Module, metaclass=ABCMeta): 8 | """ 9 | Base class for Transformer models. 10 | 11 | Attributes: 12 | - self.features (List[Tensor]): the features in each block. 13 | - self.feature_dims (List[int]): the dimension of features in each block. 14 | - self.distill_logits (Tensor|None): the logits of the distillation token, only for DeiT. 15 | """ 16 | 17 | def __init__(self): 18 | super(BaseTransformerModel, self).__init__() 19 | # Base configs for Transformers 20 | self.img_size = cfg.MODEL.IMG_SIZE 21 | self.patch_size = cfg.TRANSFORMER.PATCH_SIZE 22 | self.patch_stride = cfg.TRANSFORMER.PATCH_STRIDE 23 | self.patch_padding = cfg.TRANSFORMER.PATCH_PADDING 24 | self.in_channels = cfg.MODEL.IN_CHANNELS 25 | self.num_classes = cfg.MODEL.NUM_CLASSES 26 | self.hidden_dim = cfg.TRANSFORMER.HIDDEN_DIM 27 | self.depth = cfg.TRANSFORMER.DEPTH 28 | self.num_heads = cfg.TRANSFORMER.NUM_HEADS 29 | self.mlp_ratio = cfg.TRANSFORMER.MLP_RATIO 30 | self.drop_rate = cfg.TRANSFORMER.DROP_RATE 31 | self.drop_path_rate = cfg.TRANSFORMER.DROP_PATH_RATE 32 | self.attn_drop_rate = cfg.TRANSFORMER.ATTENTION_DROP_RATE 33 | 34 | # Calculate the dimension of features in each block 35 | if isinstance(self.hidden_dim, int): 36 | assert isinstance(self.depth, int) 37 | self.feature_dims = [self.hidden_dim] * self.depth 38 | elif isinstance(self.hidden_dim, (list, tuple)): 39 | assert isinstance(self.depth, (list, tuple)) 40 | assert len(self.hidden_dim) == len(self.depth) 41 | self.feature_dims = sum([[self.hidden_dim[i]] * d for i, d in enumerate(self.depth)], []) 42 | else: 43 | raise ValueError 44 | self.features = list() 45 | self.distill_logits = None 46 | 47 | def initialize_hooks(self, layers): 48 | """ 49 | Initialize hooks for the given layers. 50 | """ 51 | for layer in layers: 52 | layer.register_forward_hook(self._feature_hook) 53 | self.register_forward_pre_hook(lambda module, inp: self.features.clear()) 54 | 55 | @abstractmethod 56 | def _feature_hook(self, module, inputs, outputs): 57 | pass 58 | 59 | def complexity(self): 60 | params = sum(p.numel() for p in self.parameters() if p.requires_grad) 61 | return {'params': f'{round(params/1e6, 2)}M'} 62 | -------------------------------------------------------------------------------- /pycls/models/transformers/common.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from timm.models.layers import DropPath 4 | 5 | from pycls.core.config import cfg 6 | 7 | 8 | def layernorm(w_in): 9 | return nn.LayerNorm(w_in, eps=cfg.TRANSFORMER.LN_EPS) 10 | 11 | 12 | class MultiheadAttention(nn.Module): 13 | 14 | def __init__(self, 15 | in_channels, 16 | out_channels, 17 | num_heads, 18 | qkv_bias=False, 19 | attn_drop_rate=0., 20 | proj_drop_rate=0., 21 | qk_scale=None): 22 | super(MultiheadAttention, self).__init__() 23 | assert out_channels % num_heads == 0 24 | self.in_channels = in_channels 25 | self.out_channels = out_channels 26 | self.num_heads = num_heads 27 | 28 | self.norm_factor = qk_scale if qk_scale else (out_channels // num_heads) ** -0.5 29 | self.qkv_transform = nn.Linear(in_channels, out_channels * 3, bias=qkv_bias) 30 | self.projection = nn.Linear(out_channels, out_channels) 31 | self.attention_dropout = nn.Dropout(attn_drop_rate) 32 | self.projection_dropout = nn.Dropout(proj_drop_rate) 33 | 34 | def forward(self, x): 35 | N, L, _ = x.shape 36 | x = self.qkv_transform(x).view(N, L, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 37 | query, key, value = x[0], x[1], x[2] 38 | 39 | qk = query @ key.transpose(-1, -2) * self.norm_factor 40 | qk = F.softmax(qk, dim=-1) 41 | qk = self.attention_dropout(qk) 42 | 43 | out = qk @ value 44 | out = out.transpose(1, 2).contiguous().view(N, L, self.out_channels) 45 | out = self.projection(out) 46 | out = self.projection_dropout(out) 47 | 48 | if self.in_channels != self.out_channels: 49 | out = out + value.squeeze(1) 50 | 51 | return out 52 | 53 | 54 | class MLP(nn.Module): 55 | 56 | def __init__(self, 57 | in_channels, 58 | out_channels, 59 | drop_rate=0., 60 | hidden_ratio=1.): 61 | super(MLP, self).__init__() 62 | self.in_channels = in_channels 63 | self.out_channels = out_channels 64 | self.hidden_channels = int(in_channels * hidden_ratio) 65 | self.fc1 = nn.Linear(in_channels, self.hidden_channels) 66 | self.fc2 = nn.Linear(self.hidden_channels, out_channels) 67 | self.drop = nn.Dropout(drop_rate) 68 | 69 | def forward(self, x): 70 | x = F.gelu(self.fc1(x)) 71 | x = self.drop(x) 72 | x = self.fc2(x) 73 | x = self.drop(x) 74 | return x 75 | 76 | 77 | class TransformerLayer(nn.Module): 78 | 79 | def __init__(self, 80 | in_channels, 81 | num_heads, 82 | qkv_bias=False, 83 | out_channels=None, 84 | mlp_ratio=1., 85 | drop_rate=0., 86 | attn_drop_rate=0., 87 | drop_path_rate=0., 88 | qk_scale=None): 89 | super(TransformerLayer, self).__init__() 90 | if out_channels is None: 91 | out_channels = in_channels 92 | self.in_channels = in_channels 93 | self.out_channels = out_channels 94 | 95 | self.norm1 = layernorm(in_channels) 96 | self.attn = MultiheadAttention( 97 | in_channels=in_channels, 98 | out_channels=out_channels, 99 | num_heads=num_heads, 100 | qkv_bias=qkv_bias, 101 | attn_drop_rate=attn_drop_rate, 102 | proj_drop_rate=drop_rate, 103 | qk_scale=qk_scale) 104 | self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() 105 | self.norm2 = layernorm(out_channels) 106 | self.mlp = MLP( 107 | in_channels=out_channels, 108 | out_channels=out_channels, 109 | drop_rate=drop_rate, 110 | hidden_ratio=mlp_ratio) 111 | 112 | def forward(self, x): 113 | if self.in_channels == self.out_channels: 114 | x = x + self.drop_path(self.attn(self.norm1(x))) 115 | else: 116 | x = self.attn(self.norm1(x)) 117 | x = x + self.drop_path(self.mlp(self.norm2(x))) 118 | return x 119 | 120 | 121 | class PatchEmbedding(nn.Module): 122 | 123 | def __init__(self, 124 | img_size=224, 125 | patch_size=16, 126 | in_channels=3, 127 | out_channels=768): 128 | super(PatchEmbedding, self).__init__() 129 | self.img_size = img_size 130 | self.patch_size = patch_size 131 | self.num_patches = (img_size // patch_size) ** 2 132 | self.projection = nn.Conv2d(in_channels, out_channels, kernel_size=patch_size, stride=patch_size) 133 | 134 | def forward(self, x): 135 | _, _, H, W = x.shape 136 | assert H == self.img_size and W == self.img_size 137 | x = self.projection(x) 138 | x = x.flatten(2).transpose(1, 2) 139 | return x 140 | -------------------------------------------------------------------------------- /pycls/models/transformers/convit.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from the official implementation of PiT. 3 | https://github.com/facebookresearch/convit/blob/main/convit.py 4 | """ 5 | 6 | # Copyright (c) 2015-present, Facebook, Inc. 7 | # All rights reserved. 8 | # 9 | # This source code is licensed under the CC-by-NC license found in the 10 | # LICENSE file in the root directory of this source tree. 11 | # 12 | 13 | import torch 14 | import torch.nn as nn 15 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 16 | 17 | from ..build import MODEL 18 | from .common import layernorm 19 | from pycls.core.config import cfg 20 | from .base import BaseTransformerModel 21 | 22 | 23 | class Mlp(nn.Module): 24 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 25 | super().__init__() 26 | out_features = out_features or in_features 27 | hidden_features = hidden_features or in_features 28 | self.fc1 = nn.Linear(in_features, hidden_features) 29 | self.act = act_layer() 30 | self.fc2 = nn.Linear(hidden_features, out_features) 31 | self.drop = nn.Dropout(drop) 32 | self.apply(self._init_weights) 33 | 34 | def _init_weights(self, m): 35 | if isinstance(m, nn.Linear): 36 | trunc_normal_(m.weight, std=.02) 37 | if isinstance(m, nn.Linear) and m.bias is not None: 38 | nn.init.constant_(m.bias, 0) 39 | elif isinstance(m, nn.LayerNorm): 40 | nn.init.constant_(m.bias, 0) 41 | nn.init.constant_(m.weight, 1.0) 42 | 43 | def forward(self, x): 44 | x = self.fc1(x) 45 | x = self.act(x) 46 | x = self.drop(x) 47 | x = self.fc2(x) 48 | x = self.drop(x) 49 | return x 50 | 51 | 52 | class GPSA(nn.Module): 53 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., 54 | locality_strength=1., use_local_init=True): 55 | super().__init__() 56 | self.num_heads = num_heads 57 | self.dim = dim 58 | head_dim = dim // num_heads 59 | self.scale = qk_scale or head_dim ** -0.5 60 | 61 | self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias) 62 | self.v = nn.Linear(dim, dim, bias=qkv_bias) 63 | 64 | self.attn_drop = nn.Dropout(attn_drop) 65 | self.proj = nn.Linear(dim, dim) 66 | self.pos_proj = nn.Linear(3, num_heads) 67 | self.proj_drop = nn.Dropout(proj_drop) 68 | self.locality_strength = locality_strength 69 | self.gating_param = nn.Parameter(torch.ones(self.num_heads)) 70 | self.apply(self._init_weights) 71 | if use_local_init: 72 | self.local_init(locality_strength=locality_strength) 73 | 74 | def _init_weights(self, m): 75 | if isinstance(m, nn.Linear): 76 | trunc_normal_(m.weight, std=.02) 77 | if isinstance(m, nn.Linear) and m.bias is not None: 78 | nn.init.constant_(m.bias, 0) 79 | elif isinstance(m, nn.LayerNorm): 80 | nn.init.constant_(m.bias, 0) 81 | nn.init.constant_(m.weight, 1.0) 82 | 83 | def forward(self, x): 84 | B, N, C = x.shape 85 | if not hasattr(self, 'rel_indices') or self.rel_indices.size(1)!=N: 86 | self.get_rel_indices(N) 87 | 88 | attn = self.get_attention(x) 89 | v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 90 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 91 | x = self.proj(x) 92 | x = self.proj_drop(x) 93 | return x 94 | 95 | def get_attention(self, x): 96 | B, N, C = x.shape 97 | qk = self.qk(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 98 | q, k = qk[0], qk[1] 99 | pos_score = self.rel_indices.expand(B, -1, -1,-1) 100 | pos_score = self.pos_proj(pos_score).permute(0,3,1,2) 101 | patch_score = (q @ k.transpose(-2, -1)) * self.scale 102 | patch_score = patch_score.softmax(dim=-1) 103 | pos_score = pos_score.softmax(dim=-1) 104 | 105 | gating = self.gating_param.view(1,-1,1,1) 106 | attn = (1.-torch.sigmoid(gating)) * patch_score + torch.sigmoid(gating) * pos_score 107 | attn /= attn.sum(dim=-1).unsqueeze(-1) 108 | attn = self.attn_drop(attn) 109 | return attn 110 | 111 | def get_attention_map(self, x, return_map = False): 112 | 113 | attn_map = self.get_attention(x).mean(0) # average over batch 114 | distances = self.rel_indices.squeeze()[:,:,-1]**.5 115 | dist = torch.einsum('nm,hnm->h', (distances, attn_map)) 116 | dist /= distances.size(0) 117 | if return_map: 118 | return dist, attn_map 119 | else: 120 | return dist 121 | 122 | def local_init(self, locality_strength=1.): 123 | 124 | self.v.weight.data.copy_(torch.eye(self.dim)) 125 | locality_distance = 1 #max(1,1/locality_strength**.5) 126 | 127 | kernel_size = int(self.num_heads**.5) 128 | center = (kernel_size-1)/2 if kernel_size%2==0 else kernel_size//2 129 | for h1 in range(kernel_size): 130 | for h2 in range(kernel_size): 131 | position = h1+kernel_size*h2 132 | self.pos_proj.weight.data[position,2] = -1 133 | self.pos_proj.weight.data[position,1] = 2*(h1-center)*locality_distance 134 | self.pos_proj.weight.data[position,0] = 2*(h2-center)*locality_distance 135 | self.pos_proj.weight.data *= locality_strength 136 | 137 | def get_rel_indices(self, num_patches): 138 | img_size = int(num_patches**.5) 139 | rel_indices = torch.zeros(1, num_patches, num_patches, 3) 140 | ind = torch.arange(img_size).view(1,-1) - torch.arange(img_size).view(-1, 1) 141 | indx = ind.repeat(img_size,img_size) 142 | indy = ind.repeat_interleave(img_size,dim=0).repeat_interleave(img_size,dim=1) 143 | indd = indx**2 + indy**2 144 | rel_indices[:,:,:,2] = indd.unsqueeze(0) 145 | rel_indices[:,:,:,1] = indy.unsqueeze(0) 146 | rel_indices[:,:,:,0] = indx.unsqueeze(0) 147 | device = self.qk.weight.device 148 | self.rel_indices = rel_indices.to(device) 149 | 150 | 151 | class MHSA(nn.Module): 152 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 153 | super().__init__() 154 | self.num_heads = num_heads 155 | head_dim = dim // num_heads 156 | self.scale = qk_scale or head_dim ** -0.5 157 | 158 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 159 | self.attn_drop = nn.Dropout(attn_drop) 160 | self.proj = nn.Linear(dim, dim) 161 | self.proj_drop = nn.Dropout(proj_drop) 162 | self.apply(self._init_weights) 163 | 164 | def _init_weights(self, m): 165 | if isinstance(m, nn.Linear): 166 | trunc_normal_(m.weight, std=.02) 167 | if isinstance(m, nn.Linear) and m.bias is not None: 168 | nn.init.constant_(m.bias, 0) 169 | elif isinstance(m, nn.LayerNorm): 170 | nn.init.constant_(m.bias, 0) 171 | nn.init.constant_(m.weight, 1.0) 172 | 173 | def get_attention_map(self, x, return_map = False): 174 | B, N, C = x.shape 175 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 176 | q, k, v = qkv[0], qkv[1], qkv[2] 177 | attn_map = (q @ k.transpose(-2, -1)) * self.scale 178 | attn_map = attn_map.softmax(dim=-1).mean(0) 179 | 180 | img_size = int(N**.5) 181 | ind = torch.arange(img_size).view(1,-1) - torch.arange(img_size).view(-1, 1) 182 | indx = ind.repeat(img_size,img_size) 183 | indy = ind.repeat_interleave(img_size,dim=0).repeat_interleave(img_size,dim=1) 184 | indd = indx**2 + indy**2 185 | distances = indd**.5 186 | distances = distances.to('cuda') 187 | 188 | dist = torch.einsum('nm,hnm->h', (distances, attn_map)) 189 | dist /= N 190 | 191 | if return_map: 192 | return dist, attn_map 193 | else: 194 | return dist 195 | 196 | def forward(self, x): 197 | B, N, C = x.shape 198 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 199 | q, k, v = qkv[0], qkv[1], qkv[2] 200 | 201 | attn = (q @ k.transpose(-2, -1)) * self.scale 202 | attn = attn.softmax(dim=-1) 203 | attn = self.attn_drop(attn) 204 | 205 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 206 | x = self.proj(x) 207 | x = self.proj_drop(x) 208 | return x 209 | 210 | 211 | class Block(nn.Module): 212 | 213 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 214 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_gpsa=True, **kwargs): 215 | super().__init__() 216 | self.norm1 = norm_layer(dim) 217 | self.use_gpsa = use_gpsa 218 | if self.use_gpsa: 219 | self.attn = GPSA(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, **kwargs) 220 | else: 221 | self.attn = MHSA(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, **kwargs) 222 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 223 | self.norm2 = norm_layer(dim) 224 | mlp_hidden_dim = int(dim * mlp_ratio) 225 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 226 | 227 | def forward(self, x): 228 | x = x + self.drop_path(self.attn(self.norm1(x))) 229 | x = x + self.drop_path(self.mlp(self.norm2(x))) 230 | return x 231 | 232 | 233 | class PatchEmbed(nn.Module): 234 | 235 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 236 | super().__init__() 237 | img_size = to_2tuple(img_size) 238 | patch_size = to_2tuple(patch_size) 239 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 240 | self.img_size = img_size 241 | self.patch_size = patch_size 242 | self.num_patches = num_patches 243 | 244 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 245 | self.apply(self._init_weights) 246 | def forward(self, x): 247 | B, C, H, W = x.shape 248 | assert H == self.img_size[0] and W == self.img_size[1], \ 249 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 250 | x = self.proj(x).flatten(2).transpose(1, 2) 251 | return x 252 | def _init_weights(self, m): 253 | if isinstance(m, nn.Linear): 254 | trunc_normal_(m.weight, std=.02) 255 | if isinstance(m, nn.Linear) and m.bias is not None: 256 | nn.init.constant_(m.bias, 0) 257 | elif isinstance(m, nn.LayerNorm): 258 | nn.init.constant_(m.bias, 0) 259 | nn.init.constant_(m.weight, 1.0) 260 | 261 | 262 | class HybridEmbed(nn.Module): 263 | 264 | def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): 265 | super().__init__() 266 | assert isinstance(backbone, nn.Module) 267 | img_size = to_2tuple(img_size) 268 | self.img_size = img_size 269 | self.backbone = backbone 270 | if feature_size is None: 271 | with torch.no_grad(): 272 | training = backbone.training 273 | if training: 274 | backbone.eval() 275 | o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))[-1] 276 | feature_size = o.shape[-2:] 277 | feature_dim = o.shape[1] 278 | backbone.train(training) 279 | else: 280 | feature_size = to_2tuple(feature_size) 281 | feature_dim = self.backbone.feature_info.channels()[-1] 282 | self.num_patches = feature_size[0] * feature_size[1] 283 | self.proj = nn.Linear(feature_dim, embed_dim) 284 | self.apply(self._init_weights) 285 | 286 | def forward(self, x): 287 | x = self.backbone(x)[-1] 288 | x = x.flatten(2).transpose(1, 2) 289 | x = self.proj(x) 290 | return x 291 | 292 | 293 | @MODEL.register() 294 | class ConViT(BaseTransformerModel): 295 | 296 | def __init__(self, in_chans=cfg.MODEL.IN_CHANNELS, qkv_bias=False, qk_scale=None, hybrid_backbone=None, use_pos_embed=True): 297 | super(ConViT, self).__init__() 298 | self.local_up_to_layer = cfg.CONVIT.LOCAL_LAYERS 299 | self.locality_strength = cfg.CONVIT.LOCALITY_STRENGTH 300 | self.use_pos_embed = use_pos_embed 301 | 302 | if hybrid_backbone is not None: 303 | self.patch_embed = HybridEmbed( 304 | hybrid_backbone, img_size=self.img_size, in_chans=in_chans, embed_dim=self.hidden_dim) 305 | else: 306 | self.patch_embed = PatchEmbed( 307 | img_size=self.img_size, patch_size=self.patch_size, in_chans=in_chans, embed_dim=self.hidden_dim) 308 | self.num_patches = self.patch_embed.num_patches 309 | 310 | self.cls_token = nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) 311 | self.pos_drop = nn.Dropout(p=self.drop_rate) 312 | 313 | if self.use_pos_embed: 314 | self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches, self.hidden_dim)) 315 | trunc_normal_(self.pos_embed, std=.02) 316 | 317 | dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)] # stochastic depth decay rule 318 | self.blocks = nn.ModuleList([ 319 | Block( 320 | dim=self.hidden_dim, num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 321 | drop=self.drop_rate, attn_drop=self.attn_drop_rate, drop_path=dpr[i], norm_layer=layernorm, 322 | use_gpsa=True, 323 | locality_strength=self.locality_strength) 324 | if i < self.local_up_to_layer else 325 | Block( 326 | dim=self.hidden_dim, num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 327 | drop=self.drop_rate, attn_drop=self.attn_drop_rate, drop_path=dpr[i], norm_layer=layernorm, 328 | use_gpsa=False) 329 | for i in range(self.depth)]) 330 | self.initialize_hooks(self.blocks) 331 | 332 | self.norm = layernorm(self.hidden_dim) 333 | self.feature_info = [dict(num_chs=self.hidden_dim, reduction=0, module='head')] 334 | self.head = nn.Linear(self.hidden_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() 335 | 336 | trunc_normal_(self.cls_token, std=.02) 337 | self.head.apply(self._init_weights) 338 | 339 | def _feature_hook(self, module, inputs, outputs): 340 | feat_size = int(self.num_patches ** 0.5) 341 | if outputs.size(1) == self.num_patches: 342 | x = outputs.view(outputs.size(0), feat_size, feat_size, self.hidden_dim) 343 | else: 344 | x = outputs[:, 1:].view(outputs.size(0), feat_size, feat_size, self.hidden_dim) 345 | x = x.permute(0, 3, 1, 2).contiguous() 346 | self.features.append(x) 347 | 348 | def _init_weights(self, m): 349 | if isinstance(m, nn.Linear): 350 | trunc_normal_(m.weight, std=.02) 351 | if isinstance(m, nn.Linear) and m.bias is not None: 352 | nn.init.constant_(m.bias, 0) 353 | elif isinstance(m, nn.LayerNorm): 354 | nn.init.constant_(m.bias, 0) 355 | nn.init.constant_(m.weight, 1.0) 356 | 357 | def forward_features(self, x): 358 | B = x.shape[0] 359 | x = self.patch_embed(x) 360 | 361 | cls_tokens = self.cls_token.expand(B, -1, -1) 362 | 363 | if self.use_pos_embed: 364 | x = x + self.pos_embed 365 | x = self.pos_drop(x) 366 | 367 | for u, blk in enumerate(self.blocks): 368 | if u == self.local_up_to_layer : 369 | x = torch.cat((cls_tokens, x), dim=1) 370 | x = blk(x) 371 | 372 | x = self.norm(x) 373 | return x[:, 0] 374 | 375 | def forward(self, x): 376 | x = self.forward_features(x) 377 | x = self.head(x) 378 | return x 379 | -------------------------------------------------------------------------------- /pycls/models/transformers/deit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | 4 | """ 5 | Modified from the official implementation of DeiT. 6 | https://github.com/facebookresearch/deit/blob/main/models.py 7 | """ 8 | 9 | import torch 10 | import torch.nn as nn 11 | from timm.models.layers import trunc_normal_ 12 | 13 | from ..build import MODEL 14 | from pycls.core.config import cfg 15 | from .base import BaseTransformerModel 16 | from .common import PatchEmbedding, TransformerLayer, layernorm 17 | 18 | 19 | @MODEL.register() 20 | class DeiT(BaseTransformerModel): 21 | 22 | def __init__(self): 23 | super(DeiT, self).__init__() 24 | self.cls_token = nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) 25 | 26 | self.patch_embed = PatchEmbedding(img_size=self.img_size, patch_size=self.patch_size, in_channels=self.in_channels, out_channels=self.hidden_dim) 27 | self.num_patches = self.patch_embed.num_patches 28 | self.num_tokens = 1 + cfg.DISTILLATION.ENABLE_LOGIT 29 | self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + self.num_tokens, self.hidden_dim)) 30 | self.pe_dropout = nn.Dropout(p=self.drop_rate) 31 | 32 | dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)] # stochastic depth decay rule 33 | self.layers = nn.ModuleList([TransformerLayer( 34 | in_channels=self.hidden_dim, 35 | num_heads=self.num_heads, 36 | qkv_bias=True, 37 | mlp_ratio=self.mlp_ratio, 38 | drop_rate=self.drop_rate, 39 | attn_drop_rate=self.attn_drop_rate, 40 | drop_path_rate=dpr[i]) for i in range(self.depth)]) 41 | self.initialize_hooks(self.layers) 42 | 43 | self.norm = layernorm(self.hidden_dim) 44 | self.apply(self._init_weights) 45 | 46 | self.head = nn.Linear(self.hidden_dim, self.num_classes) 47 | nn.init.zeros_(self.head.weight) 48 | nn.init.constant_(self.head.bias, 0) 49 | 50 | trunc_normal_(self.cls_token, std=.02) 51 | trunc_normal_(self.pos_embed, std=.02) 52 | self.distill_logits = None 53 | 54 | self.distill_token = None 55 | self.distill_head = None 56 | if cfg.DISTILLATION.ENABLE_LOGIT: 57 | self.distill_token = nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) 58 | self.distill_head = nn.Linear(self.hidden_dim, self.num_classes) 59 | nn.init.zeros_(self.distill_head.weight) 60 | nn.init.constant_(self.distill_head.bias, 0) 61 | trunc_normal_(self.distill_token, std=.02) 62 | 63 | def _feature_hook(self, module, inputs, outputs): 64 | feat_size = int(self.num_patches ** 0.5) 65 | x = outputs[:, self.num_tokens:].view(outputs.size(0), feat_size, feat_size, self.hidden_dim) 66 | x = x.permute(0, 3, 1, 2).contiguous() 67 | self.features.append(x) 68 | 69 | def _init_weights(self, module): 70 | if isinstance(module, nn.Linear): 71 | trunc_normal_(module.weight, std=.02) 72 | if module.bias is not None: 73 | nn.init.zeros_(module.bias) 74 | elif isinstance(module, nn.LayerNorm): 75 | nn.init.zeros_(module.bias) 76 | nn.init.ones_(module.weight) 77 | 78 | def forward(self, x): 79 | x = self.patch_embed(x) 80 | if self.num_tokens == 1: 81 | x = torch.cat([self.cls_token.repeat(x.size(0), 1, 1), x], dim=1) 82 | else: 83 | x = torch.cat([self.cls_token.repeat(x.size(0), 1, 1), self.distill_token.repeat(x.size(0), 1, 1), x], dim=1) 84 | x = self.pe_dropout(x + self.pos_embed) 85 | 86 | for layer in self.layers: 87 | x = layer(x) 88 | 89 | x = self.norm(x) 90 | logits = self.head(x[:, 0]) 91 | 92 | if self.num_tokens == 1: 93 | return logits 94 | 95 | self.distill_logits = None 96 | self.distill_logits = self.distill_head(x[:, 1]) 97 | 98 | if self.training: 99 | return logits 100 | else: 101 | return (logits + self.distill_logits) / 2 102 | -------------------------------------------------------------------------------- /pycls/models/transformers/pit.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021-present NAVER Corp. 2 | # Apache License v2.0 3 | 4 | """ 5 | Modified from the official implementation of PiT. 6 | https://github.com/naver-ai/pit/blob/master/pit.py 7 | """ 8 | 9 | import math 10 | 11 | import torch 12 | import torch.nn as nn 13 | from einops import rearrange 14 | from timm.models.layers import trunc_normal_ 15 | 16 | from ..build import MODEL 17 | from pycls.core.config import cfg 18 | from .base import BaseTransformerModel 19 | from .common import TransformerLayer, layernorm 20 | 21 | 22 | class Transformer(nn.Module): 23 | 24 | def __init__(self, embed_dim, depth, heads, mlp_ratio, 25 | drop_rate=.0, attn_drop_rate=.0, drop_path_prob=None): 26 | super(Transformer, self).__init__() 27 | self.layers = nn.ModuleList([]) 28 | 29 | if drop_path_prob is None: 30 | drop_path_prob = [0.0 for _ in range(depth)] 31 | 32 | self.blocks = nn.ModuleList([ 33 | TransformerLayer( 34 | in_channels=embed_dim, 35 | num_heads=heads, 36 | qkv_bias=True, 37 | mlp_ratio=mlp_ratio, 38 | drop_rate=drop_rate, 39 | attn_drop_rate=attn_drop_rate, 40 | drop_path_rate=drop_path_prob[i]) 41 | for i in range(depth)]) 42 | 43 | def forward(self, x, cls_tokens): 44 | h, w = x.shape[2:4] 45 | x = rearrange(x, 'b c h w -> b (h w) c') 46 | 47 | token_length = cls_tokens.shape[1] 48 | x = torch.cat((cls_tokens, x), dim=1) 49 | for blk in self.blocks: 50 | blk.shape_info = (token_length, h, w) 51 | x = blk(x) 52 | 53 | cls_tokens = x[:, :token_length] 54 | x = x[:, token_length:] 55 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 56 | 57 | return x, cls_tokens 58 | 59 | 60 | class conv_head_pooling(nn.Module): 61 | 62 | def __init__(self, in_feature, out_feature, stride, 63 | padding_mode='zeros'): 64 | super(conv_head_pooling, self).__init__() 65 | 66 | self.conv = nn.Conv2d(in_feature, out_feature, kernel_size=stride + 1, 67 | padding=stride // 2, stride=stride, 68 | padding_mode=padding_mode, groups=in_feature) 69 | self.fc = nn.Linear(in_feature, out_feature) 70 | 71 | def forward(self, x, cls_token): 72 | 73 | x = self.conv(x) 74 | cls_token = self.fc(cls_token) 75 | 76 | return x, cls_token 77 | 78 | 79 | class conv_embedding(nn.Module): 80 | 81 | def __init__(self, in_channels, out_channels, patch_size, 82 | stride, padding): 83 | super(conv_embedding, self).__init__() 84 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=patch_size, 85 | stride=stride, padding=padding, bias=True) 86 | 87 | def forward(self, x): 88 | x = self.conv(x) 89 | return x 90 | 91 | 92 | @MODEL.register() 93 | class PiT(BaseTransformerModel): 94 | 95 | def __init__(self): 96 | super(PiT, self).__init__() 97 | self.stride = cfg.PIT.STRIDE 98 | total_block = sum(self.depth) 99 | block_idx = 0 100 | embed_size = math.floor((self.img_size - self.patch_size) / self.stride + 1) 101 | 102 | self.pos_embed = nn.Parameter( 103 | torch.randn(1, self.hidden_dim[0], embed_size, embed_size), 104 | requires_grad=True 105 | ) 106 | self.patch_embed = conv_embedding(self.in_channels, self.hidden_dim[0], 107 | self.patch_size, self.stride, 0) 108 | 109 | self.cls_token = nn.Parameter(torch.randn(1, 1, self.hidden_dim[0])) 110 | self.pos_drop = nn.Dropout(p=self.drop_rate) 111 | 112 | self.transformers = nn.ModuleList([]) 113 | self.pools = nn.ModuleList([]) 114 | 115 | for stage in range(len(self.depth)): 116 | drop_path_prob = [self.drop_path_rate * i / total_block 117 | for i in range(block_idx, block_idx + self.depth[stage])] 118 | block_idx += self.depth[stage] 119 | 120 | self.transformers.append( 121 | Transformer(self.hidden_dim[stage], self.depth[stage], self.num_heads[stage], 122 | self.mlp_ratio, 123 | self.drop_rate, self.attn_drop_rate, drop_path_prob) 124 | ) 125 | if stage < len(self.depth) - 1: 126 | self.pools.append( 127 | conv_head_pooling(self.hidden_dim[stage], self.hidden_dim[stage + 1], stride=2)) 128 | 129 | layers = [[m for m in t.blocks] for t in self.transformers] 130 | layers = sum(layers, []) 131 | self.initialize_hooks(layers) 132 | 133 | self.norm = layernorm(self.hidden_dim[-1]) 134 | self.embed_dim = self.hidden_dim[-1] 135 | 136 | self.head = nn.Linear(self.hidden_dim[-1], self.num_classes) 137 | 138 | trunc_normal_(self.pos_embed, std=.02) 139 | trunc_normal_(self.cls_token, std=.02) 140 | self.apply(self._init_weights) 141 | 142 | def _feature_hook(self, module, inputs, outputs): 143 | token_length, h, w = module.shape_info 144 | x = outputs[:, token_length:] 145 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 146 | self.features.append(x) 147 | 148 | def _init_weights(self, m): 149 | if isinstance(m, nn.LayerNorm): 150 | nn.init.constant_(m.bias, 0) 151 | nn.init.constant_(m.weight, 1.0) 152 | 153 | def forward_features(self, x): 154 | x = self.patch_embed(x) 155 | 156 | pos_embed = self.pos_embed 157 | x = self.pos_drop(x + pos_embed) 158 | cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) 159 | 160 | for stage in range(len(self.pools)): 161 | x, cls_tokens = self.transformers[stage](x, cls_tokens) 162 | x, cls_tokens = self.pools[stage](x, cls_tokens) 163 | x, cls_tokens = self.transformers[-1](x, cls_tokens) 164 | 165 | cls_tokens = self.norm(cls_tokens) 166 | 167 | return cls_tokens 168 | 169 | def forward(self, x): 170 | cls_token = self.forward_features(x) 171 | cls_token = self.head(cls_token[:, 0]) 172 | return cls_token 173 | -------------------------------------------------------------------------------- /pycls/models/transformers/pvt.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from the official implementation of PVT. 3 | https://github.com/whai362/PVT/blob/v2/classification/pvt.py 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 10 | 11 | from ..build import MODEL 12 | from .common import layernorm 13 | from pycls.core.config import cfg 14 | from .base import BaseTransformerModel 15 | 16 | 17 | class Mlp(nn.Module): 18 | 19 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 20 | super().__init__() 21 | out_features = out_features or in_features 22 | hidden_features = hidden_features or in_features 23 | self.fc1 = nn.Linear(in_features, hidden_features) 24 | self.act = act_layer() 25 | self.fc2 = nn.Linear(hidden_features, out_features) 26 | self.drop = nn.Dropout(drop) 27 | 28 | def forward(self, x): 29 | x = self.fc1(x) 30 | x = self.act(x) 31 | x = self.drop(x) 32 | x = self.fc2(x) 33 | x = self.drop(x) 34 | return x 35 | 36 | 37 | class Attention(nn.Module): 38 | 39 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1): 40 | super().__init__() 41 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 42 | 43 | self.dim = dim 44 | self.num_heads = num_heads 45 | head_dim = dim // num_heads 46 | self.scale = qk_scale or head_dim ** -0.5 47 | 48 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 49 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 50 | self.attn_drop = nn.Dropout(attn_drop) 51 | self.proj = nn.Linear(dim, dim) 52 | self.proj_drop = nn.Dropout(proj_drop) 53 | 54 | self.sr_ratio = sr_ratio 55 | if sr_ratio > 1: 56 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 57 | self.norm = layernorm(dim) 58 | 59 | def forward(self, x, H, W): 60 | B, N, C = x.shape 61 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 62 | 63 | if self.sr_ratio > 1: 64 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 65 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 66 | x_ = self.norm(x_) 67 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 68 | else: 69 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 70 | k, v = kv[0], kv[1] 71 | 72 | attn = (q @ k.transpose(-2, -1)) * self.scale 73 | attn = attn.softmax(dim=-1) 74 | attn = self.attn_drop(attn) 75 | 76 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 77 | x = self.proj(x) 78 | x = self.proj_drop(x) 79 | 80 | return x 81 | 82 | 83 | class Block(nn.Module): 84 | 85 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 86 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1): 87 | super().__init__() 88 | self.norm1 = norm_layer(dim) 89 | self.attn = Attention( 90 | dim, 91 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 92 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio) 93 | self.drop_path_rate = DropPath(drop_path) if drop_path > 0. else nn.Identity() 94 | self.norm2 = norm_layer(dim) 95 | mlp_hidden_dim = int(dim * mlp_ratio) 96 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 97 | 98 | def forward(self, x, H, W): 99 | x = x + self.drop_path_rate(self.attn(self.norm1(x), H, W)) 100 | x = x + self.drop_path_rate(self.mlp(self.norm2(x))) 101 | 102 | return x 103 | 104 | 105 | class PatchEmbed(nn.Module): 106 | 107 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 108 | super().__init__() 109 | img_size = to_2tuple(img_size) 110 | patch_size = to_2tuple(patch_size) 111 | 112 | self.img_size = img_size 113 | self.patch_size = patch_size 114 | self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] 115 | self.num_patches = self.H * self.W 116 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 117 | self.norm = layernorm(embed_dim) 118 | 119 | def forward(self, x): 120 | B, C, H, W = x.shape 121 | 122 | x = self.proj(x).flatten(2).transpose(1, 2) 123 | x = self.norm(x) 124 | H, W = H // self.patch_size[0], W // self.patch_size[1] 125 | 126 | return x, (H, W) 127 | 128 | 129 | @MODEL.register() 130 | class PVT(BaseTransformerModel): 131 | 132 | def __init__(self): 133 | super(PVT, self).__init__() 134 | self.sr_ratio = cfg.PVT.SR_RATIO 135 | self.num_stages = len(self.hidden_dim) 136 | 137 | dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depth))] # stochastic depth decay rule 138 | cur = 0 139 | 140 | for i in range(self.num_stages): 141 | patch_embed = PatchEmbed(img_size=self.img_size if i == 0 else self.img_size // (2 ** (i + 1)), 142 | patch_size=self.patch_size[i], 143 | in_chans=self.in_channels if i == 0 else self.hidden_dim[i - 1], 144 | embed_dim=self.hidden_dim[i]) 145 | num_patches = patch_embed.num_patches if i != self.num_stages - 1 else patch_embed.num_patches + 1 146 | pos_embed = nn.Parameter(torch.zeros(1, num_patches, self.hidden_dim[i])) 147 | pos_drop = nn.Dropout(p=self.drop_rate) 148 | 149 | block = nn.ModuleList([Block( 150 | dim=self.hidden_dim[i], num_heads=self.num_heads[i], mlp_ratio=self.mlp_ratio[i], qkv_bias=True, 151 | qk_scale=None, drop=self.drop_rate, attn_drop=self.attn_drop_rate, drop_path=dpr[cur + j], 152 | norm_layer=layernorm, sr_ratio=self.sr_ratio[i]) 153 | for j in range(self.depth[i])]) 154 | cur += self.depth[i] 155 | 156 | setattr(self, f"patch_embed{i + 1}", patch_embed) 157 | setattr(self, f"pos_embed{i + 1}", pos_embed) 158 | setattr(self, f"pos_drop{i + 1}", pos_drop) 159 | setattr(self, f"block{i + 1}", block) 160 | 161 | layers = [[m for m in getattr(self, f'block{i + 1}')] for i in range(self.num_stages)] 162 | layers = sum(layers, []) 163 | self.initialize_hooks(layers) 164 | 165 | self.norm = layernorm(self.hidden_dim[-1]) 166 | 167 | # cls_token 168 | self.cls_token = nn.Parameter(torch.zeros(1, 1, self.hidden_dim[-1])) 169 | 170 | # classification head 171 | self.head = nn.Linear(self.hidden_dim[3], self.num_classes) 172 | 173 | # init weights 174 | for i in range(self.num_stages): 175 | pos_embed = getattr(self, f"pos_embed{i + 1}") 176 | trunc_normal_(pos_embed, std=.02) 177 | trunc_normal_(self.cls_token, std=.02) 178 | self.apply(self._init_weights) 179 | 180 | def _feature_hook(self, module, inputs, outputs): 181 | _, H, W = inputs 182 | if outputs.size(1) == H * W: 183 | x = outputs.view(outputs.size(0), H, W, outputs.size(-1)) 184 | else: 185 | x = outputs[:, 1:].view(outputs.size(0), H, W, outputs.size(-1)) 186 | x = x.permute(0, 3, 1, 2).contiguous() 187 | self.features.append(x) 188 | 189 | def _init_weights(self, m): 190 | if isinstance(m, nn.Linear): 191 | trunc_normal_(m.weight, std=.02) 192 | if isinstance(m, nn.Linear) and m.bias is not None: 193 | nn.init.constant_(m.bias, 0) 194 | elif isinstance(m, nn.LayerNorm): 195 | nn.init.constant_(m.bias, 0) 196 | nn.init.constant_(m.weight, 1.0) 197 | 198 | def _get_pos_embed(self, pos_embed, patch_embed, H, W): 199 | if H * W == self.patch_embed1.num_patches: 200 | return pos_embed 201 | else: 202 | return F.interpolate( 203 | pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2), 204 | size=(H, W), mode="bilinear", align_corners=False).reshape(1, -1, H * W).permute(0, 2, 1) 205 | 206 | def forward_features(self, x): 207 | B = x.shape[0] 208 | 209 | for i in range(self.num_stages): 210 | patch_embed = getattr(self, f"patch_embed{i + 1}") 211 | pos_embed = getattr(self, f"pos_embed{i + 1}") 212 | pos_drop = getattr(self, f"pos_drop{i + 1}") 213 | block = getattr(self, f"block{i + 1}") 214 | x, (H, W) = patch_embed(x) 215 | 216 | if i == self.num_stages - 1: 217 | cls_tokens = self.cls_token.expand(B, -1, -1) 218 | x = torch.cat((cls_tokens, x), dim=1) 219 | pos_embed_ = self._get_pos_embed(pos_embed[:, 1:], patch_embed, H, W) 220 | pos_embed = torch.cat((pos_embed[:, 0:1], pos_embed_), dim=1) 221 | else: 222 | pos_embed = self._get_pos_embed(pos_embed, patch_embed, H, W) 223 | 224 | x = pos_drop(x + pos_embed) 225 | for blk in block: 226 | x = blk(x, H, W) 227 | if i != self.num_stages - 1: 228 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 229 | 230 | x = self.norm(x) 231 | 232 | return x[:, 0] 233 | 234 | def forward(self, x): 235 | x = self.forward_features(x) 236 | x = self.head(x) 237 | return x 238 | -------------------------------------------------------------------------------- /pycls/models/transformers/pvt_v2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from the official implementation of PVTv2. 3 | https://github.com/whai362/PVT/blob/v2/classification/pvt_v2.py 4 | """ 5 | 6 | import math 7 | import torch 8 | import torch.nn as nn 9 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 10 | 11 | from ..build import MODEL 12 | from .common import layernorm 13 | from pycls.core.config import cfg 14 | from .base import BaseTransformerModel 15 | 16 | 17 | class Mlp(nn.Module): 18 | 19 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., linear=False): 20 | super().__init__() 21 | out_features = out_features or in_features 22 | hidden_features = hidden_features or in_features 23 | self.fc1 = nn.Linear(in_features, hidden_features) 24 | self.dwconv = DWConv(hidden_features) 25 | self.act = act_layer() 26 | self.fc2 = nn.Linear(hidden_features, out_features) 27 | self.drop = nn.Dropout(drop) 28 | self.linear = linear 29 | if self.linear: 30 | self.relu = nn.ReLU(inplace=True) 31 | self.apply(self._init_weights) 32 | 33 | def _init_weights(self, m): 34 | if isinstance(m, nn.Linear): 35 | trunc_normal_(m.weight, std=.02) 36 | if isinstance(m, nn.Linear) and m.bias is not None: 37 | nn.init.constant_(m.bias, 0) 38 | elif isinstance(m, nn.LayerNorm): 39 | nn.init.constant_(m.bias, 0) 40 | nn.init.constant_(m.weight, 1.0) 41 | elif isinstance(m, nn.Conv2d): 42 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 43 | fan_out //= m.groups 44 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 45 | if m.bias is not None: 46 | m.bias.data.zero_() 47 | 48 | def forward(self, x, H, W): 49 | x = self.fc1(x) 50 | if self.linear: 51 | x = self.relu(x) 52 | x = self.dwconv(x, H, W) 53 | x = self.act(x) 54 | x = self.drop(x) 55 | x = self.fc2(x) 56 | x = self.drop(x) 57 | return x 58 | 59 | 60 | class Attention(nn.Module): 61 | 62 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1, linear=False): 63 | super().__init__() 64 | assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." 65 | 66 | self.dim = dim 67 | self.num_heads = num_heads 68 | head_dim = dim // num_heads 69 | self.scale = qk_scale or head_dim ** -0.5 70 | 71 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 72 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 73 | self.attn_drop = nn.Dropout(attn_drop) 74 | self.proj = nn.Linear(dim, dim) 75 | self.proj_drop = nn.Dropout(proj_drop) 76 | 77 | self.linear = linear 78 | self.sr_ratio = sr_ratio 79 | if not linear: 80 | if sr_ratio > 1: 81 | self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) 82 | self.norm = layernorm(dim) 83 | else: 84 | self.pool = nn.AdaptiveAvgPool2d(7) 85 | self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1) 86 | self.norm = layernorm(dim) 87 | self.act = nn.GELU() 88 | self.apply(self._init_weights) 89 | 90 | def _init_weights(self, m): 91 | if isinstance(m, nn.Linear): 92 | trunc_normal_(m.weight, std=.02) 93 | if isinstance(m, nn.Linear) and m.bias is not None: 94 | nn.init.constant_(m.bias, 0) 95 | elif isinstance(m, nn.LayerNorm): 96 | nn.init.constant_(m.bias, 0) 97 | nn.init.constant_(m.weight, 1.0) 98 | elif isinstance(m, nn.Conv2d): 99 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 100 | fan_out //= m.groups 101 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 102 | if m.bias is not None: 103 | m.bias.data.zero_() 104 | 105 | def forward(self, x, H, W): 106 | B, N, C = x.shape 107 | q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 108 | 109 | if not self.linear: 110 | if self.sr_ratio > 1: 111 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 112 | x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) 113 | x_ = self.norm(x_) 114 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 115 | else: 116 | kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 117 | else: 118 | x_ = x.permute(0, 2, 1).reshape(B, C, H, W) 119 | x_ = self.sr(self.pool(x_)).reshape(B, C, -1).permute(0, 2, 1) 120 | x_ = self.norm(x_) 121 | x_ = self.act(x_) 122 | kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 123 | k, v = kv[0], kv[1] 124 | 125 | attn = (q @ k.transpose(-2, -1)) * self.scale 126 | attn = attn.softmax(dim=-1) 127 | attn = self.attn_drop(attn) 128 | 129 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 130 | x = self.proj(x) 131 | x = self.proj_drop(x) 132 | 133 | return x 134 | 135 | 136 | class Block(nn.Module): 137 | 138 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 139 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, linear=False): 140 | super().__init__() 141 | self.norm1 = norm_layer(dim) 142 | self.attn = Attention( 143 | dim, 144 | num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, 145 | attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio, linear=linear) 146 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 147 | self.norm2 = norm_layer(dim) 148 | mlp_hidden_dim = int(dim * mlp_ratio) 149 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, linear=linear) 150 | 151 | self.apply(self._init_weights) 152 | 153 | def _init_weights(self, m): 154 | if isinstance(m, nn.Linear): 155 | trunc_normal_(m.weight, std=.02) 156 | if isinstance(m, nn.Linear) and m.bias is not None: 157 | nn.init.constant_(m.bias, 0) 158 | elif isinstance(m, nn.LayerNorm): 159 | nn.init.constant_(m.bias, 0) 160 | nn.init.constant_(m.weight, 1.0) 161 | elif isinstance(m, nn.Conv2d): 162 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 163 | fan_out //= m.groups 164 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 165 | if m.bias is not None: 166 | m.bias.data.zero_() 167 | 168 | def forward(self, x, H, W): 169 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 170 | x = x + self.drop_path(self.mlp(self.norm2(x), H, W)) 171 | 172 | return x 173 | 174 | 175 | class OverlapPatchEmbed(nn.Module): 176 | """ Image to Patch Embedding 177 | """ 178 | 179 | def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): 180 | super().__init__() 181 | 182 | img_size = to_2tuple(img_size) 183 | patch_size = to_2tuple(patch_size) 184 | 185 | assert max(patch_size) > stride, "Set larger patch_size than stride" 186 | 187 | self.img_size = img_size 188 | self.patch_size = patch_size 189 | self.H, self.W = img_size[0] // stride, img_size[1] // stride 190 | self.num_patches = self.H * self.W 191 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, 192 | padding=(patch_size[0] // 2, patch_size[1] // 2)) 193 | self.norm = layernorm(embed_dim) 194 | 195 | self.apply(self._init_weights) 196 | 197 | def _init_weights(self, m): 198 | if isinstance(m, nn.Linear): 199 | trunc_normal_(m.weight, std=.02) 200 | if isinstance(m, nn.Linear) and m.bias is not None: 201 | nn.init.constant_(m.bias, 0) 202 | elif isinstance(m, nn.LayerNorm): 203 | nn.init.constant_(m.bias, 0) 204 | nn.init.constant_(m.weight, 1.0) 205 | elif isinstance(m, nn.Conv2d): 206 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 207 | fan_out //= m.groups 208 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 209 | if m.bias is not None: 210 | m.bias.data.zero_() 211 | 212 | def forward(self, x): 213 | x = self.proj(x) 214 | _, _, H, W = x.shape 215 | x = x.flatten(2).transpose(1, 2) 216 | x = self.norm(x) 217 | 218 | return x, H, W 219 | 220 | 221 | @MODEL.register() 222 | class PVTv2(BaseTransformerModel): 223 | 224 | def __init__(self): 225 | super(PVTv2, self).__init__() 226 | self.sr_ratio = cfg.PVT.SR_RATIO 227 | self.num_stages = len(self.hidden_dim) 228 | 229 | dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depth))] # stochastic depth decay rule 230 | cur = 0 231 | 232 | for i in range(self.num_stages): 233 | patch_embed = OverlapPatchEmbed(img_size=self.img_size if i == 0 else self.img_size // (2 ** (i + 1)), 234 | patch_size=self.patch_size[i], 235 | stride=self.patch_stride[i], 236 | in_chans=self.in_channels if i == 0 else self.hidden_dim[i - 1], 237 | embed_dim=self.hidden_dim[i]) 238 | 239 | block = nn.ModuleList([Block( 240 | dim=self.hidden_dim[i], num_heads=self.num_heads[i], mlp_ratio=self.mlp_ratio[i], qkv_bias=True, qk_scale=None, 241 | drop=self.drop_rate, attn_drop=self.attn_drop_rate, drop_path=dpr[cur + j], norm_layer=layernorm, 242 | sr_ratio=self.sr_ratio[i], linear=False) 243 | for j in range(self.depth[i])]) 244 | norm = layernorm(self.hidden_dim[i]) 245 | cur += self.depth[i] 246 | 247 | setattr(self, f"patch_embed{i + 1}", patch_embed) 248 | setattr(self, f"block{i + 1}", block) 249 | setattr(self, f"norm{i + 1}", norm) 250 | 251 | layers = [[m for m in getattr(self, f'block{i + 1}')] for i in range(self.num_stages)] 252 | layers = sum(layers, []) 253 | self.initialize_hooks(layers) 254 | 255 | # classification head 256 | self.head = nn.Linear(self.hidden_dim[-1], self.num_classes) 257 | self.apply(self._init_weights) 258 | 259 | def _feature_hook(self, module, inp, out): 260 | _, H, W = inp 261 | feat = out.view(out.size(0), H, W, out.size(-1)) 262 | feat = feat.permute(0, 3, 1, 2).contiguous() 263 | self.features.append(feat) 264 | 265 | def _init_weights(self, m): 266 | if isinstance(m, nn.Linear): 267 | trunc_normal_(m.weight, std=.02) 268 | if isinstance(m, nn.Linear) and m.bias is not None: 269 | nn.init.constant_(m.bias, 0) 270 | elif isinstance(m, nn.LayerNorm): 271 | nn.init.constant_(m.bias, 0) 272 | nn.init.constant_(m.weight, 1.0) 273 | elif isinstance(m, nn.Conv2d): 274 | fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 275 | fan_out //= m.groups 276 | m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) 277 | if m.bias is not None: 278 | m.bias.data.zero_() 279 | 280 | def forward_features(self, x): 281 | B = x.shape[0] 282 | 283 | for i in range(self.num_stages): 284 | patch_embed = getattr(self, f"patch_embed{i + 1}") 285 | block = getattr(self, f"block{i + 1}") 286 | norm = getattr(self, f"norm{i + 1}") 287 | x, H, W = patch_embed(x) 288 | for blk in block: 289 | x = blk(x, H, W) 290 | x = norm(x) 291 | if i != self.num_stages - 1: 292 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 293 | 294 | return x.mean(dim=1) 295 | 296 | def forward(self, x): 297 | x = self.forward_features(x) 298 | x = self.head(x) 299 | 300 | return x 301 | 302 | 303 | class DWConv(nn.Module): 304 | def __init__(self, dim=768): 305 | super(DWConv, self).__init__() 306 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 307 | 308 | def forward(self, x, H, W): 309 | B, N, C = x.shape 310 | x = x.transpose(1, 2).view(B, C, H, W) 311 | x = self.dwconv(x) 312 | x = x.flatten(2).transpose(1, 2) 313 | 314 | return x 315 | -------------------------------------------------------------------------------- /pycls/models/transformers/t2t_vit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) [2012]-[2021] Shanghai Yitu Technology Co., Ltd. 2 | # 3 | # This source code is licensed under the Clear BSD License 4 | # LICENSE file in the root directory of this file 5 | # All rights reserved. 6 | 7 | """ 8 | Modified from the official implementation of T2T-ViT. 9 | https://github.com/yitu-opensource/T2T-ViT/blob/main/models/t2t_vit.py 10 | """ 11 | 12 | import math 13 | import torch 14 | import numpy as np 15 | import torch.nn as nn 16 | 17 | from ..build import MODEL 18 | from pycls.core.config import cfg 19 | from .base import BaseTransformerModel 20 | from .common import MLP, TransformerLayer, layernorm 21 | 22 | 23 | class PerformerAttention(nn.Module): 24 | 25 | def __init__(self, 26 | in_channels, 27 | out_channels, 28 | num_heads, 29 | drop_rate=0.1, 30 | kernel_ratio=0.5): 31 | super(PerformerAttention, self).__init__() 32 | assert out_channels % num_heads == 0 33 | self.in_channels = in_channels 34 | self.out_channels = out_channels 35 | self.num_heads = num_heads 36 | self.head_channels = out_channels // num_heads 37 | 38 | self.qkv_transform = nn.Linear(in_channels, out_channels * 3) 39 | self.projection = nn.Linear(out_channels, out_channels) 40 | self.dropout = nn.Dropout(drop_rate) 41 | 42 | self.m = int(out_channels * kernel_ratio) 43 | self.w = torch.randn(self.m, out_channels) 44 | self.w = nn.Parameter(nn.init.orthogonal_(self.w) * math.sqrt(self.m), requires_grad=False) 45 | 46 | def prm_exp(self, x): 47 | xd = ((x * x).sum(dim=-1, keepdim=True)).repeat(1, 1, self.m) / 2 48 | wtx = torch.einsum('bti,mi->btm', x.float(), self.w) 49 | return torch.exp(wtx - xd) / math.sqrt(self.m) 50 | 51 | def forward(self, x): 52 | k, q, v = torch.split(self.qkv_transform(x), self.head_channels, dim=-1) 53 | kp, qp = self.prm_exp(k), self.prm_exp(q) 54 | D = torch.einsum('bti,bi->bt', qp, kp.sum(dim=1)).unsqueeze(dim=2) 55 | kptv = torch.einsum('bin,bim->bnm', v.float(), kp) 56 | y = torch.einsum('bti,bni->btn', qp, kptv) / (D.repeat(1, 1, self.head_channels) + 1e-8) 57 | y = self.dropout(self.projection(y)) 58 | if self.in_channels != self.out_channels: 59 | y = v + y 60 | return y 61 | 62 | 63 | class PerformerLayer(nn.Module): 64 | 65 | def __init__(self, 66 | in_channels, 67 | num_heads, 68 | out_channels=None, 69 | hidden_ratio=1., 70 | drop_rate=0., 71 | kernel_ratio=0.5): 72 | super(PerformerLayer, self).__init__() 73 | if out_channels is None: 74 | out_channels = in_channels 75 | self.in_channels = in_channels 76 | self.out_channels = out_channels 77 | 78 | self.norm1 = layernorm(in_channels) 79 | self.attn = PerformerAttention( 80 | in_channels=in_channels, 81 | out_channels=out_channels, 82 | num_heads=num_heads, 83 | drop_rate=drop_rate, 84 | kernel_ratio=kernel_ratio) 85 | self.norm2 = layernorm(out_channels) 86 | self.mlp = MLP( 87 | in_channels=out_channels, 88 | out_channels=out_channels, 89 | drop_rate=drop_rate, 90 | hidden_ratio=hidden_ratio) 91 | 92 | def forward(self, x): 93 | if self.in_channels == self.out_channels: 94 | x = x + self.attn(self.norm1(x)) 95 | else: 96 | x = self.attn(self.norm1(x)) 97 | x = x + self.mlp(self.norm2(x)) 98 | return x 99 | 100 | 101 | class Token2TokenModule(nn.Module): 102 | 103 | def __init__(self, 104 | in_channels, 105 | out_channels, 106 | img_size): 107 | super(Token2TokenModule, self).__init__() 108 | self.in_channels = in_channels 109 | self.out_channels = out_channels 110 | self.img_size = (img_size, img_size) 111 | self.token_channels = cfg.T2T.TOKEN_DIM 112 | self.kernel_size = cfg.T2T.KERNEL_SIZE 113 | self.stride = cfg.T2T.STRIDE 114 | self.padding = cfg.T2T.PADDING 115 | assert len(self.kernel_size) == len(self.stride) 116 | 117 | self.soft_split0 = nn.Unfold( 118 | kernel_size=self.kernel_size[0], 119 | stride=self.stride[0], 120 | padding=self.padding[0]) 121 | 122 | self.soft_split = nn.ModuleList() 123 | self.attention = nn.ModuleList() 124 | cur_channels = in_channels * self.kernel_size[0] ** 2 125 | for i in range(1, len(self.kernel_size)): 126 | soft_split, attention = self._make_layer( 127 | in_channels=cur_channels, 128 | out_channels=self.token_channels, 129 | kernel_size=self.kernel_size[i], 130 | stride=self.stride[i], 131 | padding=self.padding[i]) 132 | self.soft_split.append(soft_split) 133 | self.attention.append(attention) 134 | cur_channels = self.token_channels * self.kernel_size[i] ** 2 135 | self.projection = nn.Linear(cur_channels, out_channels) 136 | 137 | def _make_layer(self, in_channels, out_channels, kernel_size, stride, padding): 138 | soft_split = nn.Unfold( 139 | kernel_size=kernel_size, 140 | stride=stride, 141 | padding=padding) 142 | attention = PerformerLayer( 143 | in_channels=in_channels, 144 | out_channels=out_channels, 145 | num_heads=1, 146 | hidden_ratio=1, 147 | kernel_ratio=0.5) 148 | return soft_split, attention 149 | 150 | def forward(self, x): 151 | H, W = x.shape[-2:] 152 | ratio = H / W 153 | x = self.soft_split0(x).transpose(-1, -2) 154 | for attention, soft_split in zip(self.attention, self.soft_split): 155 | x = attention(x).transpose(-1, -2) 156 | N, C, L = x.shape 157 | W = int(L ** 0.5 / ratio) 158 | H = L // W 159 | x = x.view(N, C, H, W) 160 | x = soft_split(x).transpose(-1, -2) 161 | x = self.projection(x) 162 | return x 163 | 164 | 165 | @MODEL.register() 166 | class T2TViT(BaseTransformerModel): 167 | 168 | def __init__(self): 169 | super(T2TViT, self).__init__() 170 | self.cls_token = nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) 171 | 172 | self.t2t_module = Token2TokenModule( 173 | in_channels=self.in_channels, 174 | out_channels=self.hidden_dim, 175 | img_size=self.img_size) 176 | 177 | feat_size = self.img_size 178 | for stride in cfg.T2T.STRIDE: 179 | feat_size = feat_size // stride 180 | self.num_patches = feat_size ** 2 181 | pe = self._get_position_embedding(self.num_patches + 1, self.hidden_dim) 182 | self.pos_embed = nn.Parameter(pe, requires_grad=False) 183 | self.pe_dropout = nn.Dropout(self.drop_rate) 184 | 185 | self.layers = nn.ModuleList() 186 | dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, self.depth)] 187 | self.layers.extend([TransformerLayer( 188 | in_channels=self.hidden_dim, 189 | num_heads=self.num_heads, 190 | mlp_ratio=self.mlp_ratio, 191 | drop_rate=self.drop_rate, 192 | attn_drop_rate=self.attn_drop_rate, 193 | drop_path_rate=dpr[i]) for i in range(self.depth)]) 194 | self.initialize_hooks(self.layers) 195 | 196 | self.norm = layernorm(self.hidden_dim) 197 | self.head = nn.Linear(self.hidden_dim, self.num_classes) 198 | 199 | nn.init.normal_(self.cls_token, std=.02) 200 | self.apply(self._init_weights) 201 | 202 | def _feature_hook(self, module, inputs, outputs): 203 | feat_size = int(self.num_patches ** 0.5) 204 | x = outputs[:, 1:].view(outputs.size(0), feat_size, feat_size, self.hidden_dim) 205 | x = x.permute(0, 3, 1, 2).contiguous() 206 | self.features.append(x) 207 | 208 | def _init_weights(self, m): 209 | if isinstance(m, nn.Linear): 210 | nn.init.normal_(m.weight, std=.02) 211 | if m.bias is not None: 212 | nn.init.constant_(m.bias, 0) 213 | elif isinstance(m, nn.LayerNorm): 214 | nn.init.constant_(m.bias, 0) 215 | nn.init.constant_(m.weight, 1.0) 216 | 217 | def _get_position_embedding(self, n_position, d_hid): 218 | 219 | def get_position_angle_vec(position): 220 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] 221 | 222 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) 223 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 224 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 225 | 226 | return torch.FloatTensor(sinusoid_table).unsqueeze(0) 227 | 228 | def forward(self, x): 229 | x = self.t2t_module(x) 230 | x = torch.cat([self.cls_token.repeat(x.size(0), 1, 1), x], dim=1) 231 | x = self.pe_dropout(x + self.pos_embed) 232 | 233 | for layer in self.layers: 234 | x = layer(x) 235 | 236 | x = self.norm(x) 237 | x = self.head(x[:, 0]) 238 | 239 | return x 240 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | black==19.3b0 2 | isort==4.3.21 3 | iopath 4 | flake8 5 | pyyaml 6 | matplotlib 7 | numpy 8 | opencv-python==4.2.0.34 9 | parameterized 10 | setuptools 11 | simplejson 12 | submitit 13 | yacs 14 | yattag 15 | einops 16 | scipy 17 | tensorboard 18 | timm -------------------------------------------------------------------------------- /run_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # Copyright (c) Facebook, Inc. and its affiliates. 4 | # 5 | # This source code is licensed under the MIT license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Modified from pycls. 10 | https://github.com/facebookresearch/pycls/blob/main/tools/run_net.py 11 | """ 12 | 13 | import argparse 14 | import sys 15 | import os 16 | 17 | import pycls.core.config as config 18 | import pycls.core.distributed as dist 19 | import pycls.core.trainer as trainer 20 | from pycls.core.config import cfg 21 | 22 | 23 | def parse_args(): 24 | parser = argparse.ArgumentParser(description="Run a model.") 25 | help_s, choices = "Run mode", ["train", "test", "time"] 26 | parser.add_argument("--mode", help=help_s, choices=choices, required=True, type=str) 27 | help_s = "Config file location" 28 | parser.add_argument("--cfg", help=help_s, required=True, type=str) 29 | help_s = "See pycls/core/config.py for all options" 30 | parser.add_argument("opts", help=help_s, default=None, nargs=argparse.REMAINDER) 31 | if len(sys.argv) == 1: 32 | parser.print_help() 33 | sys.exit(1) 34 | return parser.parse_args() 35 | 36 | 37 | def main(): 38 | args = parse_args() 39 | mode = args.mode 40 | config.load_cfg(args.cfg) 41 | cfg.merge_from_list(args.opts) 42 | if cfg.OUT_DIR is None: 43 | out_dir = os.path.join('work_dirs', os.path.splitext(os.path.basename(args.cfg))[0]) 44 | cfg.OUT_DIR = out_dir 45 | config.assert_cfg() 46 | cfg.freeze() 47 | if not os.path.exists(cfg.OUT_DIR): 48 | os.makedirs(cfg.OUT_DIR) 49 | if mode == "train": 50 | dist.multi_proc_run(num_proc=cfg.NUM_GPUS, fun=trainer.train_model) 51 | elif mode == "test": 52 | dist.multi_proc_run(num_proc=cfg.NUM_GPUS, fun=trainer.test_model) 53 | elif mode == "time": 54 | dist.multi_proc_run(num_proc=cfg.NUM_GPUS, fun=trainer.time_model) 55 | 56 | 57 | if __name__ == "__main__": 58 | main() 59 | --------------------------------------------------------------------------------