├── .gitmodules ├── LICENSE ├── README.md ├── assests └── architecture.png └── detection ├── README.md ├── configs ├── _base_ │ ├── datasets │ │ ├── coco_detection.py │ │ ├── coco_instance.py │ │ └── coco_instance_ms.py │ ├── default_runtime.py │ ├── models │ │ └── mask_rcnn_r50_fpn.py │ └── schedules │ │ ├── schedule_1x.py │ │ ├── schedule_20e.py │ │ └── schedule_2x.py ├── mask_rcnn_lightvit_base_fpn_1x_coco.py ├── mask_rcnn_lightvit_base_fpn_3x_ms_coco.py ├── mask_rcnn_lightvit_small_fpn_1x_coco.py ├── mask_rcnn_lightvit_small_fpn_3x_ms_coco.py ├── mask_rcnn_lightvit_tiny_fpn_1x_coco.py └── mask_rcnn_lightvit_tiny_fpn_3x_ms_coco.py ├── dist_test.sh ├── dist_train.sh ├── get_flops.py ├── lightvit.py ├── lightvit_fpn.py ├── slurm_test.sh ├── slurm_train.sh ├── test.py └── train.py /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "classification"] 2 | path = classification 3 | url = https://github.com/hunto/image_classification_sota.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2022 LightViT contributors 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LightViT 2 | Official implementation for paper "[LightViT: Towards Light-Weight Convolution-Free Vision Transformers](https://arxiv.org/abs/2207.05557)". 3 | 4 | By Tao Huang, Lang Huang, Shan You, Fei Wang, Chen Qian, Chang Xu. 5 | 6 | 7 | ## Updates 8 | ### July 26, 2022 9 | Code for COCO detection was released. 10 | 11 | ### July 14, 2022 12 | Code for ImageNet training was released. 13 | 14 | 15 | ## Introduction 16 |

17 | mask 18 |

19 | 20 | ## Results on ImageNet-1K 21 | 22 | |model|resolution|acc@1|acc@5|#params|FLOPs|ckpt|log| 23 | |:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:| 24 | |LightViT-T|224x224|78.7|94.4|9.4M|0.7G|[google drive](https://drive.google.com/file/d/1NasAnHYK6bkmj0-rzYvxQSEUiKAOW3j9/view?usp=sharing)|[log](https://github.com/hunto/LightViT/releases/download/v0.0.1/log_lightvit_tiny.csv)| 25 | |LightViT-S|224x224|80.9|95.3|19.2M|1.7G|[google drive](https://drive.google.com/file/d/16EZUth-wZ7rKR6tI67_Smp6GGzaODp16/view?usp=sharing)|[log](https://github.com/hunto/LightViT/releases/download/v0.0.1/log_lightvit_small.csv)| 26 | |LightViT-B|224x224|82.1|95.9|35.2M|3.9G|[google drive](https://drive.google.com/file/d/1MpVvfo8AiJMmz8CJZ7lBnDyoBJUEDxeL/view?usp=sharing)|[log](https://github.com/hunto/LightViT/releases/download/v0.0.1/log_lightvit_base.csv)| 27 | 28 | ### Preparation 29 | 1. Clone training code 30 | ```shell 31 | git clone https://github.com/hunto/LightViT.git --recurse-submodules 32 | cd LightViT/classification 33 | ``` 34 | 35 | **The code of LightViT model can be found in [lib/models/lightvit.py](https://github.com/hunto/image_classification_sota/blob/main/lib/models/lightvit.py)** . 36 | 37 | 2. Requirements 38 | ```shell 39 | torch>=1.3.0 40 | # if you want to use torch.cuda.amp for mixed-precision training, the lowest torch version is 1.5.0 41 | timm==0.5.4 42 | ``` 43 | 3. Prepare your datasets following [this link](https://github.com/hunto/image_classification_sota#prepare-datasets). 44 | 45 | ### Evaluation 46 | You can evaluate our results using the provided checkpoints. First download the checkpoints into your machine, then run 47 | ```shell 48 | sh tools/dist_run.sh tools/test.py ${NUM_GPUS} configs/strategies/lightvit/config.yaml timm_lightvit_tiny --drop-path-rate 0.1 --experiment lightvit_tiny_test --resume ${ckpt_file_path} 49 | ``` 50 | 51 | ### Train from scratch on ImageNet-1K 52 | ```shell 53 | sh tools/dist_train.sh 8 configs/strategies/lightvit/config.yaml ${MODEL} --drop-path-rate 0.1 --experiment lightvit_tiny 54 | ``` 55 | ${MODEL} can be `timm_lightvit_tiny`, `timm_lightvit_small`, `timm_lightvit_base` . 56 | 57 | For `timm_lightvit_base`, we added `--amp` option to use mixed-precision training, and **set `drop_path_rate` to 0.3**. 58 | 59 | ### Throughput 60 | ```shell 61 | sh tools/dist_run.sh tools/speed_test.py 1 configs/strategies/lightvit/config.yaml ${MODEL} --drop-path-rate 0.1 --batch-size 1024 62 | ``` 63 | or 64 | ```shell 65 | python tools/speed_test.py -c configs/strategies/lightvit/config.yaml --model ${MODEL} --drop-path-rate 0.1 --batch-size 1024 66 | ``` 67 | 68 | ## Results on COCO 69 | We conducted experiments on COCO object detection & instance segmentation tasks, see [detection/README.md](detection/README.md) for details. 70 | 71 | ## License 72 | This project is released under the [Apache 2.0 license](LICENSE). 73 | 74 | ## Citation 75 | ``` 76 | @article{huang2022lightvit, 77 | title = {LightViT: Towards Light-Weight Convolution-Free Vision Transformers}, 78 | author = {Huang, Tao and Huang, Lang and You, Shan and Wang, Fei and Qian, Chen and Xu, Chang}, 79 | journal = {arXiv preprint arXiv:2207.05557}, 80 | year = {2022} 81 | } 82 | ``` -------------------------------------------------------------------------------- /assests/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hunto/LightViT/e2452a17dcc06d426eb8979c4e6b3a5678091368/assests/architecture.png -------------------------------------------------------------------------------- /detection/README.md: -------------------------------------------------------------------------------- 1 | # Applying LightViT to Object Detection 2 | 3 | ## Usage 4 | ## Preparations 5 | * Install mmdetection `v2.14.0` and mmcv-full 6 | ```shell 7 | pip install mmdet==2.14.0 8 | ``` 9 | For mmcv-full, please check [here](https://mmdetection.readthedocs.io/en/latest/get_started.html#install-mmdetection) and [compatibility](https://mmdetection.readthedocs.io/en/latest/compatibility.html). 10 | 11 | * Put the COCO dataset into `./data` folder following [[this url]](https://mmdetection.readthedocs.io/en/latest/1_exist_data_model.html#prepare-datasets). 12 | 13 | **Note:** if you use a different mmdet version, please replace the `configs`, `train.py`, and `test.py` with the corresponding files of the version, then add `import lightvit` and `import lightvit_fpn` into `train.py` to register the model. 14 | 15 | ## Train Mask-RCNN with LightViT backbones 16 | |backbone|params|FLOPs|setting|box AP|mask AP|config|log|ckpt| 17 | |:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:| 18 | |LightViT-T|28|187|1x|37.8|35.9|[config](configs/mask_rcnn_lightvit_tiny_fpn_1x_coco.py)|[log](https://github.com/hunto/LightViT/releases/download/v0.0.1/mask_rcnn_lightvit_tiny_fpn_1x_coco.log.json)|[google drive](https://drive.google.com/file/d/15sSit0VpcsmL3Pr-fdjM6nRdmSZre0zh/view?usp=sharing)| 19 | |LightViT-S|38|204|1x|40.0|37.4|[config](configs/mask_rcnn_lightvit_small_fpn_1x_coco.py)|[log](https://github.com/hunto/LightViT/releases/download/v0.0.1/mask_rcnn_lightvit_small_fpn_1x_coco.log.json)|[google drive](https://drive.google.com/file/d/146RWYNShe5IrltvmgWoelj27LaRQBcXE/view?usp=sharing)| 20 | |LightViT-B|54|240|1x|41.7|38.8|[config](configs/mask_rcnn_lightvit_base_fpn_1x_coco.py)|[log](https://github.com/hunto/LightViT/releases/download/v0.0.1/mask_rcnn_lightvit_base_fpn_1x_coco.log.json)|[google drive](https://drive.google.com/file/d/19IHkVO_Qy__NHe0syrUpIKrhGNoS9VCd/view?usp=sharing)| 21 | |LightViT-T|28|187|3x+ms|41.5|38.4|[config](configs/mask_rcnn_lightvit_tiny_fpn_3x_ms_coco.py)|[log](https://github.com/hunto/LightViT/releases/download/v0.0.1/mask_rcnn_lightvit_tiny_fpn_3x_ms_coco.log.json)|[google drive](https://drive.google.com/file/d/1EfhtuCVCoprbwb_6W-32z_po4hGvAzu8/view?usp=sharing)| 22 | |LightViT-S|38|204|3x+ms|43.2|39.9|[config](configs/mask_rcnn_lightvit_small_fpn_3x_ms_coco.py)|[log](https://github.com/hunto/LightViT/releases/download/v0.0.1/mask_rcnn_lightvit_small_fpn_3x_ms_coco.log.json)|[google drive](https://drive.google.com/file/d/16l5lXYwz3Anx28ncqPfc3mir-nyZyRd6/view?usp=sharing)| 23 | |LightViT-B|54|240|3x+ms|45.0|41.2|[config](configs/mask_rcnn_lightvit_base_fpn_3x_ms_coco.py)|[log](https://github.com/hunto/LightViT/releases/download/v0.0.1/mask_rcnn_lightvit_base_fpn_3x_ms_coco.log.json)|[google drive](https://drive.google.com/file/d/1hLKdemruEKW2DO0227ZNm1zSZCqMWVfi/view?usp=sharing)| 24 | 25 | 26 | ### Training script: 27 | 28 | ```shell 29 | sh dist_train.sh ${CONFIG} 8 work_dirs/${EXP_NAME} 30 | ``` -------------------------------------------------------------------------------- /detection/configs/_base_/datasets/coco_detection.py: -------------------------------------------------------------------------------- 1 | dataset_type = 'CocoDataset' 2 | data_root = 'data/coco/' 3 | img_norm_cfg = dict( 4 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 5 | train_pipeline = [ 6 | dict(type='LoadImageFromFile'), 7 | dict(type='LoadAnnotations', with_bbox=True), 8 | dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), 9 | dict(type='RandomFlip', flip_ratio=0.5), 10 | dict(type='Normalize', **img_norm_cfg), 11 | dict(type='Pad', size_divisor=32), 12 | dict(type='DefaultFormatBundle'), 13 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), 14 | ] 15 | test_pipeline = [ 16 | dict(type='LoadImageFromFile'), 17 | dict( 18 | type='MultiScaleFlipAug', 19 | img_scale=(1333, 800), 20 | flip=False, 21 | transforms=[ 22 | dict(type='Resize', keep_ratio=True), 23 | dict(type='RandomFlip'), 24 | dict(type='Normalize', **img_norm_cfg), 25 | dict(type='Pad', size_divisor=32), 26 | dict(type='ImageToTensor', keys=['img']), 27 | dict(type='Collect', keys=['img']), 28 | ]) 29 | ] 30 | data = dict( 31 | samples_per_gpu=2, 32 | workers_per_gpu=6, 33 | train=dict( 34 | type=dataset_type, 35 | ann_file=data_root + 'annotations/instances_train2017.json', 36 | img_prefix=data_root + 'train2017/', 37 | pipeline=train_pipeline), 38 | val=dict( 39 | type=dataset_type, 40 | ann_file=data_root + 'annotations/instances_val2017.json', 41 | img_prefix=data_root + 'val2017/', 42 | pipeline=test_pipeline), 43 | test=dict( 44 | type=dataset_type, 45 | ann_file=data_root + 'annotations/instances_val2017.json', 46 | img_prefix=data_root + 'val2017/', 47 | pipeline=test_pipeline)) 48 | evaluation = dict(interval=1, metric='bbox') 49 | -------------------------------------------------------------------------------- /detection/configs/_base_/datasets/coco_instance.py: -------------------------------------------------------------------------------- 1 | dataset_type = 'CocoDataset' 2 | data_root = 'data/coco/' 3 | img_norm_cfg = dict( 4 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 5 | train_pipeline = [ 6 | dict(type='LoadImageFromFile'), 7 | dict(type='LoadAnnotations', with_bbox=True, with_mask=True), 8 | dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), 9 | dict(type='RandomFlip', flip_ratio=0.5), 10 | dict(type='Normalize', **img_norm_cfg), 11 | dict(type='Pad', size_divisor=32), 12 | dict(type='DefaultFormatBundle'), 13 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), 14 | ] 15 | test_pipeline = [ 16 | dict(type='LoadImageFromFile'), 17 | dict( 18 | type='MultiScaleFlipAug', 19 | img_scale=(1333, 800), 20 | flip=False, 21 | transforms=[ 22 | dict(type='Resize', keep_ratio=True), 23 | dict(type='RandomFlip'), 24 | dict(type='Normalize', **img_norm_cfg), 25 | dict(type='Pad', size_divisor=32), 26 | dict(type='ImageToTensor', keys=['img']), 27 | dict(type='Collect', keys=['img']), 28 | ]) 29 | ] 30 | data = dict( 31 | samples_per_gpu=2, 32 | workers_per_gpu=6, 33 | train=dict( 34 | type=dataset_type, 35 | ann_file=data_root + 'annotations/instances_train2017.json', 36 | img_prefix=data_root + 'train2017/', 37 | pipeline=train_pipeline), 38 | val=dict( 39 | type=dataset_type, 40 | ann_file=data_root + 'annotations/instances_val2017.json', 41 | img_prefix=data_root + 'val2017/', 42 | pipeline=test_pipeline), 43 | test=dict( 44 | type=dataset_type, 45 | ann_file=data_root + 'annotations/instances_val2017.json', 46 | img_prefix=data_root + 'val2017/', 47 | pipeline=test_pipeline)) 48 | evaluation = dict(metric=['bbox', 'segm']) 49 | -------------------------------------------------------------------------------- /detection/configs/_base_/datasets/coco_instance_ms.py: -------------------------------------------------------------------------------- 1 | 2 | img_norm_cfg = dict( 3 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 4 | 5 | # augmentation strategy originates from DETR / Sparse RCNN 6 | train_pipeline = [ 7 | dict(type='LoadImageFromFile'), 8 | dict(type='LoadAnnotations', with_bbox=True, with_mask=True), 9 | dict(type='RandomFlip', flip_ratio=0.5), 10 | dict(type='AutoAugment', 11 | policies=[ 12 | [ 13 | dict(type='Resize', 14 | img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), 15 | (608, 1333), (640, 1333), (672, 1333), (704, 1333), 16 | (736, 1333), (768, 1333), (800, 1333)], 17 | multiscale_mode='value', 18 | keep_ratio=True) 19 | ], 20 | [ 21 | dict(type='Resize', 22 | img_scale=[(400, 1333), (500, 1333), (600, 1333)], 23 | multiscale_mode='value', 24 | keep_ratio=True), 25 | dict(type='RandomCrop', 26 | crop_type='absolute_range', 27 | crop_size=(384, 600), 28 | allow_negative_crop=True), 29 | dict(type='Resize', 30 | img_scale=[(480, 1333), (512, 1333), (544, 1333), 31 | (576, 1333), (608, 1333), (640, 1333), 32 | (672, 1333), (704, 1333), (736, 1333), 33 | (768, 1333), (800, 1333)], 34 | multiscale_mode='value', 35 | override=True, 36 | keep_ratio=True) 37 | ] 38 | ]), 39 | dict(type='Normalize', **img_norm_cfg), 40 | dict(type='Pad', size_divisor=32), 41 | dict(type='DefaultFormatBundle'), 42 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), 43 | ] 44 | #data = dict(train=dict(pipeline=train_pipeline)) 45 | dataset_type = 'CocoDataset' 46 | data_root = 'data/coco/' 47 | test_pipeline = [ 48 | dict(type='LoadImageFromFile'), 49 | dict( 50 | type='MultiScaleFlipAug', 51 | img_scale=(1333, 800), 52 | flip=False, 53 | transforms=[ 54 | dict(type='Resize', keep_ratio=True), 55 | dict(type='RandomFlip'), dict(type='Normalize', **img_norm_cfg), 56 | dict(type='Pad', size_divisor=32), 57 | dict(type='ImageToTensor', keys=['img']), 58 | dict(type='Collect', keys=['img']), 59 | ]) 60 | ] 61 | data = dict( 62 | samples_per_gpu=2, 63 | workers_per_gpu=6, 64 | train=dict( 65 | type=dataset_type, 66 | ann_file=data_root + 'annotations/instances_train2017.json', 67 | img_prefix=data_root + 'train2017/', 68 | pipeline=train_pipeline), 69 | val=dict( 70 | type=dataset_type, 71 | ann_file=data_root + 'annotations/instances_val2017.json', 72 | img_prefix=data_root + 'val2017/', 73 | pipeline=test_pipeline), 74 | test=dict( 75 | type=dataset_type, 76 | ann_file=data_root + 'annotations/instances_val2017.json', 77 | img_prefix=data_root + 'val2017/', 78 | pipeline=test_pipeline)) 79 | evaluation = dict(metric=['bbox', 'segm']) 80 | 81 | 82 | -------------------------------------------------------------------------------- /detection/configs/_base_/default_runtime.py: -------------------------------------------------------------------------------- 1 | checkpoint_config = dict(interval=1) 2 | # yapf:disable 3 | log_config = dict( 4 | interval=50, 5 | hooks=[ 6 | dict(type='TextLoggerHook'), 7 | # dict(type='TensorboardLoggerHook') 8 | ]) 9 | # yapf:enable 10 | dist_params = dict(backend='nccl') 11 | log_level = 'INFO' 12 | load_from = None 13 | resume_from = None 14 | workflow = [('train', 1)] 15 | #fp16 = dict(loss_scale=512.) 16 | find_unused_parameters = True 17 | -------------------------------------------------------------------------------- /detection/configs/_base_/models/mask_rcnn_r50_fpn.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | type='MaskRCNN', 4 | pretrained='torchvision://resnet50', 5 | backbone=dict( 6 | type='ResNet', 7 | depth=50, 8 | num_stages=4, 9 | out_indices=(0, 1, 2, 3), 10 | frozen_stages=1, 11 | norm_cfg=dict(type='BN', requires_grad=True), 12 | norm_eval=True, 13 | style='pytorch'), 14 | neck=dict( 15 | type='FPN', 16 | in_channels=[256, 512, 1024, 2048], 17 | out_channels=256, 18 | num_outs=5), 19 | rpn_head=dict( 20 | type='RPNHead', 21 | in_channels=256, 22 | feat_channels=256, 23 | anchor_generator=dict( 24 | type='AnchorGenerator', 25 | scales=[8], 26 | ratios=[0.5, 1.0, 2.0], 27 | strides=[4, 8, 16, 32, 64]), 28 | bbox_coder=dict( 29 | type='DeltaXYWHBBoxCoder', 30 | target_means=[.0, .0, .0, .0], 31 | target_stds=[1.0, 1.0, 1.0, 1.0]), 32 | loss_cls=dict( 33 | type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), 34 | loss_bbox=dict(type='L1Loss', loss_weight=1.0)), 35 | roi_head=dict( 36 | type='StandardRoIHead', 37 | bbox_roi_extractor=dict( 38 | type='SingleRoIExtractor', 39 | roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), 40 | out_channels=256, 41 | featmap_strides=[4, 8, 16, 32]), 42 | bbox_head=dict( 43 | type='Shared2FCBBoxHead', 44 | in_channels=256, 45 | fc_out_channels=1024, 46 | roi_feat_size=7, 47 | num_classes=80, 48 | bbox_coder=dict( 49 | type='DeltaXYWHBBoxCoder', 50 | target_means=[0., 0., 0., 0.], 51 | target_stds=[0.1, 0.1, 0.2, 0.2]), 52 | reg_class_agnostic=False, 53 | loss_cls=dict( 54 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), 55 | loss_bbox=dict(type='L1Loss', loss_weight=1.0)), 56 | mask_roi_extractor=dict( 57 | type='SingleRoIExtractor', 58 | roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0), 59 | out_channels=256, 60 | featmap_strides=[4, 8, 16, 32]), 61 | mask_head=dict( 62 | type='FCNMaskHead', 63 | num_convs=4, 64 | in_channels=256, 65 | conv_out_channels=256, 66 | num_classes=80, 67 | loss_mask=dict( 68 | type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))), 69 | # model training and testing settings 70 | train_cfg=dict( 71 | rpn=dict( 72 | assigner=dict( 73 | type='MaxIoUAssigner', 74 | pos_iou_thr=0.7, 75 | neg_iou_thr=0.3, 76 | min_pos_iou=0.3, 77 | match_low_quality=True, 78 | ignore_iof_thr=-1), 79 | sampler=dict( 80 | type='RandomSampler', 81 | num=256, 82 | pos_fraction=0.5, 83 | neg_pos_ub=-1, 84 | add_gt_as_proposals=False), 85 | allowed_border=-1, 86 | pos_weight=-1, 87 | debug=False), 88 | rpn_proposal=dict( 89 | nms_pre=2000, 90 | max_per_img=1000, 91 | nms=dict(type='nms', iou_threshold=0.7), 92 | min_bbox_size=0), 93 | rcnn=dict( 94 | assigner=dict( 95 | type='MaxIoUAssigner', 96 | pos_iou_thr=0.5, 97 | neg_iou_thr=0.5, 98 | min_pos_iou=0.5, 99 | match_low_quality=True, 100 | ignore_iof_thr=-1), 101 | sampler=dict( 102 | type='RandomSampler', 103 | num=512, 104 | pos_fraction=0.25, 105 | neg_pos_ub=-1, 106 | add_gt_as_proposals=True), 107 | mask_size=28, 108 | pos_weight=-1, 109 | debug=False)), 110 | test_cfg=dict( 111 | rpn=dict( 112 | nms_pre=1000, 113 | max_per_img=1000, 114 | nms=dict(type='nms', iou_threshold=0.7), 115 | min_bbox_size=0), 116 | rcnn=dict( 117 | score_thr=0.05, 118 | nms=dict(type='nms', iou_threshold=0.5), 119 | max_per_img=100, 120 | mask_thr_binary=0.5))) 121 | -------------------------------------------------------------------------------- /detection/configs/_base_/schedules/schedule_1x.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.02, weight_decay=0.0001) 3 | optimizer_config = dict(grad_clip=None) 4 | # learning policy 5 | lr_config = dict( 6 | policy='step', 7 | warmup='linear', 8 | warmup_iters=500, 9 | warmup_ratio=0.001, 10 | step=[8, 11]) 11 | runner = dict(type='EpochBasedRunner', max_epochs=12) 12 | -------------------------------------------------------------------------------- /detection/configs/_base_/schedules/schedule_20e.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) 3 | optimizer_config = dict(grad_clip=None) 4 | # learning policy 5 | lr_config = dict( 6 | policy='step', 7 | warmup='linear', 8 | warmup_iters=500, 9 | warmup_ratio=0.001, 10 | step=[16, 19]) 11 | runner = dict(type='EpochBasedRunner', max_epochs=20) 12 | -------------------------------------------------------------------------------- /detection/configs/_base_/schedules/schedule_2x.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) 3 | optimizer_config = dict(grad_clip=None) 4 | # learning policy 5 | lr_config = dict( 6 | policy='step', 7 | warmup='linear', 8 | warmup_iters=500, 9 | warmup_ratio=0.001, 10 | step=[16, 22]) 11 | runner = dict(type='EpochBasedRunner', max_epochs=24) 12 | -------------------------------------------------------------------------------- /detection/configs/mask_rcnn_lightvit_base_fpn_1x_coco.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '_base_/models/mask_rcnn_r50_fpn.py', 3 | '_base_/datasets/coco_instance.py', 4 | '_base_/default_runtime.py' 5 | ] 6 | model = dict( 7 | pretrained=None, 8 | backbone=dict( 9 | _delete_=True, 10 | init_cfg=dict(type='Pretrained', checkpoint='https://github.com/hunto/LightViT/releases/download/v0.0.1/lightvit_base_82.1.ckpt'), 11 | type='lightvit_base', 12 | drop_path_rate=0.1, 13 | out_indices=(0, 1, 2), 14 | stem_norm_eval=True, # fix the BN running stats of the stem layer 15 | ), 16 | neck=dict( 17 | type='LightViTFPN', 18 | in_channels=[128, 256, 512], 19 | out_channels=256, 20 | num_outs=5, 21 | num_extra_trans_convs=1, 22 | )) 23 | # data 24 | data = dict(samples_per_gpu=2) # 2 x 8 = 16 25 | # optimizer 26 | optimizer = dict(type='AdamW', lr=0.0002, betas=(0.9, 0.999), weight_decay=0.04, 27 | paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.), 28 | 'relative_position_bias_table': dict(decay_mult=0.), 29 | 'norm': dict(decay_mult=0.), 30 | 'stem.1': dict(decay_mult=0.), 31 | 'stem.4': dict(decay_mult=0.), 32 | 'stem.7': dict(decay_mult=0.), 33 | 'stem.10': dict(decay_mult=0.), 34 | 'global_token': dict(decay_mult=0.) 35 | })) 36 | # optimizer_config = dict(grad_clip=None) 37 | # do not use mmdet version fp16 38 | # fp16 = None 39 | optimizer_config = dict(grad_clip=None) 40 | # learning policy 41 | lr_config = dict( 42 | policy='step', 43 | warmup='linear', 44 | warmup_iters=500, 45 | warmup_ratio=0.001, 46 | step=[8, 11]) 47 | total_epochs = 12 48 | 49 | -------------------------------------------------------------------------------- /detection/configs/mask_rcnn_lightvit_base_fpn_3x_ms_coco.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '_base_/models/mask_rcnn_r50_fpn.py', 3 | '_base_/datasets/coco_instance_ms.py', 4 | '_base_/schedules/schedule_1x.py', 5 | '_base_/default_runtime.py' 6 | ] 7 | 8 | model = dict( 9 | pretrained=None, 10 | backbone=dict( 11 | _delete_=True, 12 | init_cfg=dict(type='Pretrained', checkpoint='https://github.com/hunto/LightViT/releases/download/v0.0.1/lightvit_base_82.1.ckpt'), 13 | type='lightvit_base', 14 | drop_path_rate=0.1, 15 | out_indices=(0, 1, 2), 16 | stem_norm_eval=True, # fix the BN running stats of the stem layer 17 | ), 18 | neck=dict( 19 | type='LightViTFPN', 20 | in_channels=[128, 256, 512], 21 | out_channels=256, 22 | num_outs=5, 23 | num_extra_trans_convs=1, 24 | )) 25 | 26 | data = dict(samples_per_gpu=2) 27 | 28 | # optimizer 29 | optimizer = dict(type='AdamW', lr=0.0002, betas=(0.9, 0.999), weight_decay=0.04, 30 | paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.), 31 | 'relative_position_bias_table': dict(decay_mult=0.), 32 | 'norm': dict(decay_mult=0.), 33 | 'stem.1': dict(decay_mult=0.), 34 | 'stem.4': dict(decay_mult=0.), 35 | 'stem.7': dict(decay_mult=0.), 36 | 'stem.10': dict(decay_mult=0.), 37 | 'global_token': dict(decay_mult=0.) 38 | })) 39 | 40 | # optimizer_config = dict(grad_clip=None) 41 | # do not use mmdet version fp16 42 | # fp16 = None 43 | # optimizer 44 | # learning policy 45 | lr_config = dict(step=[27, 33]) 46 | runner = dict(type='EpochBasedRunner', max_epochs=36) 47 | total_epochs = 36 48 | # learning policy 49 | lr_config = dict( 50 | policy='step', 51 | warmup='linear', 52 | warmup_iters=500, 53 | warmup_ratio=0.001, 54 | step=[27, 33]) 55 | total_epochs = 36 56 | 57 | -------------------------------------------------------------------------------- /detection/configs/mask_rcnn_lightvit_small_fpn_1x_coco.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '_base_/models/mask_rcnn_r50_fpn.py', 3 | '_base_/datasets/coco_instance.py', 4 | '_base_/default_runtime.py' 5 | ] 6 | model = dict( 7 | pretrained=None, 8 | backbone=dict( 9 | _delete_=True, 10 | type='lightvit_small', 11 | drop_path_rate=0.1, 12 | init_cfg=dict(type='Pretrained', checkpoint='https://github.com/hunto/LightViT/releases/download/v0.0.1/lightvit_small_80.9.ckpt'), 13 | out_indices=(0, 1, 2), 14 | stem_norm_eval=True, # fix the BN running stats of the stem layer 15 | ), 16 | neck=dict( 17 | type='LightViTFPN', 18 | in_channels=[96, 192, 384], 19 | out_channels=256, 20 | num_outs=5, 21 | num_extra_trans_convs=1, 22 | )) 23 | # data 24 | data = dict(samples_per_gpu=2) # 2 x 8 = 16 25 | # optimizer 26 | optimizer = dict(type='AdamW', lr=0.0002, betas=(0.9, 0.999), weight_decay=0.04, 27 | paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.), 28 | 'relative_position_bias_table': dict(decay_mult=0.), 29 | 'norm': dict(decay_mult=0.), 30 | 'stem.1': dict(decay_mult=0.), 31 | 'stem.4': dict(decay_mult=0.), 32 | 'stem.7': dict(decay_mult=0.), 33 | 'stem.10': dict(decay_mult=0.), 34 | 'global_token': dict(decay_mult=0.) 35 | })) 36 | # optimizer_config = dict(grad_clip=None) 37 | # do not use mmdet version fp16 38 | # fp16 = None 39 | optimizer_config = dict(grad_clip=None) 40 | # learning policy 41 | lr_config = dict( 42 | policy='step', 43 | warmup='linear', 44 | warmup_iters=500, 45 | warmup_ratio=0.001, 46 | step=[8, 11]) 47 | total_epochs = 12 48 | 49 | -------------------------------------------------------------------------------- /detection/configs/mask_rcnn_lightvit_small_fpn_3x_ms_coco.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '_base_/models/mask_rcnn_r50_fpn.py', 3 | '_base_/datasets/coco_instance_ms.py', 4 | '_base_/schedules/schedule_1x.py', 5 | '_base_/default_runtime.py' 6 | ] 7 | 8 | model = dict( 9 | pretrained=None, 10 | backbone=dict( 11 | _delete_=True, 12 | type='lightvit_small', 13 | drop_path_rate=0.1, 14 | init_cfg=dict(type='Pretrained', checkpoint='https://github.com/hunto/LightViT/releases/download/v0.0.1/lightvit_small_80.9.ckpt'), 15 | out_indices=(0, 1, 2), 16 | stem_norm_eval=True, # fix the BN running stats of the stem layer 17 | ), 18 | neck=dict( 19 | type='LightViTFPN', 20 | in_channels=[96, 192, 384], 21 | out_channels=256, 22 | num_outs=5, 23 | num_extra_trans_convs=1, 24 | )) 25 | 26 | data = dict(samples_per_gpu=2) 27 | 28 | # optimizer 29 | optimizer = dict(type='AdamW', lr=0.0002, betas=(0.9, 0.999), weight_decay=0.04, 30 | paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.), 31 | 'relative_position_bias_table': dict(decay_mult=0.), 32 | 'norm': dict(decay_mult=0.), 33 | 'stem.1': dict(decay_mult=0.), 34 | 'stem.4': dict(decay_mult=0.), 35 | 'stem.7': dict(decay_mult=0.), 36 | 'stem.10': dict(decay_mult=0.), 37 | 'global_token': dict(decay_mult=0.) 38 | })) 39 | # optimizer_config = dict(grad_clip=None) 40 | # do not use mmdet version fp16 41 | # fp16 = None 42 | # optimizer 43 | # learning policy 44 | lr_config = dict(step=[27, 33]) 45 | runner = dict(type='EpochBasedRunner', max_epochs=36) 46 | total_epochs = 36 47 | # learning policy 48 | lr_config = dict( 49 | policy='step', 50 | warmup='linear', 51 | warmup_iters=500, 52 | warmup_ratio=0.001, 53 | step=[27, 33]) 54 | total_epochs = 36 55 | 56 | -------------------------------------------------------------------------------- /detection/configs/mask_rcnn_lightvit_tiny_fpn_1x_coco.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '_base_/models/mask_rcnn_r50_fpn.py', 3 | '_base_/datasets/coco_instance.py', 4 | '_base_/default_runtime.py' 5 | ] 6 | model = dict( 7 | pretrained=None, 8 | backbone=dict( 9 | _delete_=True, 10 | init_cfg=dict(type='Pretrained', checkpoint='https://github.com/hunto/LightViT/releases/download/v0.0.1/lightvit_tiny_78.7.ckpt'), 11 | type='lightvit_tiny', 12 | drop_path_rate=0.1, 13 | out_indices=(0, 1, 2), 14 | stem_norm_eval=True, # fix the BN running stats of the stem layer 15 | ), 16 | neck=dict( 17 | type='LightViTFPN', 18 | in_channels=[64, 128, 256], 19 | out_channels=256, 20 | num_outs=5, 21 | num_extra_trans_convs=1, 22 | )) 23 | # data 24 | data = dict(samples_per_gpu=2) # 2 x 8 = 16 25 | # optimizer 26 | optimizer = dict(type='AdamW', lr=0.0002, betas=(0.9, 0.999), weight_decay=0.04, 27 | paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.), 28 | 'relative_position_bias_table': dict(decay_mult=0.), 29 | 'norm': dict(decay_mult=0.), 30 | 'stem.1': dict(decay_mult=0.), 31 | 'stem.4': dict(decay_mult=0.), 32 | 'stem.7': dict(decay_mult=0.), 33 | 'stem.10': dict(decay_mult=0.), 34 | 'global_token': dict(decay_mult=0.) 35 | })) 36 | # optimizer_config = dict(grad_clip=None) 37 | # do not use mmdet version fp16 38 | # fp16 = None 39 | optimizer_config = dict(grad_clip=None) 40 | # learning policy 41 | lr_config = dict( 42 | policy='step', 43 | warmup='linear', 44 | warmup_iters=500, 45 | warmup_ratio=0.001, 46 | step=[8, 11]) 47 | total_epochs = 12 48 | 49 | -------------------------------------------------------------------------------- /detection/configs/mask_rcnn_lightvit_tiny_fpn_3x_ms_coco.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '_base_/models/mask_rcnn_r50_fpn.py', 3 | '_base_/datasets/coco_instance_ms.py', 4 | '_base_/schedules/schedule_1x.py', 5 | '_base_/default_runtime.py' 6 | ] 7 | model = dict( 8 | pretrained=None, 9 | backbone=dict( 10 | _delete_=True, 11 | init_cfg=dict(type='Pretrained', checkpoint='https://github.com/hunto/LightViT/releases/download/v0.0.1/lightvit_tiny_78.7.ckpt'), 12 | type='lightvit_tiny', 13 | drop_path_rate=0.1, 14 | out_indices=(0, 1, 2), 15 | stem_norm_eval=True, # fix the BN running stats of the stem layer 16 | ), 17 | neck=dict( 18 | type='LightViTFPN', 19 | in_channels=[64, 128, 256], 20 | out_channels=256, 21 | num_outs=5, 22 | num_extra_trans_convs=1, 23 | )) 24 | 25 | # optimizer 26 | optimizer = dict(type='AdamW', lr=0.0002, betas=(0.9, 0.999), weight_decay=0.04, 27 | paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.), 28 | 'relative_position_bias_table': dict(decay_mult=0.), 29 | 'norm': dict(decay_mult=0.), 30 | 'stem.1': dict(decay_mult=0.), 31 | 'stem.4': dict(decay_mult=0.), 32 | 'stem.7': dict(decay_mult=0.), 33 | 'stem.10': dict(decay_mult=0.), 34 | 'global_token': dict(decay_mult=0.) 35 | })) 36 | # optimizer_config = dict(grad_clip=None) 37 | # do not use mmdet version fp16 38 | # fp16 = None 39 | # optimizer 40 | # learning policy 41 | lr_config = dict(step=[27, 33]) 42 | runner = dict(type='EpochBasedRunner', max_epochs=36) 43 | total_epochs = 36 44 | # learning policy 45 | lr_config = dict( 46 | policy='step', 47 | warmup='linear', 48 | warmup_iters=500, 49 | warmup_ratio=0.001, 50 | step=[27, 33]) 51 | total_epochs = 36 52 | 53 | -------------------------------------------------------------------------------- /detection/dist_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | CHECKPOINT=$2 5 | GPUS=$3 6 | PORT=${PORT:-29500} 7 | 8 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 9 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 10 | $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4} 11 | -------------------------------------------------------------------------------- /detection/dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | GPUS=$2 5 | PORT=${PORT:-29500} 6 | 7 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 8 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 9 | $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3} 10 | -------------------------------------------------------------------------------- /detection/get_flops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | 4 | import numpy as np 5 | import torch 6 | from mmcv import Config, DictAction 7 | 8 | from mmdet.models import build_detector 9 | 10 | import lightvit 11 | import lightvit_fpn 12 | 13 | try: 14 | from mmcv.cnn import get_model_complexity_info 15 | except ImportError: 16 | raise ImportError('Please upgrade mmcv to >0.6.2') 17 | 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser(description='Train a detector') 21 | parser.add_argument('config', help='train config file path') 22 | parser.add_argument( 23 | '--shape', 24 | type=int, 25 | nargs='+', 26 | default=[1280, 800], 27 | help='input image size') 28 | parser.add_argument( 29 | '--cfg-options', 30 | nargs='+', 31 | action=DictAction, 32 | help='override some settings in the used config, the key-value pair ' 33 | 'in xxx=yyy format will be merged into config file. If the value to ' 34 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 35 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 36 | 'Note that the quotation marks are necessary and that no white space ' 37 | 'is allowed.') 38 | parser.add_argument( 39 | '--size-divisor', 40 | type=int, 41 | default=32, 42 | help='Pad the input image, the minimum size that is divisible ' 43 | 'by size_divisor, -1 means do not pad the image.') 44 | args = parser.parse_args() 45 | return args 46 | 47 | 48 | def main(): 49 | 50 | args = parse_args() 51 | 52 | if len(args.shape) == 1: 53 | h = w = args.shape[0] 54 | elif len(args.shape) == 2: 55 | h, w = args.shape 56 | else: 57 | raise ValueError('invalid input shape') 58 | ori_shape = (3, h, w) 59 | divisor = args.size_divisor 60 | if divisor > 0: 61 | h = int(np.ceil(h / divisor)) * divisor 62 | w = int(np.ceil(w / divisor)) * divisor 63 | 64 | input_shape = (3, h, w) 65 | 66 | cfg = Config.fromfile(args.config) 67 | if args.cfg_options is not None: 68 | cfg.merge_from_dict(args.cfg_options) 69 | 70 | model = build_detector( 71 | cfg.model, 72 | train_cfg=cfg.get('train_cfg'), 73 | test_cfg=cfg.get('test_cfg')) 74 | if torch.cuda.is_available(): 75 | model.cuda() 76 | model.eval() 77 | 78 | if hasattr(model, 'forward_dummy'): 79 | model.forward = model.forward_dummy 80 | else: 81 | raise NotImplementedError( 82 | 'FLOPs counter is currently not currently supported with {}'. 83 | format(model.__class__.__name__)) 84 | 85 | flops, params = get_model_complexity_info(model, input_shape) 86 | split_line = '=' * 30 87 | 88 | if divisor > 0 and \ 89 | input_shape != ori_shape: 90 | print(f'{split_line}\nUse size divisor set input shape ' 91 | f'from {ori_shape} to {input_shape}\n') 92 | print(f'{split_line}\nInput shape: {input_shape}\n' 93 | f'Flops: {flops}\nParams: {params}\n{split_line}') 94 | print('!!!Please be cautious if you use the results in papers. ' 95 | 'You may need to check if all ops are supported and verify that the ' 96 | 'flops computation is correct.') 97 | 98 | 99 | if __name__ == '__main__': 100 | main() 101 | -------------------------------------------------------------------------------- /detection/lightvit.py: -------------------------------------------------------------------------------- 1 | import math 2 | from functools import partial 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.nn.modules.batchnorm import _BatchNorm 8 | 9 | from timm.models.layers import DropPath, trunc_normal_, drop_path 10 | 11 | from mmdet.models.builder import BACKBONES 12 | from mmcv.runner import (auto_fp16, force_fp32,) 13 | from mmcv.runner import BaseModule 14 | 15 | 16 | class ConvStem(nn.Module): 17 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_eval=False): 18 | super().__init__() 19 | 20 | self.patch_size = patch_size 21 | self.norm_eval = norm_eval 22 | 23 | stem_dim = embed_dim // 2 24 | self.stem = nn.Sequential( 25 | nn.Conv2d(in_chans, stem_dim, kernel_size=3, 26 | stride=2, padding=1, bias=False), 27 | nn.BatchNorm2d(stem_dim), 28 | nn.GELU(), 29 | nn.Conv2d(stem_dim, stem_dim, kernel_size=3, 30 | groups=stem_dim, stride=1, padding=1, bias=False), 31 | nn.BatchNorm2d(stem_dim), 32 | nn.GELU(), 33 | nn.Conv2d(stem_dim, stem_dim, kernel_size=3, 34 | groups=stem_dim, stride=1, padding=1, bias=False), 35 | nn.BatchNorm2d(stem_dim), 36 | nn.GELU(), 37 | nn.Conv2d(stem_dim, stem_dim, kernel_size=3, 38 | groups=stem_dim, stride=2, padding=1, bias=False), 39 | nn.BatchNorm2d(stem_dim), 40 | nn.GELU(), 41 | ) 42 | self.proj = nn.Conv2d(stem_dim, embed_dim, 43 | kernel_size=3, 44 | stride=2, padding=1) 45 | self.norm = nn.LayerNorm(embed_dim) 46 | 47 | def forward(self, x): 48 | stem = self.stem(x) 49 | x = self.proj(stem) 50 | _, _, H, W = x.shape 51 | x = x.flatten(2).transpose(1, 2) 52 | x = self.norm(x) 53 | return x, (H, W) 54 | 55 | def train(self, mode=True): 56 | """Convert the model into training mode while keep normalization layer 57 | freezed.""" 58 | super().train(mode) 59 | if mode and self.norm_eval: 60 | for m in self.modules(): 61 | if isinstance(m, _BatchNorm): 62 | m.eval() 63 | 64 | 65 | class BiAttn(nn.Module): 66 | def __init__(self, in_channels, act_ratio=0.25, act_fn=nn.GELU, gate_fn=nn.Sigmoid): 67 | super().__init__() 68 | reduce_channels = int(in_channels * act_ratio) 69 | self.norm = nn.LayerNorm(in_channels) 70 | self.global_reduce = nn.Linear(in_channels, reduce_channels) 71 | self.local_reduce = nn.Linear(in_channels, reduce_channels) 72 | self.act_fn = act_fn() 73 | self.channel_select = nn.Linear(reduce_channels, in_channels) 74 | self.spatial_select = nn.Linear(reduce_channels * 2, 1) 75 | self.gate_fn = gate_fn() 76 | 77 | def forward(self, x): 78 | ori_x = x 79 | x = self.norm(x) 80 | x_global = x.mean(1, keepdim=True) 81 | x_global = self.act_fn(self.global_reduce(x_global)) 82 | x_local = self.act_fn(self.local_reduce(x)) 83 | 84 | c_attn = self.channel_select(x_global) 85 | c_attn = self.gate_fn(c_attn) # [B, 1, C] 86 | s_attn = self.spatial_select(torch.cat([x_local, x_global.expand(-1, x.shape[1], -1)], dim=-1)) 87 | s_attn = self.gate_fn(s_attn) # [B, N, 1] 88 | 89 | attn = c_attn * s_attn # [B, N, C] 90 | return ori_x * attn 91 | 92 | 93 | class BiAttnMlp(nn.Module): 94 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 95 | super().__init__() 96 | out_features = out_features or in_features 97 | hidden_features = hidden_features or in_features 98 | self.fc1 = nn.Linear(in_features, hidden_features) 99 | self.act = act_layer() 100 | self.fc2 = nn.Linear(hidden_features, out_features) 101 | self.attn = BiAttn(out_features) 102 | self.drop = nn.Dropout(drop) if drop > 0 else nn.Identity() 103 | 104 | def forward(self, x): 105 | x = self.fc1(x) 106 | x = self.act(x) 107 | x = self.drop(x) 108 | x = self.fc2(x) 109 | x = self.attn(x) 110 | x = self.drop(x) 111 | return x 112 | 113 | 114 | def window_reverse( 115 | windows: torch.Tensor, 116 | original_size, 117 | window_size=(7, 7) 118 | ) -> torch.Tensor: 119 | """ Reverses the window partition. 120 | Args: 121 | windows (torch.Tensor): Window tensor of the shape [B * windows, window_size[0] * window_size[1], C]. 122 | original_size (Tuple[int, int]): Original shape. 123 | window_size (Tuple[int, int], optional): Window size which have been applied. Default (7, 7) 124 | Returns: 125 | output (torch.Tensor): Folded output tensor of the shape [B, original_size[0] * original_size[1], C]. 126 | """ 127 | # Get height and width 128 | H, W = original_size 129 | # Compute original batch size 130 | B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1])) 131 | # Fold grid tensor 132 | output = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1) 133 | output = output.permute(0, 1, 3, 2, 4, 5).reshape(B, H * W, -1) 134 | return output 135 | 136 | 137 | def get_relative_position_index( 138 | win_h: int, 139 | win_w: int 140 | ) -> torch.Tensor: 141 | """ Function to generate pair-wise relative position index for each token inside the window. 142 | Taken from Timms Swin V1 implementation. 143 | Args: 144 | win_h (int): Window/Grid height. 145 | win_w (int): Window/Grid width. 146 | Returns: 147 | relative_coords (torch.Tensor): Pair-wise relative position indexes [height * width, height * width]. 148 | """ 149 | coords = torch.stack(torch.meshgrid([torch.arange(win_h), torch.arange(win_w)])) 150 | coords_flatten = torch.flatten(coords, 1) 151 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] 152 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() 153 | relative_coords[:, :, 0] += win_h - 1 154 | relative_coords[:, :, 1] += win_w - 1 155 | relative_coords[:, :, 0] *= 2 * win_w - 1 156 | return relative_coords.sum(-1) 157 | 158 | 159 | class Attention(nn.Module): 160 | def __init__(self, dim, num_tokens=1, num_heads=8, window_size=7, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 161 | super().__init__() 162 | self.num_heads = num_heads 163 | head_dim = dim // num_heads 164 | self.num_tokens = num_tokens 165 | self.window_size = window_size 166 | self.attn_area = window_size * window_size 167 | self.scale = qk_scale or head_dim ** -0.5 168 | 169 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 170 | self.kv_global = nn.Linear(dim, dim * 2, bias=qkv_bias) 171 | self.attn_drop = nn.Dropout(attn_drop) if attn_drop > 0 else nn.Identity() 172 | self.proj = nn.Linear(dim, dim) 173 | self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0 else nn.Identity() 174 | 175 | # positional embedding 176 | # Define a parameter table of relative position bias, shape: 2*Wh-1 * 2*Ww-1, nH 177 | self.relative_position_bias_table = nn.Parameter( 178 | torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads)) 179 | 180 | # Get pair-wise relative position index for each token inside the window 181 | self.register_buffer("relative_position_index", get_relative_position_index(window_size, 182 | window_size).view(-1)) 183 | # Init relative positional bias 184 | trunc_normal_(self.relative_position_bias_table, std=.02) 185 | 186 | def _get_relative_positional_bias( 187 | self 188 | ) -> torch.Tensor: 189 | """ Returns the relative positional bias. 190 | Returns: 191 | relative_position_bias (torch.Tensor): Relative positional bias. 192 | """ 193 | relative_position_bias = self.relative_position_bias_table[ 194 | self.relative_position_index].view(self.attn_area, self.attn_area, -1) 195 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() 196 | return relative_position_bias.unsqueeze(0) 197 | 198 | def forward_global_aggregation(self, q, k, v): 199 | """ 200 | q: global tokens 201 | k: image tokens 202 | v: image tokens 203 | """ 204 | B, _, N, _ = q.shape 205 | attn = (q @ k.transpose(-2, -1)) * self.scale 206 | attn = attn.softmax(dim=-1) 207 | attn = self.attn_drop(attn) 208 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1) 209 | return x 210 | 211 | def forward_local(self, q, k, v, H, W): 212 | """ 213 | q: image tokens 214 | k: image tokens 215 | v: image tokens 216 | """ 217 | B, num_heads, N, C = q.shape 218 | ws = self.window_size 219 | h_group, w_group = H // ws, W // ws 220 | 221 | # partition to windows 222 | q = q.view(B, num_heads, h_group, ws, w_group, ws, -1).permute(0, 2, 4, 1, 3, 5, 6).contiguous() 223 | q = q.view(-1, num_heads, ws*ws, C) 224 | k = k.view(B, num_heads, h_group, ws, w_group, ws, -1).permute(0, 2, 4, 1, 3, 5, 6).contiguous() 225 | k = k.view(-1, num_heads, ws*ws, C) 226 | v = v.view(B, num_heads, h_group, ws, w_group, ws, -1).permute(0, 2, 4, 1, 3, 5, 6).contiguous() 227 | v = v.view(-1, num_heads, ws*ws, v.shape[-1]) 228 | 229 | attn = (q @ k.transpose(-2, -1)) * self.scale 230 | pos_bias = self._get_relative_positional_bias() 231 | attn = (attn + pos_bias).softmax(dim=-1) 232 | attn = self.attn_drop(attn) 233 | x = (attn @ v).transpose(1, 2).reshape(v.shape[0], ws*ws, -1) 234 | 235 | # reverse 236 | x = window_reverse(x, (H, W), (ws, ws)) 237 | return x 238 | 239 | def forward_global_broadcast(self, q, k, v): 240 | """ 241 | q: image tokens 242 | k: global tokens 243 | v: global tokens 244 | """ 245 | B, num_heads, N, _ = q.shape 246 | attn = (q @ k.transpose(-2, -1)) * self.scale 247 | attn = attn.softmax(dim=-1) 248 | attn = self.attn_drop(attn) 249 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1) 250 | return x 251 | 252 | def forward(self, x, H, W): 253 | B, N, C = x.shape 254 | NC = self.num_tokens 255 | # pad 256 | x_img, x_global = x[:, NC:], x[:, :NC] 257 | x_img = x_img.view(B, H, W, C) 258 | pad_l = pad_t = 0 259 | ws = self.window_size 260 | pad_r = (ws - W % ws) % ws 261 | pad_b = (ws - H % ws) % ws 262 | x_img = F.pad(x_img, (0, 0, pad_l, pad_r, pad_t, pad_b)) 263 | Hp, Wp = x_img.shape[1], x_img.shape[2] 264 | x_img = x_img.view(B, -1, C) 265 | x = torch.cat([x_global, x_img], dim=1) 266 | 267 | # qkv 268 | qkv = self.qkv(x) 269 | q, k, v = qkv.view(B, -1, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).unbind(0) 270 | 271 | # split img tokens & global tokens 272 | q_img, k_img, v_img = q[:, :, NC:], k[:, :, NC:], v[:, :, NC:] 273 | q_cls, _, _ = q[:, :, :NC], k[:, :, :NC], v[:, :, :NC] 274 | 275 | # local window attention 276 | x_img = self.forward_local(q_img, k_img, v_img, Hp, Wp) 277 | # restore to the original size 278 | x_img = x_img.view(B, Hp, Wp, -1)[:, :H, :W].reshape(B, H*W, -1) 279 | q_img = q_img.reshape(B, self.num_heads, Hp, Wp, -1)[:, :, :H, :W].reshape(B, self.num_heads, H*W, -1) 280 | k_img = k_img.reshape(B, self.num_heads, Hp, Wp, -1)[:, :, :H, :W].reshape(B, self.num_heads, H*W, -1) 281 | v_img = v_img.reshape(B, self.num_heads, Hp, Wp, -1)[:, :, :H, :W].reshape(B, self.num_heads, H*W, -1) 282 | 283 | # global aggregation 284 | x_cls = self.forward_global_aggregation(q_cls, k_img, v_img) 285 | k_cls, v_cls = self.kv_global(x_cls).view(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).unbind(0) 286 | 287 | # gloal broadcast 288 | x_img = x_img + self.forward_global_broadcast(q_img, k_cls, v_cls) 289 | 290 | x = torch.cat([x_cls, x_img], dim=1) 291 | x = self.proj(x) 292 | return x 293 | 294 | 295 | class Block(nn.Module): 296 | 297 | def __init__(self, dim, num_heads, num_tokens=1, window_size=7, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., 298 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attention=Attention, last_block=False): 299 | super().__init__() 300 | self.last_block = last_block 301 | self.norm1 = norm_layer(dim) 302 | self.attn = attention(dim, num_heads=num_heads, num_tokens=num_tokens, window_size=window_size, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) 303 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 304 | self.norm2 = norm_layer(dim) 305 | mlp_hidden_dim = int(dim * mlp_ratio) 306 | self.mlp = BiAttnMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 307 | 308 | def forward(self, x, H, W): 309 | x = x + self.drop_path(self.attn(self.norm1(x), H, W)) 310 | if self.last_block: 311 | # ignore unused global tokens in downstream tasks 312 | x = x[:, -H*W:] 313 | x = x + self.drop_path(self.mlp(self.norm2(x))) 314 | return x 315 | 316 | 317 | class ResidualMergePatch(nn.Module): 318 | def __init__(self, dim, out_dim, num_tokens=1): 319 | super().__init__() 320 | self.num_tokens = num_tokens 321 | self.norm = nn.LayerNorm(4 * dim) 322 | self.reduction = nn.Linear(4 * dim, out_dim, bias=False) 323 | self.norm2 = nn.LayerNorm(dim) 324 | self.proj = nn.Linear(dim, out_dim, bias=False) 325 | # use MaxPool3d to avoid permutations 326 | self.maxp = nn.MaxPool3d((2, 2, 1), (2, 2, 1)) 327 | self.res_proj = nn.Linear(dim, out_dim, bias=False) 328 | 329 | def forward(self, x, H, W): 330 | global_token, x = x[:, :self.num_tokens].contiguous(), x[:, self.num_tokens:].contiguous() 331 | B, L, C = x.shape 332 | 333 | x = x.view(B, H, W, C) 334 | # pad 335 | pad_l = pad_t = 0 336 | pad_r = (2 - W % 2) % 2 337 | pad_b = (2 - H % 2) % 2 338 | x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) 339 | 340 | res = self.res_proj(self.maxp(x).view(B, -1, C)) 341 | 342 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 343 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 344 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 345 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 346 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 347 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 348 | 349 | x = self.norm(x) 350 | x = self.reduction(x) 351 | x = x + res 352 | global_token = self.proj(self.norm2(global_token)) 353 | x = torch.cat([global_token, x], 1) 354 | return x, (math.ceil(H / 2), math.ceil(W / 2)) 355 | 356 | 357 | class LightViT(BaseModule): 358 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[32, 64, 160, 256], num_layers=[2, 2, 2, 2], 359 | num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], num_tokens=8, window_size=7, neck_dim=1280, qkv_bias=True, 360 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=ConvStem, norm_layer=None, 361 | act_layer=None, weight_init='', out_indices=(0, 1, 2, 3), stem_norm_eval=False, init_cfg=None): 362 | super().__init__(init_cfg) 363 | self.num_classes = num_classes 364 | self.embed_dims = embed_dims 365 | self.num_tokens = num_tokens 366 | self.mlp_ratios = mlp_ratios 367 | self.patch_size = patch_size 368 | self.num_layers = num_layers 369 | self.out_indices = out_indices 370 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 371 | act_layer = act_layer or nn.GELU 372 | 373 | self.patch_embed = embed_layer( 374 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0], norm_eval=stem_norm_eval) 375 | 376 | self.global_token = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dims[0])) 377 | 378 | stages = [] 379 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(num_layers))] # stochastic depth decay rule 380 | for stage, (embed_dim, num_layer, num_head, mlp_ratio) in enumerate(zip(embed_dims, num_layers, num_heads, mlp_ratios)): 381 | blocks = [] 382 | if stage > 0: 383 | # downsample 384 | blocks.append(ResidualMergePatch(embed_dims[stage-1], embed_dim, num_tokens=num_tokens)) 385 | blocks += [ 386 | Block( 387 | dim=embed_dim, num_heads=num_head, num_tokens=num_tokens, window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, 388 | attn_drop=attn_drop_rate, drop_path=dpr[sum(num_layers[:stage]) + i], norm_layer=norm_layer, act_layer=act_layer, attention=Attention, 389 | last_block=(stage==len(embed_dims)-1 and i == num_layer - 1)) 390 | for i in range(num_layer)] 391 | blocks = nn.Sequential(*blocks) 392 | stages.append(blocks) 393 | self.stages = nn.Sequential(*stages) 394 | 395 | # add a norm layer for each output 396 | for i_layer in out_indices: 397 | layer = norm_layer(embed_dims[i_layer]) 398 | layer_name = f'norm{i_layer}' 399 | self.add_module(layer_name, layer) 400 | 401 | def forward_features(self, x): 402 | outputs = [] 403 | x, (H, W) = self.patch_embed(x) 404 | x_out = x.view(-1, H, W, x.shape[-1]).permute(0, 3, 1, 2).contiguous() 405 | global_token = self.global_token.expand(x.shape[0], -1, -1) 406 | x = torch.cat((global_token, x), dim=1) 407 | 408 | for i_stage, stage in enumerate(self.stages): 409 | for block in stage: 410 | if isinstance(block, ResidualMergePatch): 411 | x, (H, W) = block(x, H, W) 412 | elif isinstance(block, Block): 413 | x = block(x, H, W) 414 | else: 415 | x = block(x) 416 | 417 | if i_stage in self.out_indices: 418 | norm_layer = getattr(self, f'norm{i_stage}') 419 | x_out = norm_layer(x[:, -H*W:]) 420 | x_out = x_out.view(-1, H, W, self.embed_dims[i_stage]).permute(0, 3, 1, 2).contiguous() 421 | outputs.append(x_out) 422 | return tuple(outputs) 423 | 424 | @auto_fp16() 425 | def forward(self, x): 426 | x = self.forward_features(x) 427 | return x 428 | 429 | 430 | @BACKBONES.register_module() 431 | class lightvit_tiny(LightViT): 432 | 433 | def __init__(self, **kwargs): 434 | model_kwargs = dict(patch_size=8, embed_dims=[64, 128, 256], num_layers=[2, 6, 6], 435 | num_heads=[2, 4, 8, ], mlp_ratios=[8, 4, 4], num_tokens=512, drop_path_rate=0.1) 436 | model_kwargs.update(kwargs) 437 | super().__init__(**model_kwargs) 438 | 439 | 440 | @BACKBONES.register_module() 441 | class lightvit_small(LightViT): 442 | 443 | def __init__(self, **kwargs): 444 | model_kwargs = dict(patch_size=8, embed_dims=[96, 192, 384], num_layers=[2, 6, 6], 445 | num_heads=[3, 6, 12, ], mlp_ratios=[8, 4, 4], num_tokens=16, drop_path_rate=0.1) 446 | model_kwargs.update(kwargs) 447 | super().__init__(**model_kwargs) 448 | 449 | 450 | @BACKBONES.register_module() 451 | class lightvit_base(LightViT): 452 | 453 | def __init__(self, **kwargs): 454 | model_kwargs = dict(patch_size=8, embed_dims=[128, 256, 512], num_layers=[3, 8, 6], 455 | num_heads=[4, 8, 16, ], mlp_ratios=[8, 4, 4], num_tokens=24, drop_path_rate=0.1) 456 | model_kwargs.update(kwargs) 457 | super().__init__(**model_kwargs) 458 | -------------------------------------------------------------------------------- /detection/lightvit_fpn.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from mmcv.cnn import ConvModule, xavier_init 6 | from mmcv.runner import auto_fp16 7 | 8 | from mmdet.models.builder import NECKS 9 | 10 | 11 | @NECKS.register_module() 12 | class LightViTFPN(nn.Module): 13 | r"""Feature Pyramid Network for LightViT. 14 | 15 | Args: 16 | in_channels (List[int]): Number of input channels per scale. 17 | out_channels (int): Number of output channels (used at each scale) 18 | num_outs (int): Number of output scales. 19 | start_level (int): Index of the start input backbone level used to 20 | build the feature pyramid. Default: 0. 21 | end_level (int): Index of the end input backbone level (exclusive) to 22 | build the feature pyramid. Default: -1, which means the last level. 23 | add_extra_convs (bool | str): If bool, it decides whether to add conv 24 | layers on top of the original feature maps. Default to False. 25 | If True, its actual mode is specified by `extra_convs_on_inputs`. 26 | If str, it specifies the source feature map of the extra convs. 27 | Only the following options are allowed 28 | 29 | - 'on_input': Last feat map of neck inputs (i.e. backbone feature). 30 | - 'on_lateral': Last feature map after lateral convs. 31 | - 'on_output': The last output feature map after fpn convs. 32 | extra_convs_on_inputs (bool, deprecated): Whether to apply extra convs 33 | on the original feature from the backbone. If True, 34 | it is equivalent to `add_extra_convs='on_input'`. If False, it is 35 | equivalent to set `add_extra_convs='on_output'`. Default to True. 36 | relu_before_extra_convs (bool): Whether to apply relu before the extra 37 | conv. Default: False. 38 | no_norm_on_lateral (bool): Whether to apply norm on lateral. 39 | Default: False. 40 | num_extra_trans_convs (int): extra transposed conv on the output 41 | with largest resolution. Default: 0. 42 | conv_cfg (dict): Config dict for convolution layer. Default: None. 43 | norm_cfg (dict): Config dict for normalization layer. Default: None. 44 | act_cfg (str): Config dict for activation layer in ConvModule. 45 | Default: None. 46 | upsample_cfg (dict): Config dict for interpolate layer. 47 | Default: `dict(mode='nearest')` 48 | 49 | Example: 50 | >>> import torch 51 | >>> in_channels = [2, 3, 5, 7] 52 | >>> scales = [340, 170, 84, 43] 53 | >>> inputs = [torch.rand(1, c, s, s) 54 | ... for c, s in zip(in_channels, scales)] 55 | >>> self = FPN(in_channels, 11, len(in_channels)).eval() 56 | >>> outputs = self.forward(inputs) 57 | >>> for i in range(len(outputs)): 58 | ... print(f'outputs[{i}].shape = {outputs[i].shape}') 59 | outputs[0].shape = torch.Size([1, 11, 340, 340]) 60 | outputs[1].shape = torch.Size([1, 11, 170, 170]) 61 | outputs[2].shape = torch.Size([1, 11, 84, 84]) 62 | outputs[3].shape = torch.Size([1, 11, 43, 43]) 63 | """ 64 | 65 | def __init__(self, 66 | in_channels, 67 | out_channels, 68 | num_outs, 69 | start_level=0, 70 | end_level=-1, 71 | add_extra_convs=False, 72 | extra_convs_on_inputs=True, 73 | relu_before_extra_convs=False, 74 | no_norm_on_lateral=False, 75 | num_extra_trans_convs=0, 76 | conv_cfg=None, 77 | norm_cfg=None, 78 | act_cfg=None, 79 | upsample_cfg=dict(mode='nearest')): 80 | super(LightViTFPN, self).__init__() 81 | assert isinstance(in_channels, list) 82 | self.in_channels = in_channels 83 | self.out_channels = out_channels 84 | self.num_ins = len(in_channels) 85 | self.num_outs = num_outs 86 | self.relu_before_extra_convs = relu_before_extra_convs 87 | self.no_norm_on_lateral = no_norm_on_lateral 88 | self.num_extra_trans_convs = num_extra_trans_convs 89 | self.fp16_enabled = False 90 | self.upsample_cfg = upsample_cfg.copy() 91 | 92 | if end_level == -1: 93 | self.backbone_end_level = self.num_ins 94 | assert num_outs >= self.num_ins - start_level 95 | else: 96 | # if end_level < inputs, no extra level is allowed 97 | self.backbone_end_level = end_level 98 | assert end_level <= len(in_channels) 99 | assert num_outs == end_level - start_level 100 | self.start_level = start_level 101 | self.end_level = end_level 102 | self.add_extra_convs = add_extra_convs 103 | assert isinstance(add_extra_convs, (str, bool)) 104 | if isinstance(add_extra_convs, str): 105 | # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output' 106 | assert add_extra_convs in ('on_input', 'on_lateral', 'on_output') 107 | elif add_extra_convs: # True 108 | if extra_convs_on_inputs: 109 | # TODO: deprecate `extra_convs_on_inputs` 110 | warnings.simplefilter('once') 111 | warnings.warn( 112 | '"extra_convs_on_inputs" will be deprecated in v2.9.0,' 113 | 'Please use "add_extra_convs"', DeprecationWarning) 114 | self.add_extra_convs = 'on_input' 115 | else: 116 | self.add_extra_convs = 'on_output' 117 | 118 | self.lateral_convs = nn.ModuleList() 119 | self.fpn_convs = nn.ModuleList() 120 | 121 | for i in range(self.start_level, self.backbone_end_level): 122 | l_conv = ConvModule( 123 | in_channels[i], 124 | out_channels, 125 | 1, 126 | conv_cfg=conv_cfg, 127 | norm_cfg=norm_cfg if not self.no_norm_on_lateral else None, 128 | act_cfg=act_cfg, 129 | inplace=False) 130 | fpn_conv = ConvModule( 131 | out_channels, 132 | out_channels, 133 | 3, 134 | padding=1, 135 | conv_cfg=conv_cfg, 136 | norm_cfg=norm_cfg, 137 | act_cfg=act_cfg, 138 | inplace=False) 139 | 140 | self.lateral_convs.append(l_conv) 141 | self.fpn_convs.append(fpn_conv) 142 | 143 | # add extra conv layers (e.g., RetinaNet) 144 | extra_levels = num_outs - self.backbone_end_level + self.start_level 145 | assert extra_levels >= num_extra_trans_convs 146 | extra_levels -= num_extra_trans_convs 147 | if self.add_extra_convs and extra_levels >= 1: 148 | for i in range(extra_levels): 149 | if i == 0 and self.add_extra_convs == 'on_input': 150 | in_channels = self.in_channels[self.backbone_end_level - 1] 151 | else: 152 | in_channels = out_channels 153 | extra_fpn_conv = ConvModule( 154 | in_channels, 155 | out_channels, 156 | 3, 157 | stride=2, 158 | padding=1, 159 | conv_cfg=conv_cfg, 160 | norm_cfg=norm_cfg, 161 | act_cfg=act_cfg, 162 | inplace=False) 163 | self.fpn_convs.append(extra_fpn_conv) 164 | 165 | # add extra transposed convs 166 | self.extra_trans_convs = nn.ModuleList() 167 | self.extra_fpn_convs = nn.ModuleList() 168 | for i in range(num_extra_trans_convs): 169 | extra_trans_conv = TransposedConvModule( 170 | out_channels, 171 | out_channels, 172 | 2, 173 | stride=2, 174 | padding=0, 175 | conv_cfg=conv_cfg, 176 | norm_cfg=norm_cfg if not no_norm_on_lateral else None, 177 | act_cfg=act_cfg, 178 | inplace=False) 179 | self.extra_trans_convs.append(extra_trans_conv) 180 | extra_fpn_conv = ConvModule( 181 | out_channels, 182 | out_channels, 183 | 3, 184 | padding=1, 185 | conv_cfg=conv_cfg, 186 | norm_cfg=norm_cfg, 187 | act_cfg=act_cfg, 188 | inplace=False) 189 | self.extra_fpn_convs.append(extra_fpn_conv) 190 | 191 | # default init_weights for conv(msra) and norm in ConvModule 192 | def init_weights(self): 193 | """Initialize the weights of FPN module.""" 194 | for m in self.modules(): 195 | if isinstance(m, nn.Conv2d): 196 | xavier_init(m, distribution='uniform') 197 | 198 | @auto_fp16() 199 | def forward(self, inputs): 200 | """Forward function.""" 201 | assert len(inputs) == len(self.in_channels) 202 | 203 | # build laterals 204 | laterals = [ 205 | lateral_conv(inputs[i + self.start_level]) 206 | for i, lateral_conv in enumerate(self.lateral_convs) 207 | ] 208 | 209 | # build top-down path 210 | used_backbone_levels = len(laterals) 211 | for i in range(used_backbone_levels - 1, 0, -1): 212 | # In some cases, fixing `scale factor` (e.g. 2) is preferred, but 213 | # it cannot co-exist with `size` in `F.interpolate`. 214 | if 'scale_factor' in self.upsample_cfg: 215 | laterals[i - 1] += F.interpolate(laterals[i], 216 | **self.upsample_cfg) 217 | else: 218 | prev_shape = laterals[i - 1].shape[2:] 219 | laterals[i - 1] += F.interpolate( 220 | laterals[i], size=prev_shape, **self.upsample_cfg) 221 | 222 | # extra transposed convs for outputs with extra scales 223 | extra_laterals = [] 224 | if self.num_extra_trans_convs > 0: 225 | prev_lateral = laterals[0] 226 | for i in range(self.num_extra_trans_convs): 227 | extra_lateral = self.extra_trans_convs[i](prev_lateral) 228 | extra_laterals.insert(0, extra_lateral) 229 | prev_lateral = extra_lateral 230 | 231 | # build outputs 232 | # part 1: from original levels 233 | outs = [ 234 | self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels) 235 | ] 236 | # part 2: add extra levels 237 | if self.num_outs > len(outs) + len(extra_laterals): 238 | # use max pool to get more levels on top of outputs 239 | # (e.g., Faster R-CNN, Mask R-CNN) 240 | if not self.add_extra_convs: 241 | for i in range(self.num_outs - len(extra_laterals) - used_backbone_levels): 242 | outs.append(F.max_pool2d(outs[-1], 1, stride=2)) 243 | # add conv layers on top of original feature maps (RetinaNet) 244 | else: 245 | if self.add_extra_convs == 'on_input': 246 | extra_source = inputs[self.backbone_end_level - 1] 247 | elif self.add_extra_convs == 'on_lateral': 248 | extra_source = laterals[-1] 249 | elif self.add_extra_convs == 'on_output': 250 | extra_source = outs[-1] 251 | else: 252 | raise NotImplementedError 253 | outs.append(self.fpn_convs[used_backbone_levels](extra_source)) 254 | for i in range(used_backbone_levels + 1, self.num_outs - len(extra_laterals)): 255 | if self.relu_before_extra_convs: 256 | outs.append(self.fpn_convs[i](F.relu(outs[-1]))) 257 | else: 258 | outs.append(self.fpn_convs[i](outs[-1])) 259 | 260 | # part 3: add extra transposed convs 261 | if self.num_extra_trans_convs > 0: 262 | extra_outs = [ 263 | self.extra_fpn_convs[i](extra_laterals[i]) 264 | for i in range(self.num_extra_trans_convs) 265 | ] 266 | assert (len(extra_outs) + len(outs)) == self.num_outs, f"{len(extra_outs)} + {len(outs)} != {self.num_outs}" 267 | return tuple(extra_outs + outs) 268 | 269 | 270 | class TransposedConvModule(ConvModule): 271 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 272 | padding=0, dilation=1, groups=1, bias='auto', conv_cfg=None, 273 | norm_cfg=None, act_cfg=..., inplace=True, 274 | **kwargs): 275 | super(TransposedConvModule, self).__init__(in_channels, out_channels, kernel_size, 276 | stride, padding, dilation, groups, bias, conv_cfg, 277 | norm_cfg, act_cfg, inplace, **kwargs) 278 | 279 | self.conv = nn.ConvTranspose2d( 280 | in_channels, 281 | out_channels, 282 | kernel_size, 283 | stride=stride, 284 | padding=padding, 285 | dilation=dilation, 286 | groups=groups, 287 | bias=self.with_bias 288 | ) 289 | 290 | # Use msra init by default 291 | self.init_weights() 292 | -------------------------------------------------------------------------------- /detection/slurm_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | CONFIG=$3 8 | CHECKPOINT=$4 9 | GPUS=${GPUS:-8} 10 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 12 | PY_ARGS=${@:5} 13 | SRUN_ARGS=${SRUN_ARGS:-""} 14 | 15 | PYTHONPATH="$(dirname $0)":$PYTHONPATH \ 16 | srun -p ${PARTITION} \ 17 | --job-name=${JOB_NAME} \ 18 | --gres=gpu:${GPUS_PER_NODE} \ 19 | --ntasks=${GPUS} \ 20 | --ntasks-per-node=${GPUS_PER_NODE} \ 21 | --cpus-per-task=${CPUS_PER_TASK} \ 22 | --kill-on-bad-exit=1 \ 23 | ${SRUN_ARGS} \ 24 | python -u test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS} 25 | -------------------------------------------------------------------------------- /detection/slurm_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | CONFIG=$3 8 | WORK_DIR=$4 9 | GPUS=${GPUS:-8} 10 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 12 | SRUN_ARGS=${SRUN_ARGS:-""} 13 | PY_ARGS=${@:5} 14 | 15 | PYTHONPATH="$(dirname $0)":$PYTHONPATH \ 16 | srun -p ${PARTITION} \ 17 | --job-name=${JOB_NAME} \ 18 | --gres=gpu:${GPUS_PER_NODE} \ 19 | --ntasks=${GPUS} \ 20 | --ntasks-per-node=${GPUS_PER_NODE} \ 21 | --cpus-per-task=${CPUS_PER_TASK} \ 22 | --kill-on-bad-exit=1 \ 23 | ${SRUN_ARGS} \ 24 | python -u train.py ${CONFIG} --work-dir=${WORK_DIR} --launcher="slurm" ${PY_ARGS} 25 | -------------------------------------------------------------------------------- /detection/test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import time 5 | import warnings 6 | 7 | import mmcv 8 | import torch 9 | from mmcv import Config, DictAction 10 | from mmcv.cnn import fuse_conv_bn 11 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel 12 | from mmcv.runner import (get_dist_info, init_dist, load_checkpoint, 13 | wrap_fp16_model) 14 | 15 | from mmdet.apis import multi_gpu_test, single_gpu_test 16 | from mmdet.datasets import (build_dataloader, build_dataset, 17 | replace_ImageToTensor) 18 | from mmdet.models import build_detector 19 | 20 | import lightvit 21 | import lightvit_fpn 22 | 23 | 24 | def parse_args(): 25 | parser = argparse.ArgumentParser( 26 | description='MMDet test (and eval) a model') 27 | parser.add_argument('config', help='test config file path') 28 | parser.add_argument('checkpoint', help='checkpoint file') 29 | parser.add_argument( 30 | '--work-dir', 31 | help='the directory to save the file containing evaluation metrics') 32 | parser.add_argument('--out', help='output result file in pickle format') 33 | parser.add_argument( 34 | '--fuse-conv-bn', 35 | action='store_true', 36 | help='Whether to fuse conv and bn, this will slightly increase' 37 | 'the inference speed') 38 | parser.add_argument( 39 | '--format-only', 40 | action='store_true', 41 | help='Format the output results without perform evaluation. It is' 42 | 'useful when you want to format the result to a specific format and ' 43 | 'submit it to the test server') 44 | parser.add_argument( 45 | '--eval', 46 | type=str, 47 | nargs='+', 48 | help='evaluation metrics, which depends on the dataset, e.g., "bbox",' 49 | ' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC') 50 | parser.add_argument('--show', action='store_true', help='show results') 51 | parser.add_argument( 52 | '--show-dir', help='directory where painted images will be saved') 53 | parser.add_argument( 54 | '--show-score-thr', 55 | type=float, 56 | default=0.3, 57 | help='score threshold (default: 0.3)') 58 | parser.add_argument( 59 | '--gpu-collect', 60 | action='store_true', 61 | help='whether to use gpu to collect results.') 62 | parser.add_argument( 63 | '--tmpdir', 64 | help='tmp directory used for collecting results from multiple ' 65 | 'workers, available when gpu-collect is not specified') 66 | parser.add_argument( 67 | '--cfg-options', 68 | nargs='+', 69 | action=DictAction, 70 | help='override some settings in the used config, the key-value pair ' 71 | 'in xxx=yyy format will be merged into config file. If the value to ' 72 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 73 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 74 | 'Note that the quotation marks are necessary and that no white space ' 75 | 'is allowed.') 76 | parser.add_argument( 77 | '--options', 78 | nargs='+', 79 | action=DictAction, 80 | help='custom options for evaluation, the key-value pair in xxx=yyy ' 81 | 'format will be kwargs for dataset.evaluate() function (deprecate), ' 82 | 'change to --eval-options instead.') 83 | parser.add_argument( 84 | '--eval-options', 85 | nargs='+', 86 | action=DictAction, 87 | help='custom options for evaluation, the key-value pair in xxx=yyy ' 88 | 'format will be kwargs for dataset.evaluate() function') 89 | parser.add_argument( 90 | '--launcher', 91 | choices=['none', 'pytorch', 'slurm', 'mpi'], 92 | default='none', 93 | help='job launcher') 94 | parser.add_argument('--local_rank', type=int, default=0) 95 | args = parser.parse_args() 96 | if 'LOCAL_RANK' not in os.environ: 97 | os.environ['LOCAL_RANK'] = str(args.local_rank) 98 | 99 | if args.options and args.eval_options: 100 | raise ValueError( 101 | '--options and --eval-options cannot be both ' 102 | 'specified, --options is deprecated in favor of --eval-options') 103 | if args.options: 104 | warnings.warn('--options is deprecated in favor of --eval-options') 105 | args.eval_options = args.options 106 | return args 107 | 108 | 109 | def main(): 110 | args = parse_args() 111 | 112 | assert args.out or args.eval or args.format_only or args.show \ 113 | or args.show_dir, \ 114 | ('Please specify at least one operation (save/eval/format/show the ' 115 | 'results / save the results) with the argument "--out", "--eval"' 116 | ', "--format-only", "--show" or "--show-dir"') 117 | 118 | if args.eval and args.format_only: 119 | raise ValueError('--eval and --format_only cannot be both specified') 120 | 121 | if args.out is not None and not args.out.endswith(('.pkl', '.pickle')): 122 | raise ValueError('The output file must be a pkl file.') 123 | 124 | cfg = Config.fromfile(args.config) 125 | if args.cfg_options is not None: 126 | cfg.merge_from_dict(args.cfg_options) 127 | # import modules from string list. 128 | if cfg.get('custom_imports', None): 129 | from mmcv.utils import import_modules_from_strings 130 | import_modules_from_strings(**cfg['custom_imports']) 131 | # set cudnn_benchmark 132 | if cfg.get('cudnn_benchmark', False): 133 | torch.backends.cudnn.benchmark = True 134 | 135 | cfg.model.pretrained = None 136 | if cfg.model.get('neck'): 137 | if isinstance(cfg.model.neck, list): 138 | for neck_cfg in cfg.model.neck: 139 | if neck_cfg.get('rfp_backbone'): 140 | if neck_cfg.rfp_backbone.get('pretrained'): 141 | neck_cfg.rfp_backbone.pretrained = None 142 | elif cfg.model.neck.get('rfp_backbone'): 143 | if cfg.model.neck.rfp_backbone.get('pretrained'): 144 | cfg.model.neck.rfp_backbone.pretrained = None 145 | 146 | # in case the test dataset is concatenated 147 | samples_per_gpu = 1 148 | if isinstance(cfg.data.test, dict): 149 | cfg.data.test.test_mode = True 150 | samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1) 151 | if samples_per_gpu > 1: 152 | # Replace 'ImageToTensor' to 'DefaultFormatBundle' 153 | cfg.data.test.pipeline = replace_ImageToTensor( 154 | cfg.data.test.pipeline) 155 | elif isinstance(cfg.data.test, list): 156 | for ds_cfg in cfg.data.test: 157 | ds_cfg.test_mode = True 158 | samples_per_gpu = max( 159 | [ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test]) 160 | if samples_per_gpu > 1: 161 | for ds_cfg in cfg.data.test: 162 | ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline) 163 | 164 | # init distributed env first, since logger depends on the dist info. 165 | if args.launcher == 'none': 166 | distributed = False 167 | else: 168 | distributed = True 169 | init_dist(args.launcher, **cfg.dist_params) 170 | 171 | rank, _ = get_dist_info() 172 | # allows not to create 173 | if args.work_dir is not None and rank == 0: 174 | mmcv.mkdir_or_exist(osp.abspath(args.work_dir)) 175 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 176 | json_file = osp.join(args.work_dir, f'eval_{timestamp}.json') 177 | 178 | # build the dataloader 179 | dataset = build_dataset(cfg.data.test) 180 | data_loader = build_dataloader( 181 | dataset, 182 | samples_per_gpu=samples_per_gpu, 183 | workers_per_gpu=cfg.data.workers_per_gpu, 184 | dist=distributed, 185 | shuffle=False) 186 | 187 | # build the model and load checkpoint 188 | cfg.model.train_cfg = None 189 | model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg')) 190 | fp16_cfg = cfg.get('fp16', None) 191 | if fp16_cfg is not None: 192 | wrap_fp16_model(model) 193 | checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu') 194 | if args.fuse_conv_bn: 195 | model = fuse_conv_bn(model) 196 | # old versions did not save class info in checkpoints, this walkaround is 197 | # for backward compatibility 198 | if 'CLASSES' in checkpoint.get('meta', {}): 199 | model.CLASSES = checkpoint['meta']['CLASSES'] 200 | else: 201 | model.CLASSES = dataset.CLASSES 202 | 203 | if not distributed: 204 | model = MMDataParallel(model, device_ids=[0]) 205 | outputs = single_gpu_test(model, data_loader, args.show, args.show_dir, 206 | args.show_score_thr) 207 | else: 208 | model = MMDistributedDataParallel( 209 | model.cuda(), 210 | device_ids=[torch.cuda.current_device()], 211 | broadcast_buffers=False) 212 | outputs = multi_gpu_test(model, data_loader, args.tmpdir, 213 | args.gpu_collect) 214 | 215 | rank, _ = get_dist_info() 216 | if rank == 0: 217 | if args.out: 218 | print(f'\nwriting results to {args.out}') 219 | mmcv.dump(outputs, args.out) 220 | kwargs = {} if args.eval_options is None else args.eval_options 221 | if args.format_only: 222 | dataset.format_results(outputs, **kwargs) 223 | if args.eval: 224 | eval_kwargs = cfg.get('evaluation', {}).copy() 225 | # hard-code way to remove EvalHook args 226 | for key in [ 227 | 'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best', 228 | 'rule' 229 | ]: 230 | eval_kwargs.pop(key, None) 231 | eval_kwargs.update(dict(metric=args.eval, **kwargs)) 232 | metric = dataset.evaluate(outputs, **eval_kwargs) 233 | print(metric) 234 | metric_dict = dict(config=args.config, metric=metric) 235 | if args.work_dir is not None and rank == 0: 236 | mmcv.dump(metric_dict, json_file) 237 | 238 | 239 | if __name__ == '__main__': 240 | main() 241 | -------------------------------------------------------------------------------- /detection/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import os 4 | import os.path as osp 5 | import time 6 | import warnings 7 | 8 | import mmcv 9 | import torch 10 | from mmcv import Config, DictAction 11 | from mmcv.runner import get_dist_info, init_dist 12 | from mmcv.utils import get_git_hash 13 | 14 | from mmdet import __version__ 15 | from mmdet.apis import set_random_seed, train_detector 16 | from mmdet.datasets import build_dataset 17 | from mmdet.models import build_detector 18 | from mmdet.utils import collect_env, get_root_logger 19 | 20 | import lightvit 21 | import lightvit_fpn 22 | 23 | 24 | def parse_args(): 25 | parser = argparse.ArgumentParser(description='Train a detector') 26 | parser.add_argument('config', help='train config file path') 27 | parser.add_argument('--work-dir', help='the dir to save logs and models') 28 | parser.add_argument( 29 | '--resume-from', help='the checkpoint file to resume from') 30 | parser.add_argument( 31 | '--no-validate', 32 | action='store_true', 33 | help='whether not to evaluate the checkpoint during training') 34 | group_gpus = parser.add_mutually_exclusive_group() 35 | group_gpus.add_argument( 36 | '--gpus', 37 | type=int, 38 | help='number of gpus to use ' 39 | '(only applicable to non-distributed training)') 40 | group_gpus.add_argument( 41 | '--gpu-ids', 42 | type=int, 43 | nargs='+', 44 | help='ids of gpus to use ' 45 | '(only applicable to non-distributed training)') 46 | parser.add_argument('--seed', type=int, default=None, help='random seed') 47 | parser.add_argument( 48 | '--deterministic', 49 | action='store_true', 50 | help='whether to set deterministic options for CUDNN backend.') 51 | parser.add_argument( 52 | '--options', 53 | nargs='+', 54 | action=DictAction, 55 | help='override some settings in the used config, the key-value pair ' 56 | 'in xxx=yyy format will be merged into config file (deprecate), ' 57 | 'change to --cfg-options instead.') 58 | parser.add_argument( 59 | '--cfg-options', 60 | nargs='+', 61 | action=DictAction, 62 | help='override some settings in the used config, the key-value pair ' 63 | 'in xxx=yyy format will be merged into config file. If the value to ' 64 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 65 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 66 | 'Note that the quotation marks are necessary and that no white space ' 67 | 'is allowed.') 68 | parser.add_argument( 69 | '--launcher', 70 | choices=['none', 'pytorch', 'slurm', 'mpi'], 71 | default='none', 72 | help='job launcher') 73 | parser.add_argument('--local_rank', type=int, default=0) 74 | args = parser.parse_args() 75 | if 'LOCAL_RANK' not in os.environ: 76 | os.environ['LOCAL_RANK'] = str(args.local_rank) 77 | 78 | if args.options and args.cfg_options: 79 | raise ValueError( 80 | '--options and --cfg-options cannot be both ' 81 | 'specified, --options is deprecated in favor of --cfg-options') 82 | if args.options: 83 | warnings.warn('--options is deprecated in favor of --cfg-options') 84 | args.cfg_options = args.options 85 | 86 | return args 87 | 88 | 89 | def main(): 90 | args = parse_args() 91 | 92 | cfg = Config.fromfile(args.config) 93 | if args.cfg_options is not None: 94 | cfg.merge_from_dict(args.cfg_options) 95 | # import modules from string list. 96 | if cfg.get('custom_imports', None): 97 | from mmcv.utils import import_modules_from_strings 98 | import_modules_from_strings(**cfg['custom_imports']) 99 | # set cudnn_benchmark 100 | if cfg.get('cudnn_benchmark', False): 101 | torch.backends.cudnn.benchmark = True 102 | 103 | # work_dir is determined in this priority: CLI > segment in file > filename 104 | if args.work_dir is not None: 105 | # update configs according to CLI args if args.work_dir is not None 106 | cfg.work_dir = args.work_dir 107 | elif cfg.get('work_dir', None) is None: 108 | # use config filename as default work_dir if cfg.work_dir is None 109 | cfg.work_dir = osp.join('./work_dirs', 110 | osp.splitext(osp.basename(args.config))[0]) 111 | if args.resume_from is not None: 112 | cfg.resume_from = args.resume_from 113 | if args.gpu_ids is not None: 114 | cfg.gpu_ids = args.gpu_ids 115 | else: 116 | cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus) 117 | 118 | # init distributed env first, since logger depends on the dist info. 119 | if args.launcher == 'none': 120 | distributed = False 121 | else: 122 | distributed = True 123 | init_dist(args.launcher, **cfg.dist_params) 124 | # re-set gpu_ids with distributed training mode 125 | _, world_size = get_dist_info() 126 | cfg.gpu_ids = range(world_size) 127 | 128 | # create work_dir 129 | mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) 130 | # dump config 131 | cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) 132 | # init the logger before other steps 133 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 134 | log_file = osp.join(cfg.work_dir, f'{timestamp}.log') 135 | logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) 136 | 137 | # init the meta dict to record some important information such as 138 | # environment info and seed, which will be logged 139 | meta = dict() 140 | # log env info 141 | env_info_dict = collect_env() 142 | env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()]) 143 | dash_line = '-' * 60 + '\n' 144 | logger.info('Environment info:\n' + dash_line + env_info + '\n' + 145 | dash_line) 146 | meta['env_info'] = env_info 147 | meta['config'] = cfg.pretty_text 148 | # log some basic info 149 | logger.info(f'Distributed training: {distributed}') 150 | logger.info(f'Config:\n{cfg.pretty_text}') 151 | 152 | # set random seeds 153 | if args.seed is not None: 154 | logger.info(f'Set random seed to {args.seed}, ' 155 | f'deterministic: {args.deterministic}') 156 | set_random_seed(args.seed, deterministic=args.deterministic) 157 | cfg.seed = args.seed 158 | meta['seed'] = args.seed 159 | meta['exp_name'] = osp.basename(args.config) 160 | 161 | model = build_detector( 162 | cfg.model, 163 | train_cfg=cfg.get('train_cfg'), 164 | test_cfg=cfg.get('test_cfg')) 165 | model.init_weights() 166 | 167 | datasets = [build_dataset(cfg.data.train)] 168 | if len(cfg.workflow) == 2: 169 | val_dataset = copy.deepcopy(cfg.data.val) 170 | val_dataset.pipeline = cfg.data.train.pipeline 171 | datasets.append(build_dataset(val_dataset)) 172 | if cfg.checkpoint_config is not None: 173 | # save mmdet version, config file content and class names in 174 | # checkpoints as meta data 175 | cfg.checkpoint_config.meta = dict( 176 | mmdet_version=__version__ + get_git_hash()[:7], 177 | CLASSES=datasets[0].CLASSES) 178 | # add an attribute for visualization convenience 179 | model.CLASSES = datasets[0].CLASSES 180 | train_detector( 181 | model, 182 | datasets, 183 | cfg, 184 | distributed=distributed, 185 | validate=(not args.no_validate), 186 | timestamp=timestamp, 187 | meta=meta) 188 | 189 | 190 | if __name__ == '__main__': 191 | main() 192 | --------------------------------------------------------------------------------