├── .gitmodules
├── LICENSE
├── README.md
├── assests
└── architecture.png
└── detection
├── README.md
├── configs
├── _base_
│ ├── datasets
│ │ ├── coco_detection.py
│ │ ├── coco_instance.py
│ │ └── coco_instance_ms.py
│ ├── default_runtime.py
│ ├── models
│ │ └── mask_rcnn_r50_fpn.py
│ └── schedules
│ │ ├── schedule_1x.py
│ │ ├── schedule_20e.py
│ │ └── schedule_2x.py
├── mask_rcnn_lightvit_base_fpn_1x_coco.py
├── mask_rcnn_lightvit_base_fpn_3x_ms_coco.py
├── mask_rcnn_lightvit_small_fpn_1x_coco.py
├── mask_rcnn_lightvit_small_fpn_3x_ms_coco.py
├── mask_rcnn_lightvit_tiny_fpn_1x_coco.py
└── mask_rcnn_lightvit_tiny_fpn_3x_ms_coco.py
├── dist_test.sh
├── dist_train.sh
├── get_flops.py
├── lightvit.py
├── lightvit_fpn.py
├── slurm_test.sh
├── slurm_train.sh
├── test.py
└── train.py
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "classification"]
2 | path = classification
3 | url = https://github.com/hunto/image_classification_sota.git
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2022 LightViT contributors
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LightViT
2 | Official implementation for paper "[LightViT: Towards Light-Weight Convolution-Free Vision Transformers](https://arxiv.org/abs/2207.05557)".
3 |
4 | By Tao Huang, Lang Huang, Shan You, Fei Wang, Chen Qian, Chang Xu.
5 |
6 |
7 | ## Updates
8 | ### July 26, 2022
9 | Code for COCO detection was released.
10 |
11 | ### July 14, 2022
12 | Code for ImageNet training was released.
13 |
14 |
15 | ## Introduction
16 |
17 |
18 |
19 |
20 | ## Results on ImageNet-1K
21 |
22 | |model|resolution|acc@1|acc@5|#params|FLOPs|ckpt|log|
23 | |:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
24 | |LightViT-T|224x224|78.7|94.4|9.4M|0.7G|[google drive](https://drive.google.com/file/d/1NasAnHYK6bkmj0-rzYvxQSEUiKAOW3j9/view?usp=sharing)|[log](https://github.com/hunto/LightViT/releases/download/v0.0.1/log_lightvit_tiny.csv)|
25 | |LightViT-S|224x224|80.9|95.3|19.2M|1.7G|[google drive](https://drive.google.com/file/d/16EZUth-wZ7rKR6tI67_Smp6GGzaODp16/view?usp=sharing)|[log](https://github.com/hunto/LightViT/releases/download/v0.0.1/log_lightvit_small.csv)|
26 | |LightViT-B|224x224|82.1|95.9|35.2M|3.9G|[google drive](https://drive.google.com/file/d/1MpVvfo8AiJMmz8CJZ7lBnDyoBJUEDxeL/view?usp=sharing)|[log](https://github.com/hunto/LightViT/releases/download/v0.0.1/log_lightvit_base.csv)|
27 |
28 | ### Preparation
29 | 1. Clone training code
30 | ```shell
31 | git clone https://github.com/hunto/LightViT.git --recurse-submodules
32 | cd LightViT/classification
33 | ```
34 |
35 | **The code of LightViT model can be found in [lib/models/lightvit.py](https://github.com/hunto/image_classification_sota/blob/main/lib/models/lightvit.py)** .
36 |
37 | 2. Requirements
38 | ```shell
39 | torch>=1.3.0
40 | # if you want to use torch.cuda.amp for mixed-precision training, the lowest torch version is 1.5.0
41 | timm==0.5.4
42 | ```
43 | 3. Prepare your datasets following [this link](https://github.com/hunto/image_classification_sota#prepare-datasets).
44 |
45 | ### Evaluation
46 | You can evaluate our results using the provided checkpoints. First download the checkpoints into your machine, then run
47 | ```shell
48 | sh tools/dist_run.sh tools/test.py ${NUM_GPUS} configs/strategies/lightvit/config.yaml timm_lightvit_tiny --drop-path-rate 0.1 --experiment lightvit_tiny_test --resume ${ckpt_file_path}
49 | ```
50 |
51 | ### Train from scratch on ImageNet-1K
52 | ```shell
53 | sh tools/dist_train.sh 8 configs/strategies/lightvit/config.yaml ${MODEL} --drop-path-rate 0.1 --experiment lightvit_tiny
54 | ```
55 | ${MODEL} can be `timm_lightvit_tiny`, `timm_lightvit_small`, `timm_lightvit_base` .
56 |
57 | For `timm_lightvit_base`, we added `--amp` option to use mixed-precision training, and **set `drop_path_rate` to 0.3**.
58 |
59 | ### Throughput
60 | ```shell
61 | sh tools/dist_run.sh tools/speed_test.py 1 configs/strategies/lightvit/config.yaml ${MODEL} --drop-path-rate 0.1 --batch-size 1024
62 | ```
63 | or
64 | ```shell
65 | python tools/speed_test.py -c configs/strategies/lightvit/config.yaml --model ${MODEL} --drop-path-rate 0.1 --batch-size 1024
66 | ```
67 |
68 | ## Results on COCO
69 | We conducted experiments on COCO object detection & instance segmentation tasks, see [detection/README.md](detection/README.md) for details.
70 |
71 | ## License
72 | This project is released under the [Apache 2.0 license](LICENSE).
73 |
74 | ## Citation
75 | ```
76 | @article{huang2022lightvit,
77 | title = {LightViT: Towards Light-Weight Convolution-Free Vision Transformers},
78 | author = {Huang, Tao and Huang, Lang and You, Shan and Wang, Fei and Qian, Chen and Xu, Chang},
79 | journal = {arXiv preprint arXiv:2207.05557},
80 | year = {2022}
81 | }
82 | ```
--------------------------------------------------------------------------------
/assests/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hunto/LightViT/e2452a17dcc06d426eb8979c4e6b3a5678091368/assests/architecture.png
--------------------------------------------------------------------------------
/detection/README.md:
--------------------------------------------------------------------------------
1 | # Applying LightViT to Object Detection
2 |
3 | ## Usage
4 | ## Preparations
5 | * Install mmdetection `v2.14.0` and mmcv-full
6 | ```shell
7 | pip install mmdet==2.14.0
8 | ```
9 | For mmcv-full, please check [here](https://mmdetection.readthedocs.io/en/latest/get_started.html#install-mmdetection) and [compatibility](https://mmdetection.readthedocs.io/en/latest/compatibility.html).
10 |
11 | * Put the COCO dataset into `./data` folder following [[this url]](https://mmdetection.readthedocs.io/en/latest/1_exist_data_model.html#prepare-datasets).
12 |
13 | **Note:** if you use a different mmdet version, please replace the `configs`, `train.py`, and `test.py` with the corresponding files of the version, then add `import lightvit` and `import lightvit_fpn` into `train.py` to register the model.
14 |
15 | ## Train Mask-RCNN with LightViT backbones
16 | |backbone|params|FLOPs|setting|box AP|mask AP|config|log|ckpt|
17 | |:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|:--:|
18 | |LightViT-T|28|187|1x|37.8|35.9|[config](configs/mask_rcnn_lightvit_tiny_fpn_1x_coco.py)|[log](https://github.com/hunto/LightViT/releases/download/v0.0.1/mask_rcnn_lightvit_tiny_fpn_1x_coco.log.json)|[google drive](https://drive.google.com/file/d/15sSit0VpcsmL3Pr-fdjM6nRdmSZre0zh/view?usp=sharing)|
19 | |LightViT-S|38|204|1x|40.0|37.4|[config](configs/mask_rcnn_lightvit_small_fpn_1x_coco.py)|[log](https://github.com/hunto/LightViT/releases/download/v0.0.1/mask_rcnn_lightvit_small_fpn_1x_coco.log.json)|[google drive](https://drive.google.com/file/d/146RWYNShe5IrltvmgWoelj27LaRQBcXE/view?usp=sharing)|
20 | |LightViT-B|54|240|1x|41.7|38.8|[config](configs/mask_rcnn_lightvit_base_fpn_1x_coco.py)|[log](https://github.com/hunto/LightViT/releases/download/v0.0.1/mask_rcnn_lightvit_base_fpn_1x_coco.log.json)|[google drive](https://drive.google.com/file/d/19IHkVO_Qy__NHe0syrUpIKrhGNoS9VCd/view?usp=sharing)|
21 | |LightViT-T|28|187|3x+ms|41.5|38.4|[config](configs/mask_rcnn_lightvit_tiny_fpn_3x_ms_coco.py)|[log](https://github.com/hunto/LightViT/releases/download/v0.0.1/mask_rcnn_lightvit_tiny_fpn_3x_ms_coco.log.json)|[google drive](https://drive.google.com/file/d/1EfhtuCVCoprbwb_6W-32z_po4hGvAzu8/view?usp=sharing)|
22 | |LightViT-S|38|204|3x+ms|43.2|39.9|[config](configs/mask_rcnn_lightvit_small_fpn_3x_ms_coco.py)|[log](https://github.com/hunto/LightViT/releases/download/v0.0.1/mask_rcnn_lightvit_small_fpn_3x_ms_coco.log.json)|[google drive](https://drive.google.com/file/d/16l5lXYwz3Anx28ncqPfc3mir-nyZyRd6/view?usp=sharing)|
23 | |LightViT-B|54|240|3x+ms|45.0|41.2|[config](configs/mask_rcnn_lightvit_base_fpn_3x_ms_coco.py)|[log](https://github.com/hunto/LightViT/releases/download/v0.0.1/mask_rcnn_lightvit_base_fpn_3x_ms_coco.log.json)|[google drive](https://drive.google.com/file/d/1hLKdemruEKW2DO0227ZNm1zSZCqMWVfi/view?usp=sharing)|
24 |
25 |
26 | ### Training script:
27 |
28 | ```shell
29 | sh dist_train.sh ${CONFIG} 8 work_dirs/${EXP_NAME}
30 | ```
--------------------------------------------------------------------------------
/detection/configs/_base_/datasets/coco_detection.py:
--------------------------------------------------------------------------------
1 | dataset_type = 'CocoDataset'
2 | data_root = 'data/coco/'
3 | img_norm_cfg = dict(
4 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
5 | train_pipeline = [
6 | dict(type='LoadImageFromFile'),
7 | dict(type='LoadAnnotations', with_bbox=True),
8 | dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
9 | dict(type='RandomFlip', flip_ratio=0.5),
10 | dict(type='Normalize', **img_norm_cfg),
11 | dict(type='Pad', size_divisor=32),
12 | dict(type='DefaultFormatBundle'),
13 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
14 | ]
15 | test_pipeline = [
16 | dict(type='LoadImageFromFile'),
17 | dict(
18 | type='MultiScaleFlipAug',
19 | img_scale=(1333, 800),
20 | flip=False,
21 | transforms=[
22 | dict(type='Resize', keep_ratio=True),
23 | dict(type='RandomFlip'),
24 | dict(type='Normalize', **img_norm_cfg),
25 | dict(type='Pad', size_divisor=32),
26 | dict(type='ImageToTensor', keys=['img']),
27 | dict(type='Collect', keys=['img']),
28 | ])
29 | ]
30 | data = dict(
31 | samples_per_gpu=2,
32 | workers_per_gpu=6,
33 | train=dict(
34 | type=dataset_type,
35 | ann_file=data_root + 'annotations/instances_train2017.json',
36 | img_prefix=data_root + 'train2017/',
37 | pipeline=train_pipeline),
38 | val=dict(
39 | type=dataset_type,
40 | ann_file=data_root + 'annotations/instances_val2017.json',
41 | img_prefix=data_root + 'val2017/',
42 | pipeline=test_pipeline),
43 | test=dict(
44 | type=dataset_type,
45 | ann_file=data_root + 'annotations/instances_val2017.json',
46 | img_prefix=data_root + 'val2017/',
47 | pipeline=test_pipeline))
48 | evaluation = dict(interval=1, metric='bbox')
49 |
--------------------------------------------------------------------------------
/detection/configs/_base_/datasets/coco_instance.py:
--------------------------------------------------------------------------------
1 | dataset_type = 'CocoDataset'
2 | data_root = 'data/coco/'
3 | img_norm_cfg = dict(
4 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
5 | train_pipeline = [
6 | dict(type='LoadImageFromFile'),
7 | dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
8 | dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
9 | dict(type='RandomFlip', flip_ratio=0.5),
10 | dict(type='Normalize', **img_norm_cfg),
11 | dict(type='Pad', size_divisor=32),
12 | dict(type='DefaultFormatBundle'),
13 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
14 | ]
15 | test_pipeline = [
16 | dict(type='LoadImageFromFile'),
17 | dict(
18 | type='MultiScaleFlipAug',
19 | img_scale=(1333, 800),
20 | flip=False,
21 | transforms=[
22 | dict(type='Resize', keep_ratio=True),
23 | dict(type='RandomFlip'),
24 | dict(type='Normalize', **img_norm_cfg),
25 | dict(type='Pad', size_divisor=32),
26 | dict(type='ImageToTensor', keys=['img']),
27 | dict(type='Collect', keys=['img']),
28 | ])
29 | ]
30 | data = dict(
31 | samples_per_gpu=2,
32 | workers_per_gpu=6,
33 | train=dict(
34 | type=dataset_type,
35 | ann_file=data_root + 'annotations/instances_train2017.json',
36 | img_prefix=data_root + 'train2017/',
37 | pipeline=train_pipeline),
38 | val=dict(
39 | type=dataset_type,
40 | ann_file=data_root + 'annotations/instances_val2017.json',
41 | img_prefix=data_root + 'val2017/',
42 | pipeline=test_pipeline),
43 | test=dict(
44 | type=dataset_type,
45 | ann_file=data_root + 'annotations/instances_val2017.json',
46 | img_prefix=data_root + 'val2017/',
47 | pipeline=test_pipeline))
48 | evaluation = dict(metric=['bbox', 'segm'])
49 |
--------------------------------------------------------------------------------
/detection/configs/_base_/datasets/coco_instance_ms.py:
--------------------------------------------------------------------------------
1 |
2 | img_norm_cfg = dict(
3 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
4 |
5 | # augmentation strategy originates from DETR / Sparse RCNN
6 | train_pipeline = [
7 | dict(type='LoadImageFromFile'),
8 | dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
9 | dict(type='RandomFlip', flip_ratio=0.5),
10 | dict(type='AutoAugment',
11 | policies=[
12 | [
13 | dict(type='Resize',
14 | img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
15 | (608, 1333), (640, 1333), (672, 1333), (704, 1333),
16 | (736, 1333), (768, 1333), (800, 1333)],
17 | multiscale_mode='value',
18 | keep_ratio=True)
19 | ],
20 | [
21 | dict(type='Resize',
22 | img_scale=[(400, 1333), (500, 1333), (600, 1333)],
23 | multiscale_mode='value',
24 | keep_ratio=True),
25 | dict(type='RandomCrop',
26 | crop_type='absolute_range',
27 | crop_size=(384, 600),
28 | allow_negative_crop=True),
29 | dict(type='Resize',
30 | img_scale=[(480, 1333), (512, 1333), (544, 1333),
31 | (576, 1333), (608, 1333), (640, 1333),
32 | (672, 1333), (704, 1333), (736, 1333),
33 | (768, 1333), (800, 1333)],
34 | multiscale_mode='value',
35 | override=True,
36 | keep_ratio=True)
37 | ]
38 | ]),
39 | dict(type='Normalize', **img_norm_cfg),
40 | dict(type='Pad', size_divisor=32),
41 | dict(type='DefaultFormatBundle'),
42 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
43 | ]
44 | #data = dict(train=dict(pipeline=train_pipeline))
45 | dataset_type = 'CocoDataset'
46 | data_root = 'data/coco/'
47 | test_pipeline = [
48 | dict(type='LoadImageFromFile'),
49 | dict(
50 | type='MultiScaleFlipAug',
51 | img_scale=(1333, 800),
52 | flip=False,
53 | transforms=[
54 | dict(type='Resize', keep_ratio=True),
55 | dict(type='RandomFlip'), dict(type='Normalize', **img_norm_cfg),
56 | dict(type='Pad', size_divisor=32),
57 | dict(type='ImageToTensor', keys=['img']),
58 | dict(type='Collect', keys=['img']),
59 | ])
60 | ]
61 | data = dict(
62 | samples_per_gpu=2,
63 | workers_per_gpu=6,
64 | train=dict(
65 | type=dataset_type,
66 | ann_file=data_root + 'annotations/instances_train2017.json',
67 | img_prefix=data_root + 'train2017/',
68 | pipeline=train_pipeline),
69 | val=dict(
70 | type=dataset_type,
71 | ann_file=data_root + 'annotations/instances_val2017.json',
72 | img_prefix=data_root + 'val2017/',
73 | pipeline=test_pipeline),
74 | test=dict(
75 | type=dataset_type,
76 | ann_file=data_root + 'annotations/instances_val2017.json',
77 | img_prefix=data_root + 'val2017/',
78 | pipeline=test_pipeline))
79 | evaluation = dict(metric=['bbox', 'segm'])
80 |
81 |
82 |
--------------------------------------------------------------------------------
/detection/configs/_base_/default_runtime.py:
--------------------------------------------------------------------------------
1 | checkpoint_config = dict(interval=1)
2 | # yapf:disable
3 | log_config = dict(
4 | interval=50,
5 | hooks=[
6 | dict(type='TextLoggerHook'),
7 | # dict(type='TensorboardLoggerHook')
8 | ])
9 | # yapf:enable
10 | dist_params = dict(backend='nccl')
11 | log_level = 'INFO'
12 | load_from = None
13 | resume_from = None
14 | workflow = [('train', 1)]
15 | #fp16 = dict(loss_scale=512.)
16 | find_unused_parameters = True
17 |
--------------------------------------------------------------------------------
/detection/configs/_base_/models/mask_rcnn_r50_fpn.py:
--------------------------------------------------------------------------------
1 | # model settings
2 | model = dict(
3 | type='MaskRCNN',
4 | pretrained='torchvision://resnet50',
5 | backbone=dict(
6 | type='ResNet',
7 | depth=50,
8 | num_stages=4,
9 | out_indices=(0, 1, 2, 3),
10 | frozen_stages=1,
11 | norm_cfg=dict(type='BN', requires_grad=True),
12 | norm_eval=True,
13 | style='pytorch'),
14 | neck=dict(
15 | type='FPN',
16 | in_channels=[256, 512, 1024, 2048],
17 | out_channels=256,
18 | num_outs=5),
19 | rpn_head=dict(
20 | type='RPNHead',
21 | in_channels=256,
22 | feat_channels=256,
23 | anchor_generator=dict(
24 | type='AnchorGenerator',
25 | scales=[8],
26 | ratios=[0.5, 1.0, 2.0],
27 | strides=[4, 8, 16, 32, 64]),
28 | bbox_coder=dict(
29 | type='DeltaXYWHBBoxCoder',
30 | target_means=[.0, .0, .0, .0],
31 | target_stds=[1.0, 1.0, 1.0, 1.0]),
32 | loss_cls=dict(
33 | type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
34 | loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
35 | roi_head=dict(
36 | type='StandardRoIHead',
37 | bbox_roi_extractor=dict(
38 | type='SingleRoIExtractor',
39 | roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
40 | out_channels=256,
41 | featmap_strides=[4, 8, 16, 32]),
42 | bbox_head=dict(
43 | type='Shared2FCBBoxHead',
44 | in_channels=256,
45 | fc_out_channels=1024,
46 | roi_feat_size=7,
47 | num_classes=80,
48 | bbox_coder=dict(
49 | type='DeltaXYWHBBoxCoder',
50 | target_means=[0., 0., 0., 0.],
51 | target_stds=[0.1, 0.1, 0.2, 0.2]),
52 | reg_class_agnostic=False,
53 | loss_cls=dict(
54 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
55 | loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
56 | mask_roi_extractor=dict(
57 | type='SingleRoIExtractor',
58 | roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
59 | out_channels=256,
60 | featmap_strides=[4, 8, 16, 32]),
61 | mask_head=dict(
62 | type='FCNMaskHead',
63 | num_convs=4,
64 | in_channels=256,
65 | conv_out_channels=256,
66 | num_classes=80,
67 | loss_mask=dict(
68 | type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
69 | # model training and testing settings
70 | train_cfg=dict(
71 | rpn=dict(
72 | assigner=dict(
73 | type='MaxIoUAssigner',
74 | pos_iou_thr=0.7,
75 | neg_iou_thr=0.3,
76 | min_pos_iou=0.3,
77 | match_low_quality=True,
78 | ignore_iof_thr=-1),
79 | sampler=dict(
80 | type='RandomSampler',
81 | num=256,
82 | pos_fraction=0.5,
83 | neg_pos_ub=-1,
84 | add_gt_as_proposals=False),
85 | allowed_border=-1,
86 | pos_weight=-1,
87 | debug=False),
88 | rpn_proposal=dict(
89 | nms_pre=2000,
90 | max_per_img=1000,
91 | nms=dict(type='nms', iou_threshold=0.7),
92 | min_bbox_size=0),
93 | rcnn=dict(
94 | assigner=dict(
95 | type='MaxIoUAssigner',
96 | pos_iou_thr=0.5,
97 | neg_iou_thr=0.5,
98 | min_pos_iou=0.5,
99 | match_low_quality=True,
100 | ignore_iof_thr=-1),
101 | sampler=dict(
102 | type='RandomSampler',
103 | num=512,
104 | pos_fraction=0.25,
105 | neg_pos_ub=-1,
106 | add_gt_as_proposals=True),
107 | mask_size=28,
108 | pos_weight=-1,
109 | debug=False)),
110 | test_cfg=dict(
111 | rpn=dict(
112 | nms_pre=1000,
113 | max_per_img=1000,
114 | nms=dict(type='nms', iou_threshold=0.7),
115 | min_bbox_size=0),
116 | rcnn=dict(
117 | score_thr=0.05,
118 | nms=dict(type='nms', iou_threshold=0.5),
119 | max_per_img=100,
120 | mask_thr_binary=0.5)))
121 |
--------------------------------------------------------------------------------
/detection/configs/_base_/schedules/schedule_1x.py:
--------------------------------------------------------------------------------
1 | # optimizer
2 | optimizer = dict(type='SGD', lr=0.02, weight_decay=0.0001)
3 | optimizer_config = dict(grad_clip=None)
4 | # learning policy
5 | lr_config = dict(
6 | policy='step',
7 | warmup='linear',
8 | warmup_iters=500,
9 | warmup_ratio=0.001,
10 | step=[8, 11])
11 | runner = dict(type='EpochBasedRunner', max_epochs=12)
12 |
--------------------------------------------------------------------------------
/detection/configs/_base_/schedules/schedule_20e.py:
--------------------------------------------------------------------------------
1 | # optimizer
2 | optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
3 | optimizer_config = dict(grad_clip=None)
4 | # learning policy
5 | lr_config = dict(
6 | policy='step',
7 | warmup='linear',
8 | warmup_iters=500,
9 | warmup_ratio=0.001,
10 | step=[16, 19])
11 | runner = dict(type='EpochBasedRunner', max_epochs=20)
12 |
--------------------------------------------------------------------------------
/detection/configs/_base_/schedules/schedule_2x.py:
--------------------------------------------------------------------------------
1 | # optimizer
2 | optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
3 | optimizer_config = dict(grad_clip=None)
4 | # learning policy
5 | lr_config = dict(
6 | policy='step',
7 | warmup='linear',
8 | warmup_iters=500,
9 | warmup_ratio=0.001,
10 | step=[16, 22])
11 | runner = dict(type='EpochBasedRunner', max_epochs=24)
12 |
--------------------------------------------------------------------------------
/detection/configs/mask_rcnn_lightvit_base_fpn_1x_coco.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '_base_/models/mask_rcnn_r50_fpn.py',
3 | '_base_/datasets/coco_instance.py',
4 | '_base_/default_runtime.py'
5 | ]
6 | model = dict(
7 | pretrained=None,
8 | backbone=dict(
9 | _delete_=True,
10 | init_cfg=dict(type='Pretrained', checkpoint='https://github.com/hunto/LightViT/releases/download/v0.0.1/lightvit_base_82.1.ckpt'),
11 | type='lightvit_base',
12 | drop_path_rate=0.1,
13 | out_indices=(0, 1, 2),
14 | stem_norm_eval=True, # fix the BN running stats of the stem layer
15 | ),
16 | neck=dict(
17 | type='LightViTFPN',
18 | in_channels=[128, 256, 512],
19 | out_channels=256,
20 | num_outs=5,
21 | num_extra_trans_convs=1,
22 | ))
23 | # data
24 | data = dict(samples_per_gpu=2) # 2 x 8 = 16
25 | # optimizer
26 | optimizer = dict(type='AdamW', lr=0.0002, betas=(0.9, 0.999), weight_decay=0.04,
27 | paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
28 | 'relative_position_bias_table': dict(decay_mult=0.),
29 | 'norm': dict(decay_mult=0.),
30 | 'stem.1': dict(decay_mult=0.),
31 | 'stem.4': dict(decay_mult=0.),
32 | 'stem.7': dict(decay_mult=0.),
33 | 'stem.10': dict(decay_mult=0.),
34 | 'global_token': dict(decay_mult=0.)
35 | }))
36 | # optimizer_config = dict(grad_clip=None)
37 | # do not use mmdet version fp16
38 | # fp16 = None
39 | optimizer_config = dict(grad_clip=None)
40 | # learning policy
41 | lr_config = dict(
42 | policy='step',
43 | warmup='linear',
44 | warmup_iters=500,
45 | warmup_ratio=0.001,
46 | step=[8, 11])
47 | total_epochs = 12
48 |
49 |
--------------------------------------------------------------------------------
/detection/configs/mask_rcnn_lightvit_base_fpn_3x_ms_coco.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '_base_/models/mask_rcnn_r50_fpn.py',
3 | '_base_/datasets/coco_instance_ms.py',
4 | '_base_/schedules/schedule_1x.py',
5 | '_base_/default_runtime.py'
6 | ]
7 |
8 | model = dict(
9 | pretrained=None,
10 | backbone=dict(
11 | _delete_=True,
12 | init_cfg=dict(type='Pretrained', checkpoint='https://github.com/hunto/LightViT/releases/download/v0.0.1/lightvit_base_82.1.ckpt'),
13 | type='lightvit_base',
14 | drop_path_rate=0.1,
15 | out_indices=(0, 1, 2),
16 | stem_norm_eval=True, # fix the BN running stats of the stem layer
17 | ),
18 | neck=dict(
19 | type='LightViTFPN',
20 | in_channels=[128, 256, 512],
21 | out_channels=256,
22 | num_outs=5,
23 | num_extra_trans_convs=1,
24 | ))
25 |
26 | data = dict(samples_per_gpu=2)
27 |
28 | # optimizer
29 | optimizer = dict(type='AdamW', lr=0.0002, betas=(0.9, 0.999), weight_decay=0.04,
30 | paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
31 | 'relative_position_bias_table': dict(decay_mult=0.),
32 | 'norm': dict(decay_mult=0.),
33 | 'stem.1': dict(decay_mult=0.),
34 | 'stem.4': dict(decay_mult=0.),
35 | 'stem.7': dict(decay_mult=0.),
36 | 'stem.10': dict(decay_mult=0.),
37 | 'global_token': dict(decay_mult=0.)
38 | }))
39 |
40 | # optimizer_config = dict(grad_clip=None)
41 | # do not use mmdet version fp16
42 | # fp16 = None
43 | # optimizer
44 | # learning policy
45 | lr_config = dict(step=[27, 33])
46 | runner = dict(type='EpochBasedRunner', max_epochs=36)
47 | total_epochs = 36
48 | # learning policy
49 | lr_config = dict(
50 | policy='step',
51 | warmup='linear',
52 | warmup_iters=500,
53 | warmup_ratio=0.001,
54 | step=[27, 33])
55 | total_epochs = 36
56 |
57 |
--------------------------------------------------------------------------------
/detection/configs/mask_rcnn_lightvit_small_fpn_1x_coco.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '_base_/models/mask_rcnn_r50_fpn.py',
3 | '_base_/datasets/coco_instance.py',
4 | '_base_/default_runtime.py'
5 | ]
6 | model = dict(
7 | pretrained=None,
8 | backbone=dict(
9 | _delete_=True,
10 | type='lightvit_small',
11 | drop_path_rate=0.1,
12 | init_cfg=dict(type='Pretrained', checkpoint='https://github.com/hunto/LightViT/releases/download/v0.0.1/lightvit_small_80.9.ckpt'),
13 | out_indices=(0, 1, 2),
14 | stem_norm_eval=True, # fix the BN running stats of the stem layer
15 | ),
16 | neck=dict(
17 | type='LightViTFPN',
18 | in_channels=[96, 192, 384],
19 | out_channels=256,
20 | num_outs=5,
21 | num_extra_trans_convs=1,
22 | ))
23 | # data
24 | data = dict(samples_per_gpu=2) # 2 x 8 = 16
25 | # optimizer
26 | optimizer = dict(type='AdamW', lr=0.0002, betas=(0.9, 0.999), weight_decay=0.04,
27 | paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
28 | 'relative_position_bias_table': dict(decay_mult=0.),
29 | 'norm': dict(decay_mult=0.),
30 | 'stem.1': dict(decay_mult=0.),
31 | 'stem.4': dict(decay_mult=0.),
32 | 'stem.7': dict(decay_mult=0.),
33 | 'stem.10': dict(decay_mult=0.),
34 | 'global_token': dict(decay_mult=0.)
35 | }))
36 | # optimizer_config = dict(grad_clip=None)
37 | # do not use mmdet version fp16
38 | # fp16 = None
39 | optimizer_config = dict(grad_clip=None)
40 | # learning policy
41 | lr_config = dict(
42 | policy='step',
43 | warmup='linear',
44 | warmup_iters=500,
45 | warmup_ratio=0.001,
46 | step=[8, 11])
47 | total_epochs = 12
48 |
49 |
--------------------------------------------------------------------------------
/detection/configs/mask_rcnn_lightvit_small_fpn_3x_ms_coco.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '_base_/models/mask_rcnn_r50_fpn.py',
3 | '_base_/datasets/coco_instance_ms.py',
4 | '_base_/schedules/schedule_1x.py',
5 | '_base_/default_runtime.py'
6 | ]
7 |
8 | model = dict(
9 | pretrained=None,
10 | backbone=dict(
11 | _delete_=True,
12 | type='lightvit_small',
13 | drop_path_rate=0.1,
14 | init_cfg=dict(type='Pretrained', checkpoint='https://github.com/hunto/LightViT/releases/download/v0.0.1/lightvit_small_80.9.ckpt'),
15 | out_indices=(0, 1, 2),
16 | stem_norm_eval=True, # fix the BN running stats of the stem layer
17 | ),
18 | neck=dict(
19 | type='LightViTFPN',
20 | in_channels=[96, 192, 384],
21 | out_channels=256,
22 | num_outs=5,
23 | num_extra_trans_convs=1,
24 | ))
25 |
26 | data = dict(samples_per_gpu=2)
27 |
28 | # optimizer
29 | optimizer = dict(type='AdamW', lr=0.0002, betas=(0.9, 0.999), weight_decay=0.04,
30 | paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
31 | 'relative_position_bias_table': dict(decay_mult=0.),
32 | 'norm': dict(decay_mult=0.),
33 | 'stem.1': dict(decay_mult=0.),
34 | 'stem.4': dict(decay_mult=0.),
35 | 'stem.7': dict(decay_mult=0.),
36 | 'stem.10': dict(decay_mult=0.),
37 | 'global_token': dict(decay_mult=0.)
38 | }))
39 | # optimizer_config = dict(grad_clip=None)
40 | # do not use mmdet version fp16
41 | # fp16 = None
42 | # optimizer
43 | # learning policy
44 | lr_config = dict(step=[27, 33])
45 | runner = dict(type='EpochBasedRunner', max_epochs=36)
46 | total_epochs = 36
47 | # learning policy
48 | lr_config = dict(
49 | policy='step',
50 | warmup='linear',
51 | warmup_iters=500,
52 | warmup_ratio=0.001,
53 | step=[27, 33])
54 | total_epochs = 36
55 |
56 |
--------------------------------------------------------------------------------
/detection/configs/mask_rcnn_lightvit_tiny_fpn_1x_coco.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '_base_/models/mask_rcnn_r50_fpn.py',
3 | '_base_/datasets/coco_instance.py',
4 | '_base_/default_runtime.py'
5 | ]
6 | model = dict(
7 | pretrained=None,
8 | backbone=dict(
9 | _delete_=True,
10 | init_cfg=dict(type='Pretrained', checkpoint='https://github.com/hunto/LightViT/releases/download/v0.0.1/lightvit_tiny_78.7.ckpt'),
11 | type='lightvit_tiny',
12 | drop_path_rate=0.1,
13 | out_indices=(0, 1, 2),
14 | stem_norm_eval=True, # fix the BN running stats of the stem layer
15 | ),
16 | neck=dict(
17 | type='LightViTFPN',
18 | in_channels=[64, 128, 256],
19 | out_channels=256,
20 | num_outs=5,
21 | num_extra_trans_convs=1,
22 | ))
23 | # data
24 | data = dict(samples_per_gpu=2) # 2 x 8 = 16
25 | # optimizer
26 | optimizer = dict(type='AdamW', lr=0.0002, betas=(0.9, 0.999), weight_decay=0.04,
27 | paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
28 | 'relative_position_bias_table': dict(decay_mult=0.),
29 | 'norm': dict(decay_mult=0.),
30 | 'stem.1': dict(decay_mult=0.),
31 | 'stem.4': dict(decay_mult=0.),
32 | 'stem.7': dict(decay_mult=0.),
33 | 'stem.10': dict(decay_mult=0.),
34 | 'global_token': dict(decay_mult=0.)
35 | }))
36 | # optimizer_config = dict(grad_clip=None)
37 | # do not use mmdet version fp16
38 | # fp16 = None
39 | optimizer_config = dict(grad_clip=None)
40 | # learning policy
41 | lr_config = dict(
42 | policy='step',
43 | warmup='linear',
44 | warmup_iters=500,
45 | warmup_ratio=0.001,
46 | step=[8, 11])
47 | total_epochs = 12
48 |
49 |
--------------------------------------------------------------------------------
/detection/configs/mask_rcnn_lightvit_tiny_fpn_3x_ms_coco.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '_base_/models/mask_rcnn_r50_fpn.py',
3 | '_base_/datasets/coco_instance_ms.py',
4 | '_base_/schedules/schedule_1x.py',
5 | '_base_/default_runtime.py'
6 | ]
7 | model = dict(
8 | pretrained=None,
9 | backbone=dict(
10 | _delete_=True,
11 | init_cfg=dict(type='Pretrained', checkpoint='https://github.com/hunto/LightViT/releases/download/v0.0.1/lightvit_tiny_78.7.ckpt'),
12 | type='lightvit_tiny',
13 | drop_path_rate=0.1,
14 | out_indices=(0, 1, 2),
15 | stem_norm_eval=True, # fix the BN running stats of the stem layer
16 | ),
17 | neck=dict(
18 | type='LightViTFPN',
19 | in_channels=[64, 128, 256],
20 | out_channels=256,
21 | num_outs=5,
22 | num_extra_trans_convs=1,
23 | ))
24 |
25 | # optimizer
26 | optimizer = dict(type='AdamW', lr=0.0002, betas=(0.9, 0.999), weight_decay=0.04,
27 | paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
28 | 'relative_position_bias_table': dict(decay_mult=0.),
29 | 'norm': dict(decay_mult=0.),
30 | 'stem.1': dict(decay_mult=0.),
31 | 'stem.4': dict(decay_mult=0.),
32 | 'stem.7': dict(decay_mult=0.),
33 | 'stem.10': dict(decay_mult=0.),
34 | 'global_token': dict(decay_mult=0.)
35 | }))
36 | # optimizer_config = dict(grad_clip=None)
37 | # do not use mmdet version fp16
38 | # fp16 = None
39 | # optimizer
40 | # learning policy
41 | lr_config = dict(step=[27, 33])
42 | runner = dict(type='EpochBasedRunner', max_epochs=36)
43 | total_epochs = 36
44 | # learning policy
45 | lr_config = dict(
46 | policy='step',
47 | warmup='linear',
48 | warmup_iters=500,
49 | warmup_ratio=0.001,
50 | step=[27, 33])
51 | total_epochs = 36
52 |
53 |
--------------------------------------------------------------------------------
/detection/dist_test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CONFIG=$1
4 | CHECKPOINT=$2
5 | GPUS=$3
6 | PORT=${PORT:-29500}
7 |
8 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
9 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
10 | $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4}
11 |
--------------------------------------------------------------------------------
/detection/dist_train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CONFIG=$1
4 | GPUS=$2
5 | PORT=${PORT:-29500}
6 |
7 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
8 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
9 | $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3}
10 |
--------------------------------------------------------------------------------
/detection/get_flops.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import argparse
3 |
4 | import numpy as np
5 | import torch
6 | from mmcv import Config, DictAction
7 |
8 | from mmdet.models import build_detector
9 |
10 | import lightvit
11 | import lightvit_fpn
12 |
13 | try:
14 | from mmcv.cnn import get_model_complexity_info
15 | except ImportError:
16 | raise ImportError('Please upgrade mmcv to >0.6.2')
17 |
18 |
19 | def parse_args():
20 | parser = argparse.ArgumentParser(description='Train a detector')
21 | parser.add_argument('config', help='train config file path')
22 | parser.add_argument(
23 | '--shape',
24 | type=int,
25 | nargs='+',
26 | default=[1280, 800],
27 | help='input image size')
28 | parser.add_argument(
29 | '--cfg-options',
30 | nargs='+',
31 | action=DictAction,
32 | help='override some settings in the used config, the key-value pair '
33 | 'in xxx=yyy format will be merged into config file. If the value to '
34 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
35 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
36 | 'Note that the quotation marks are necessary and that no white space '
37 | 'is allowed.')
38 | parser.add_argument(
39 | '--size-divisor',
40 | type=int,
41 | default=32,
42 | help='Pad the input image, the minimum size that is divisible '
43 | 'by size_divisor, -1 means do not pad the image.')
44 | args = parser.parse_args()
45 | return args
46 |
47 |
48 | def main():
49 |
50 | args = parse_args()
51 |
52 | if len(args.shape) == 1:
53 | h = w = args.shape[0]
54 | elif len(args.shape) == 2:
55 | h, w = args.shape
56 | else:
57 | raise ValueError('invalid input shape')
58 | ori_shape = (3, h, w)
59 | divisor = args.size_divisor
60 | if divisor > 0:
61 | h = int(np.ceil(h / divisor)) * divisor
62 | w = int(np.ceil(w / divisor)) * divisor
63 |
64 | input_shape = (3, h, w)
65 |
66 | cfg = Config.fromfile(args.config)
67 | if args.cfg_options is not None:
68 | cfg.merge_from_dict(args.cfg_options)
69 |
70 | model = build_detector(
71 | cfg.model,
72 | train_cfg=cfg.get('train_cfg'),
73 | test_cfg=cfg.get('test_cfg'))
74 | if torch.cuda.is_available():
75 | model.cuda()
76 | model.eval()
77 |
78 | if hasattr(model, 'forward_dummy'):
79 | model.forward = model.forward_dummy
80 | else:
81 | raise NotImplementedError(
82 | 'FLOPs counter is currently not currently supported with {}'.
83 | format(model.__class__.__name__))
84 |
85 | flops, params = get_model_complexity_info(model, input_shape)
86 | split_line = '=' * 30
87 |
88 | if divisor > 0 and \
89 | input_shape != ori_shape:
90 | print(f'{split_line}\nUse size divisor set input shape '
91 | f'from {ori_shape} to {input_shape}\n')
92 | print(f'{split_line}\nInput shape: {input_shape}\n'
93 | f'Flops: {flops}\nParams: {params}\n{split_line}')
94 | print('!!!Please be cautious if you use the results in papers. '
95 | 'You may need to check if all ops are supported and verify that the '
96 | 'flops computation is correct.')
97 |
98 |
99 | if __name__ == '__main__':
100 | main()
101 |
--------------------------------------------------------------------------------
/detection/lightvit.py:
--------------------------------------------------------------------------------
1 | import math
2 | from functools import partial
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torch.nn.modules.batchnorm import _BatchNorm
8 |
9 | from timm.models.layers import DropPath, trunc_normal_, drop_path
10 |
11 | from mmdet.models.builder import BACKBONES
12 | from mmcv.runner import (auto_fp16, force_fp32,)
13 | from mmcv.runner import BaseModule
14 |
15 |
16 | class ConvStem(nn.Module):
17 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_eval=False):
18 | super().__init__()
19 |
20 | self.patch_size = patch_size
21 | self.norm_eval = norm_eval
22 |
23 | stem_dim = embed_dim // 2
24 | self.stem = nn.Sequential(
25 | nn.Conv2d(in_chans, stem_dim, kernel_size=3,
26 | stride=2, padding=1, bias=False),
27 | nn.BatchNorm2d(stem_dim),
28 | nn.GELU(),
29 | nn.Conv2d(stem_dim, stem_dim, kernel_size=3,
30 | groups=stem_dim, stride=1, padding=1, bias=False),
31 | nn.BatchNorm2d(stem_dim),
32 | nn.GELU(),
33 | nn.Conv2d(stem_dim, stem_dim, kernel_size=3,
34 | groups=stem_dim, stride=1, padding=1, bias=False),
35 | nn.BatchNorm2d(stem_dim),
36 | nn.GELU(),
37 | nn.Conv2d(stem_dim, stem_dim, kernel_size=3,
38 | groups=stem_dim, stride=2, padding=1, bias=False),
39 | nn.BatchNorm2d(stem_dim),
40 | nn.GELU(),
41 | )
42 | self.proj = nn.Conv2d(stem_dim, embed_dim,
43 | kernel_size=3,
44 | stride=2, padding=1)
45 | self.norm = nn.LayerNorm(embed_dim)
46 |
47 | def forward(self, x):
48 | stem = self.stem(x)
49 | x = self.proj(stem)
50 | _, _, H, W = x.shape
51 | x = x.flatten(2).transpose(1, 2)
52 | x = self.norm(x)
53 | return x, (H, W)
54 |
55 | def train(self, mode=True):
56 | """Convert the model into training mode while keep normalization layer
57 | freezed."""
58 | super().train(mode)
59 | if mode and self.norm_eval:
60 | for m in self.modules():
61 | if isinstance(m, _BatchNorm):
62 | m.eval()
63 |
64 |
65 | class BiAttn(nn.Module):
66 | def __init__(self, in_channels, act_ratio=0.25, act_fn=nn.GELU, gate_fn=nn.Sigmoid):
67 | super().__init__()
68 | reduce_channels = int(in_channels * act_ratio)
69 | self.norm = nn.LayerNorm(in_channels)
70 | self.global_reduce = nn.Linear(in_channels, reduce_channels)
71 | self.local_reduce = nn.Linear(in_channels, reduce_channels)
72 | self.act_fn = act_fn()
73 | self.channel_select = nn.Linear(reduce_channels, in_channels)
74 | self.spatial_select = nn.Linear(reduce_channels * 2, 1)
75 | self.gate_fn = gate_fn()
76 |
77 | def forward(self, x):
78 | ori_x = x
79 | x = self.norm(x)
80 | x_global = x.mean(1, keepdim=True)
81 | x_global = self.act_fn(self.global_reduce(x_global))
82 | x_local = self.act_fn(self.local_reduce(x))
83 |
84 | c_attn = self.channel_select(x_global)
85 | c_attn = self.gate_fn(c_attn) # [B, 1, C]
86 | s_attn = self.spatial_select(torch.cat([x_local, x_global.expand(-1, x.shape[1], -1)], dim=-1))
87 | s_attn = self.gate_fn(s_attn) # [B, N, 1]
88 |
89 | attn = c_attn * s_attn # [B, N, C]
90 | return ori_x * attn
91 |
92 |
93 | class BiAttnMlp(nn.Module):
94 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
95 | super().__init__()
96 | out_features = out_features or in_features
97 | hidden_features = hidden_features or in_features
98 | self.fc1 = nn.Linear(in_features, hidden_features)
99 | self.act = act_layer()
100 | self.fc2 = nn.Linear(hidden_features, out_features)
101 | self.attn = BiAttn(out_features)
102 | self.drop = nn.Dropout(drop) if drop > 0 else nn.Identity()
103 |
104 | def forward(self, x):
105 | x = self.fc1(x)
106 | x = self.act(x)
107 | x = self.drop(x)
108 | x = self.fc2(x)
109 | x = self.attn(x)
110 | x = self.drop(x)
111 | return x
112 |
113 |
114 | def window_reverse(
115 | windows: torch.Tensor,
116 | original_size,
117 | window_size=(7, 7)
118 | ) -> torch.Tensor:
119 | """ Reverses the window partition.
120 | Args:
121 | windows (torch.Tensor): Window tensor of the shape [B * windows, window_size[0] * window_size[1], C].
122 | original_size (Tuple[int, int]): Original shape.
123 | window_size (Tuple[int, int], optional): Window size which have been applied. Default (7, 7)
124 | Returns:
125 | output (torch.Tensor): Folded output tensor of the shape [B, original_size[0] * original_size[1], C].
126 | """
127 | # Get height and width
128 | H, W = original_size
129 | # Compute original batch size
130 | B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
131 | # Fold grid tensor
132 | output = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1)
133 | output = output.permute(0, 1, 3, 2, 4, 5).reshape(B, H * W, -1)
134 | return output
135 |
136 |
137 | def get_relative_position_index(
138 | win_h: int,
139 | win_w: int
140 | ) -> torch.Tensor:
141 | """ Function to generate pair-wise relative position index for each token inside the window.
142 | Taken from Timms Swin V1 implementation.
143 | Args:
144 | win_h (int): Window/Grid height.
145 | win_w (int): Window/Grid width.
146 | Returns:
147 | relative_coords (torch.Tensor): Pair-wise relative position indexes [height * width, height * width].
148 | """
149 | coords = torch.stack(torch.meshgrid([torch.arange(win_h), torch.arange(win_w)]))
150 | coords_flatten = torch.flatten(coords, 1)
151 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
152 | relative_coords = relative_coords.permute(1, 2, 0).contiguous()
153 | relative_coords[:, :, 0] += win_h - 1
154 | relative_coords[:, :, 1] += win_w - 1
155 | relative_coords[:, :, 0] *= 2 * win_w - 1
156 | return relative_coords.sum(-1)
157 |
158 |
159 | class Attention(nn.Module):
160 | def __init__(self, dim, num_tokens=1, num_heads=8, window_size=7, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
161 | super().__init__()
162 | self.num_heads = num_heads
163 | head_dim = dim // num_heads
164 | self.num_tokens = num_tokens
165 | self.window_size = window_size
166 | self.attn_area = window_size * window_size
167 | self.scale = qk_scale or head_dim ** -0.5
168 |
169 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
170 | self.kv_global = nn.Linear(dim, dim * 2, bias=qkv_bias)
171 | self.attn_drop = nn.Dropout(attn_drop) if attn_drop > 0 else nn.Identity()
172 | self.proj = nn.Linear(dim, dim)
173 | self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0 else nn.Identity()
174 |
175 | # positional embedding
176 | # Define a parameter table of relative position bias, shape: 2*Wh-1 * 2*Ww-1, nH
177 | self.relative_position_bias_table = nn.Parameter(
178 | torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads))
179 |
180 | # Get pair-wise relative position index for each token inside the window
181 | self.register_buffer("relative_position_index", get_relative_position_index(window_size,
182 | window_size).view(-1))
183 | # Init relative positional bias
184 | trunc_normal_(self.relative_position_bias_table, std=.02)
185 |
186 | def _get_relative_positional_bias(
187 | self
188 | ) -> torch.Tensor:
189 | """ Returns the relative positional bias.
190 | Returns:
191 | relative_position_bias (torch.Tensor): Relative positional bias.
192 | """
193 | relative_position_bias = self.relative_position_bias_table[
194 | self.relative_position_index].view(self.attn_area, self.attn_area, -1)
195 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
196 | return relative_position_bias.unsqueeze(0)
197 |
198 | def forward_global_aggregation(self, q, k, v):
199 | """
200 | q: global tokens
201 | k: image tokens
202 | v: image tokens
203 | """
204 | B, _, N, _ = q.shape
205 | attn = (q @ k.transpose(-2, -1)) * self.scale
206 | attn = attn.softmax(dim=-1)
207 | attn = self.attn_drop(attn)
208 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
209 | return x
210 |
211 | def forward_local(self, q, k, v, H, W):
212 | """
213 | q: image tokens
214 | k: image tokens
215 | v: image tokens
216 | """
217 | B, num_heads, N, C = q.shape
218 | ws = self.window_size
219 | h_group, w_group = H // ws, W // ws
220 |
221 | # partition to windows
222 | q = q.view(B, num_heads, h_group, ws, w_group, ws, -1).permute(0, 2, 4, 1, 3, 5, 6).contiguous()
223 | q = q.view(-1, num_heads, ws*ws, C)
224 | k = k.view(B, num_heads, h_group, ws, w_group, ws, -1).permute(0, 2, 4, 1, 3, 5, 6).contiguous()
225 | k = k.view(-1, num_heads, ws*ws, C)
226 | v = v.view(B, num_heads, h_group, ws, w_group, ws, -1).permute(0, 2, 4, 1, 3, 5, 6).contiguous()
227 | v = v.view(-1, num_heads, ws*ws, v.shape[-1])
228 |
229 | attn = (q @ k.transpose(-2, -1)) * self.scale
230 | pos_bias = self._get_relative_positional_bias()
231 | attn = (attn + pos_bias).softmax(dim=-1)
232 | attn = self.attn_drop(attn)
233 | x = (attn @ v).transpose(1, 2).reshape(v.shape[0], ws*ws, -1)
234 |
235 | # reverse
236 | x = window_reverse(x, (H, W), (ws, ws))
237 | return x
238 |
239 | def forward_global_broadcast(self, q, k, v):
240 | """
241 | q: image tokens
242 | k: global tokens
243 | v: global tokens
244 | """
245 | B, num_heads, N, _ = q.shape
246 | attn = (q @ k.transpose(-2, -1)) * self.scale
247 | attn = attn.softmax(dim=-1)
248 | attn = self.attn_drop(attn)
249 | x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
250 | return x
251 |
252 | def forward(self, x, H, W):
253 | B, N, C = x.shape
254 | NC = self.num_tokens
255 | # pad
256 | x_img, x_global = x[:, NC:], x[:, :NC]
257 | x_img = x_img.view(B, H, W, C)
258 | pad_l = pad_t = 0
259 | ws = self.window_size
260 | pad_r = (ws - W % ws) % ws
261 | pad_b = (ws - H % ws) % ws
262 | x_img = F.pad(x_img, (0, 0, pad_l, pad_r, pad_t, pad_b))
263 | Hp, Wp = x_img.shape[1], x_img.shape[2]
264 | x_img = x_img.view(B, -1, C)
265 | x = torch.cat([x_global, x_img], dim=1)
266 |
267 | # qkv
268 | qkv = self.qkv(x)
269 | q, k, v = qkv.view(B, -1, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).unbind(0)
270 |
271 | # split img tokens & global tokens
272 | q_img, k_img, v_img = q[:, :, NC:], k[:, :, NC:], v[:, :, NC:]
273 | q_cls, _, _ = q[:, :, :NC], k[:, :, :NC], v[:, :, :NC]
274 |
275 | # local window attention
276 | x_img = self.forward_local(q_img, k_img, v_img, Hp, Wp)
277 | # restore to the original size
278 | x_img = x_img.view(B, Hp, Wp, -1)[:, :H, :W].reshape(B, H*W, -1)
279 | q_img = q_img.reshape(B, self.num_heads, Hp, Wp, -1)[:, :, :H, :W].reshape(B, self.num_heads, H*W, -1)
280 | k_img = k_img.reshape(B, self.num_heads, Hp, Wp, -1)[:, :, :H, :W].reshape(B, self.num_heads, H*W, -1)
281 | v_img = v_img.reshape(B, self.num_heads, Hp, Wp, -1)[:, :, :H, :W].reshape(B, self.num_heads, H*W, -1)
282 |
283 | # global aggregation
284 | x_cls = self.forward_global_aggregation(q_cls, k_img, v_img)
285 | k_cls, v_cls = self.kv_global(x_cls).view(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).unbind(0)
286 |
287 | # gloal broadcast
288 | x_img = x_img + self.forward_global_broadcast(q_img, k_cls, v_cls)
289 |
290 | x = torch.cat([x_cls, x_img], dim=1)
291 | x = self.proj(x)
292 | return x
293 |
294 |
295 | class Block(nn.Module):
296 |
297 | def __init__(self, dim, num_heads, num_tokens=1, window_size=7, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
298 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attention=Attention, last_block=False):
299 | super().__init__()
300 | self.last_block = last_block
301 | self.norm1 = norm_layer(dim)
302 | self.attn = attention(dim, num_heads=num_heads, num_tokens=num_tokens, window_size=window_size, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
303 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
304 | self.norm2 = norm_layer(dim)
305 | mlp_hidden_dim = int(dim * mlp_ratio)
306 | self.mlp = BiAttnMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
307 |
308 | def forward(self, x, H, W):
309 | x = x + self.drop_path(self.attn(self.norm1(x), H, W))
310 | if self.last_block:
311 | # ignore unused global tokens in downstream tasks
312 | x = x[:, -H*W:]
313 | x = x + self.drop_path(self.mlp(self.norm2(x)))
314 | return x
315 |
316 |
317 | class ResidualMergePatch(nn.Module):
318 | def __init__(self, dim, out_dim, num_tokens=1):
319 | super().__init__()
320 | self.num_tokens = num_tokens
321 | self.norm = nn.LayerNorm(4 * dim)
322 | self.reduction = nn.Linear(4 * dim, out_dim, bias=False)
323 | self.norm2 = nn.LayerNorm(dim)
324 | self.proj = nn.Linear(dim, out_dim, bias=False)
325 | # use MaxPool3d to avoid permutations
326 | self.maxp = nn.MaxPool3d((2, 2, 1), (2, 2, 1))
327 | self.res_proj = nn.Linear(dim, out_dim, bias=False)
328 |
329 | def forward(self, x, H, W):
330 | global_token, x = x[:, :self.num_tokens].contiguous(), x[:, self.num_tokens:].contiguous()
331 | B, L, C = x.shape
332 |
333 | x = x.view(B, H, W, C)
334 | # pad
335 | pad_l = pad_t = 0
336 | pad_r = (2 - W % 2) % 2
337 | pad_b = (2 - H % 2) % 2
338 | x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
339 |
340 | res = self.res_proj(self.maxp(x).view(B, -1, C))
341 |
342 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
343 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
344 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
345 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
346 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
347 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
348 |
349 | x = self.norm(x)
350 | x = self.reduction(x)
351 | x = x + res
352 | global_token = self.proj(self.norm2(global_token))
353 | x = torch.cat([global_token, x], 1)
354 | return x, (math.ceil(H / 2), math.ceil(W / 2))
355 |
356 |
357 | class LightViT(BaseModule):
358 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[32, 64, 160, 256], num_layers=[2, 2, 2, 2],
359 | num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], num_tokens=8, window_size=7, neck_dim=1280, qkv_bias=True,
360 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=ConvStem, norm_layer=None,
361 | act_layer=None, weight_init='', out_indices=(0, 1, 2, 3), stem_norm_eval=False, init_cfg=None):
362 | super().__init__(init_cfg)
363 | self.num_classes = num_classes
364 | self.embed_dims = embed_dims
365 | self.num_tokens = num_tokens
366 | self.mlp_ratios = mlp_ratios
367 | self.patch_size = patch_size
368 | self.num_layers = num_layers
369 | self.out_indices = out_indices
370 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
371 | act_layer = act_layer or nn.GELU
372 |
373 | self.patch_embed = embed_layer(
374 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0], norm_eval=stem_norm_eval)
375 |
376 | self.global_token = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dims[0]))
377 |
378 | stages = []
379 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(num_layers))] # stochastic depth decay rule
380 | for stage, (embed_dim, num_layer, num_head, mlp_ratio) in enumerate(zip(embed_dims, num_layers, num_heads, mlp_ratios)):
381 | blocks = []
382 | if stage > 0:
383 | # downsample
384 | blocks.append(ResidualMergePatch(embed_dims[stage-1], embed_dim, num_tokens=num_tokens))
385 | blocks += [
386 | Block(
387 | dim=embed_dim, num_heads=num_head, num_tokens=num_tokens, window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
388 | attn_drop=attn_drop_rate, drop_path=dpr[sum(num_layers[:stage]) + i], norm_layer=norm_layer, act_layer=act_layer, attention=Attention,
389 | last_block=(stage==len(embed_dims)-1 and i == num_layer - 1))
390 | for i in range(num_layer)]
391 | blocks = nn.Sequential(*blocks)
392 | stages.append(blocks)
393 | self.stages = nn.Sequential(*stages)
394 |
395 | # add a norm layer for each output
396 | for i_layer in out_indices:
397 | layer = norm_layer(embed_dims[i_layer])
398 | layer_name = f'norm{i_layer}'
399 | self.add_module(layer_name, layer)
400 |
401 | def forward_features(self, x):
402 | outputs = []
403 | x, (H, W) = self.patch_embed(x)
404 | x_out = x.view(-1, H, W, x.shape[-1]).permute(0, 3, 1, 2).contiguous()
405 | global_token = self.global_token.expand(x.shape[0], -1, -1)
406 | x = torch.cat((global_token, x), dim=1)
407 |
408 | for i_stage, stage in enumerate(self.stages):
409 | for block in stage:
410 | if isinstance(block, ResidualMergePatch):
411 | x, (H, W) = block(x, H, W)
412 | elif isinstance(block, Block):
413 | x = block(x, H, W)
414 | else:
415 | x = block(x)
416 |
417 | if i_stage in self.out_indices:
418 | norm_layer = getattr(self, f'norm{i_stage}')
419 | x_out = norm_layer(x[:, -H*W:])
420 | x_out = x_out.view(-1, H, W, self.embed_dims[i_stage]).permute(0, 3, 1, 2).contiguous()
421 | outputs.append(x_out)
422 | return tuple(outputs)
423 |
424 | @auto_fp16()
425 | def forward(self, x):
426 | x = self.forward_features(x)
427 | return x
428 |
429 |
430 | @BACKBONES.register_module()
431 | class lightvit_tiny(LightViT):
432 |
433 | def __init__(self, **kwargs):
434 | model_kwargs = dict(patch_size=8, embed_dims=[64, 128, 256], num_layers=[2, 6, 6],
435 | num_heads=[2, 4, 8, ], mlp_ratios=[8, 4, 4], num_tokens=512, drop_path_rate=0.1)
436 | model_kwargs.update(kwargs)
437 | super().__init__(**model_kwargs)
438 |
439 |
440 | @BACKBONES.register_module()
441 | class lightvit_small(LightViT):
442 |
443 | def __init__(self, **kwargs):
444 | model_kwargs = dict(patch_size=8, embed_dims=[96, 192, 384], num_layers=[2, 6, 6],
445 | num_heads=[3, 6, 12, ], mlp_ratios=[8, 4, 4], num_tokens=16, drop_path_rate=0.1)
446 | model_kwargs.update(kwargs)
447 | super().__init__(**model_kwargs)
448 |
449 |
450 | @BACKBONES.register_module()
451 | class lightvit_base(LightViT):
452 |
453 | def __init__(self, **kwargs):
454 | model_kwargs = dict(patch_size=8, embed_dims=[128, 256, 512], num_layers=[3, 8, 6],
455 | num_heads=[4, 8, 16, ], mlp_ratios=[8, 4, 4], num_tokens=24, drop_path_rate=0.1)
456 | model_kwargs.update(kwargs)
457 | super().__init__(**model_kwargs)
458 |
--------------------------------------------------------------------------------
/detection/lightvit_fpn.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from mmcv.cnn import ConvModule, xavier_init
6 | from mmcv.runner import auto_fp16
7 |
8 | from mmdet.models.builder import NECKS
9 |
10 |
11 | @NECKS.register_module()
12 | class LightViTFPN(nn.Module):
13 | r"""Feature Pyramid Network for LightViT.
14 |
15 | Args:
16 | in_channels (List[int]): Number of input channels per scale.
17 | out_channels (int): Number of output channels (used at each scale)
18 | num_outs (int): Number of output scales.
19 | start_level (int): Index of the start input backbone level used to
20 | build the feature pyramid. Default: 0.
21 | end_level (int): Index of the end input backbone level (exclusive) to
22 | build the feature pyramid. Default: -1, which means the last level.
23 | add_extra_convs (bool | str): If bool, it decides whether to add conv
24 | layers on top of the original feature maps. Default to False.
25 | If True, its actual mode is specified by `extra_convs_on_inputs`.
26 | If str, it specifies the source feature map of the extra convs.
27 | Only the following options are allowed
28 |
29 | - 'on_input': Last feat map of neck inputs (i.e. backbone feature).
30 | - 'on_lateral': Last feature map after lateral convs.
31 | - 'on_output': The last output feature map after fpn convs.
32 | extra_convs_on_inputs (bool, deprecated): Whether to apply extra convs
33 | on the original feature from the backbone. If True,
34 | it is equivalent to `add_extra_convs='on_input'`. If False, it is
35 | equivalent to set `add_extra_convs='on_output'`. Default to True.
36 | relu_before_extra_convs (bool): Whether to apply relu before the extra
37 | conv. Default: False.
38 | no_norm_on_lateral (bool): Whether to apply norm on lateral.
39 | Default: False.
40 | num_extra_trans_convs (int): extra transposed conv on the output
41 | with largest resolution. Default: 0.
42 | conv_cfg (dict): Config dict for convolution layer. Default: None.
43 | norm_cfg (dict): Config dict for normalization layer. Default: None.
44 | act_cfg (str): Config dict for activation layer in ConvModule.
45 | Default: None.
46 | upsample_cfg (dict): Config dict for interpolate layer.
47 | Default: `dict(mode='nearest')`
48 |
49 | Example:
50 | >>> import torch
51 | >>> in_channels = [2, 3, 5, 7]
52 | >>> scales = [340, 170, 84, 43]
53 | >>> inputs = [torch.rand(1, c, s, s)
54 | ... for c, s in zip(in_channels, scales)]
55 | >>> self = FPN(in_channels, 11, len(in_channels)).eval()
56 | >>> outputs = self.forward(inputs)
57 | >>> for i in range(len(outputs)):
58 | ... print(f'outputs[{i}].shape = {outputs[i].shape}')
59 | outputs[0].shape = torch.Size([1, 11, 340, 340])
60 | outputs[1].shape = torch.Size([1, 11, 170, 170])
61 | outputs[2].shape = torch.Size([1, 11, 84, 84])
62 | outputs[3].shape = torch.Size([1, 11, 43, 43])
63 | """
64 |
65 | def __init__(self,
66 | in_channels,
67 | out_channels,
68 | num_outs,
69 | start_level=0,
70 | end_level=-1,
71 | add_extra_convs=False,
72 | extra_convs_on_inputs=True,
73 | relu_before_extra_convs=False,
74 | no_norm_on_lateral=False,
75 | num_extra_trans_convs=0,
76 | conv_cfg=None,
77 | norm_cfg=None,
78 | act_cfg=None,
79 | upsample_cfg=dict(mode='nearest')):
80 | super(LightViTFPN, self).__init__()
81 | assert isinstance(in_channels, list)
82 | self.in_channels = in_channels
83 | self.out_channels = out_channels
84 | self.num_ins = len(in_channels)
85 | self.num_outs = num_outs
86 | self.relu_before_extra_convs = relu_before_extra_convs
87 | self.no_norm_on_lateral = no_norm_on_lateral
88 | self.num_extra_trans_convs = num_extra_trans_convs
89 | self.fp16_enabled = False
90 | self.upsample_cfg = upsample_cfg.copy()
91 |
92 | if end_level == -1:
93 | self.backbone_end_level = self.num_ins
94 | assert num_outs >= self.num_ins - start_level
95 | else:
96 | # if end_level < inputs, no extra level is allowed
97 | self.backbone_end_level = end_level
98 | assert end_level <= len(in_channels)
99 | assert num_outs == end_level - start_level
100 | self.start_level = start_level
101 | self.end_level = end_level
102 | self.add_extra_convs = add_extra_convs
103 | assert isinstance(add_extra_convs, (str, bool))
104 | if isinstance(add_extra_convs, str):
105 | # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output'
106 | assert add_extra_convs in ('on_input', 'on_lateral', 'on_output')
107 | elif add_extra_convs: # True
108 | if extra_convs_on_inputs:
109 | # TODO: deprecate `extra_convs_on_inputs`
110 | warnings.simplefilter('once')
111 | warnings.warn(
112 | '"extra_convs_on_inputs" will be deprecated in v2.9.0,'
113 | 'Please use "add_extra_convs"', DeprecationWarning)
114 | self.add_extra_convs = 'on_input'
115 | else:
116 | self.add_extra_convs = 'on_output'
117 |
118 | self.lateral_convs = nn.ModuleList()
119 | self.fpn_convs = nn.ModuleList()
120 |
121 | for i in range(self.start_level, self.backbone_end_level):
122 | l_conv = ConvModule(
123 | in_channels[i],
124 | out_channels,
125 | 1,
126 | conv_cfg=conv_cfg,
127 | norm_cfg=norm_cfg if not self.no_norm_on_lateral else None,
128 | act_cfg=act_cfg,
129 | inplace=False)
130 | fpn_conv = ConvModule(
131 | out_channels,
132 | out_channels,
133 | 3,
134 | padding=1,
135 | conv_cfg=conv_cfg,
136 | norm_cfg=norm_cfg,
137 | act_cfg=act_cfg,
138 | inplace=False)
139 |
140 | self.lateral_convs.append(l_conv)
141 | self.fpn_convs.append(fpn_conv)
142 |
143 | # add extra conv layers (e.g., RetinaNet)
144 | extra_levels = num_outs - self.backbone_end_level + self.start_level
145 | assert extra_levels >= num_extra_trans_convs
146 | extra_levels -= num_extra_trans_convs
147 | if self.add_extra_convs and extra_levels >= 1:
148 | for i in range(extra_levels):
149 | if i == 0 and self.add_extra_convs == 'on_input':
150 | in_channels = self.in_channels[self.backbone_end_level - 1]
151 | else:
152 | in_channels = out_channels
153 | extra_fpn_conv = ConvModule(
154 | in_channels,
155 | out_channels,
156 | 3,
157 | stride=2,
158 | padding=1,
159 | conv_cfg=conv_cfg,
160 | norm_cfg=norm_cfg,
161 | act_cfg=act_cfg,
162 | inplace=False)
163 | self.fpn_convs.append(extra_fpn_conv)
164 |
165 | # add extra transposed convs
166 | self.extra_trans_convs = nn.ModuleList()
167 | self.extra_fpn_convs = nn.ModuleList()
168 | for i in range(num_extra_trans_convs):
169 | extra_trans_conv = TransposedConvModule(
170 | out_channels,
171 | out_channels,
172 | 2,
173 | stride=2,
174 | padding=0,
175 | conv_cfg=conv_cfg,
176 | norm_cfg=norm_cfg if not no_norm_on_lateral else None,
177 | act_cfg=act_cfg,
178 | inplace=False)
179 | self.extra_trans_convs.append(extra_trans_conv)
180 | extra_fpn_conv = ConvModule(
181 | out_channels,
182 | out_channels,
183 | 3,
184 | padding=1,
185 | conv_cfg=conv_cfg,
186 | norm_cfg=norm_cfg,
187 | act_cfg=act_cfg,
188 | inplace=False)
189 | self.extra_fpn_convs.append(extra_fpn_conv)
190 |
191 | # default init_weights for conv(msra) and norm in ConvModule
192 | def init_weights(self):
193 | """Initialize the weights of FPN module."""
194 | for m in self.modules():
195 | if isinstance(m, nn.Conv2d):
196 | xavier_init(m, distribution='uniform')
197 |
198 | @auto_fp16()
199 | def forward(self, inputs):
200 | """Forward function."""
201 | assert len(inputs) == len(self.in_channels)
202 |
203 | # build laterals
204 | laterals = [
205 | lateral_conv(inputs[i + self.start_level])
206 | for i, lateral_conv in enumerate(self.lateral_convs)
207 | ]
208 |
209 | # build top-down path
210 | used_backbone_levels = len(laterals)
211 | for i in range(used_backbone_levels - 1, 0, -1):
212 | # In some cases, fixing `scale factor` (e.g. 2) is preferred, but
213 | # it cannot co-exist with `size` in `F.interpolate`.
214 | if 'scale_factor' in self.upsample_cfg:
215 | laterals[i - 1] += F.interpolate(laterals[i],
216 | **self.upsample_cfg)
217 | else:
218 | prev_shape = laterals[i - 1].shape[2:]
219 | laterals[i - 1] += F.interpolate(
220 | laterals[i], size=prev_shape, **self.upsample_cfg)
221 |
222 | # extra transposed convs for outputs with extra scales
223 | extra_laterals = []
224 | if self.num_extra_trans_convs > 0:
225 | prev_lateral = laterals[0]
226 | for i in range(self.num_extra_trans_convs):
227 | extra_lateral = self.extra_trans_convs[i](prev_lateral)
228 | extra_laterals.insert(0, extra_lateral)
229 | prev_lateral = extra_lateral
230 |
231 | # build outputs
232 | # part 1: from original levels
233 | outs = [
234 | self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
235 | ]
236 | # part 2: add extra levels
237 | if self.num_outs > len(outs) + len(extra_laterals):
238 | # use max pool to get more levels on top of outputs
239 | # (e.g., Faster R-CNN, Mask R-CNN)
240 | if not self.add_extra_convs:
241 | for i in range(self.num_outs - len(extra_laterals) - used_backbone_levels):
242 | outs.append(F.max_pool2d(outs[-1], 1, stride=2))
243 | # add conv layers on top of original feature maps (RetinaNet)
244 | else:
245 | if self.add_extra_convs == 'on_input':
246 | extra_source = inputs[self.backbone_end_level - 1]
247 | elif self.add_extra_convs == 'on_lateral':
248 | extra_source = laterals[-1]
249 | elif self.add_extra_convs == 'on_output':
250 | extra_source = outs[-1]
251 | else:
252 | raise NotImplementedError
253 | outs.append(self.fpn_convs[used_backbone_levels](extra_source))
254 | for i in range(used_backbone_levels + 1, self.num_outs - len(extra_laterals)):
255 | if self.relu_before_extra_convs:
256 | outs.append(self.fpn_convs[i](F.relu(outs[-1])))
257 | else:
258 | outs.append(self.fpn_convs[i](outs[-1]))
259 |
260 | # part 3: add extra transposed convs
261 | if self.num_extra_trans_convs > 0:
262 | extra_outs = [
263 | self.extra_fpn_convs[i](extra_laterals[i])
264 | for i in range(self.num_extra_trans_convs)
265 | ]
266 | assert (len(extra_outs) + len(outs)) == self.num_outs, f"{len(extra_outs)} + {len(outs)} != {self.num_outs}"
267 | return tuple(extra_outs + outs)
268 |
269 |
270 | class TransposedConvModule(ConvModule):
271 | def __init__(self, in_channels, out_channels, kernel_size, stride=1,
272 | padding=0, dilation=1, groups=1, bias='auto', conv_cfg=None,
273 | norm_cfg=None, act_cfg=..., inplace=True,
274 | **kwargs):
275 | super(TransposedConvModule, self).__init__(in_channels, out_channels, kernel_size,
276 | stride, padding, dilation, groups, bias, conv_cfg,
277 | norm_cfg, act_cfg, inplace, **kwargs)
278 |
279 | self.conv = nn.ConvTranspose2d(
280 | in_channels,
281 | out_channels,
282 | kernel_size,
283 | stride=stride,
284 | padding=padding,
285 | dilation=dilation,
286 | groups=groups,
287 | bias=self.with_bias
288 | )
289 |
290 | # Use msra init by default
291 | self.init_weights()
292 |
--------------------------------------------------------------------------------
/detection/slurm_test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -x
4 |
5 | PARTITION=$1
6 | JOB_NAME=$2
7 | CONFIG=$3
8 | CHECKPOINT=$4
9 | GPUS=${GPUS:-8}
10 | GPUS_PER_NODE=${GPUS_PER_NODE:-8}
11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5}
12 | PY_ARGS=${@:5}
13 | SRUN_ARGS=${SRUN_ARGS:-""}
14 |
15 | PYTHONPATH="$(dirname $0)":$PYTHONPATH \
16 | srun -p ${PARTITION} \
17 | --job-name=${JOB_NAME} \
18 | --gres=gpu:${GPUS_PER_NODE} \
19 | --ntasks=${GPUS} \
20 | --ntasks-per-node=${GPUS_PER_NODE} \
21 | --cpus-per-task=${CPUS_PER_TASK} \
22 | --kill-on-bad-exit=1 \
23 | ${SRUN_ARGS} \
24 | python -u test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS}
25 |
--------------------------------------------------------------------------------
/detection/slurm_train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -x
4 |
5 | PARTITION=$1
6 | JOB_NAME=$2
7 | CONFIG=$3
8 | WORK_DIR=$4
9 | GPUS=${GPUS:-8}
10 | GPUS_PER_NODE=${GPUS_PER_NODE:-8}
11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5}
12 | SRUN_ARGS=${SRUN_ARGS:-""}
13 | PY_ARGS=${@:5}
14 |
15 | PYTHONPATH="$(dirname $0)":$PYTHONPATH \
16 | srun -p ${PARTITION} \
17 | --job-name=${JOB_NAME} \
18 | --gres=gpu:${GPUS_PER_NODE} \
19 | --ntasks=${GPUS} \
20 | --ntasks-per-node=${GPUS_PER_NODE} \
21 | --cpus-per-task=${CPUS_PER_TASK} \
22 | --kill-on-bad-exit=1 \
23 | ${SRUN_ARGS} \
24 | python -u train.py ${CONFIG} --work-dir=${WORK_DIR} --launcher="slurm" ${PY_ARGS}
25 |
--------------------------------------------------------------------------------
/detection/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import os.path as osp
4 | import time
5 | import warnings
6 |
7 | import mmcv
8 | import torch
9 | from mmcv import Config, DictAction
10 | from mmcv.cnn import fuse_conv_bn
11 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
12 | from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
13 | wrap_fp16_model)
14 |
15 | from mmdet.apis import multi_gpu_test, single_gpu_test
16 | from mmdet.datasets import (build_dataloader, build_dataset,
17 | replace_ImageToTensor)
18 | from mmdet.models import build_detector
19 |
20 | import lightvit
21 | import lightvit_fpn
22 |
23 |
24 | def parse_args():
25 | parser = argparse.ArgumentParser(
26 | description='MMDet test (and eval) a model')
27 | parser.add_argument('config', help='test config file path')
28 | parser.add_argument('checkpoint', help='checkpoint file')
29 | parser.add_argument(
30 | '--work-dir',
31 | help='the directory to save the file containing evaluation metrics')
32 | parser.add_argument('--out', help='output result file in pickle format')
33 | parser.add_argument(
34 | '--fuse-conv-bn',
35 | action='store_true',
36 | help='Whether to fuse conv and bn, this will slightly increase'
37 | 'the inference speed')
38 | parser.add_argument(
39 | '--format-only',
40 | action='store_true',
41 | help='Format the output results without perform evaluation. It is'
42 | 'useful when you want to format the result to a specific format and '
43 | 'submit it to the test server')
44 | parser.add_argument(
45 | '--eval',
46 | type=str,
47 | nargs='+',
48 | help='evaluation metrics, which depends on the dataset, e.g., "bbox",'
49 | ' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC')
50 | parser.add_argument('--show', action='store_true', help='show results')
51 | parser.add_argument(
52 | '--show-dir', help='directory where painted images will be saved')
53 | parser.add_argument(
54 | '--show-score-thr',
55 | type=float,
56 | default=0.3,
57 | help='score threshold (default: 0.3)')
58 | parser.add_argument(
59 | '--gpu-collect',
60 | action='store_true',
61 | help='whether to use gpu to collect results.')
62 | parser.add_argument(
63 | '--tmpdir',
64 | help='tmp directory used for collecting results from multiple '
65 | 'workers, available when gpu-collect is not specified')
66 | parser.add_argument(
67 | '--cfg-options',
68 | nargs='+',
69 | action=DictAction,
70 | help='override some settings in the used config, the key-value pair '
71 | 'in xxx=yyy format will be merged into config file. If the value to '
72 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
73 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
74 | 'Note that the quotation marks are necessary and that no white space '
75 | 'is allowed.')
76 | parser.add_argument(
77 | '--options',
78 | nargs='+',
79 | action=DictAction,
80 | help='custom options for evaluation, the key-value pair in xxx=yyy '
81 | 'format will be kwargs for dataset.evaluate() function (deprecate), '
82 | 'change to --eval-options instead.')
83 | parser.add_argument(
84 | '--eval-options',
85 | nargs='+',
86 | action=DictAction,
87 | help='custom options for evaluation, the key-value pair in xxx=yyy '
88 | 'format will be kwargs for dataset.evaluate() function')
89 | parser.add_argument(
90 | '--launcher',
91 | choices=['none', 'pytorch', 'slurm', 'mpi'],
92 | default='none',
93 | help='job launcher')
94 | parser.add_argument('--local_rank', type=int, default=0)
95 | args = parser.parse_args()
96 | if 'LOCAL_RANK' not in os.environ:
97 | os.environ['LOCAL_RANK'] = str(args.local_rank)
98 |
99 | if args.options and args.eval_options:
100 | raise ValueError(
101 | '--options and --eval-options cannot be both '
102 | 'specified, --options is deprecated in favor of --eval-options')
103 | if args.options:
104 | warnings.warn('--options is deprecated in favor of --eval-options')
105 | args.eval_options = args.options
106 | return args
107 |
108 |
109 | def main():
110 | args = parse_args()
111 |
112 | assert args.out or args.eval or args.format_only or args.show \
113 | or args.show_dir, \
114 | ('Please specify at least one operation (save/eval/format/show the '
115 | 'results / save the results) with the argument "--out", "--eval"'
116 | ', "--format-only", "--show" or "--show-dir"')
117 |
118 | if args.eval and args.format_only:
119 | raise ValueError('--eval and --format_only cannot be both specified')
120 |
121 | if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
122 | raise ValueError('The output file must be a pkl file.')
123 |
124 | cfg = Config.fromfile(args.config)
125 | if args.cfg_options is not None:
126 | cfg.merge_from_dict(args.cfg_options)
127 | # import modules from string list.
128 | if cfg.get('custom_imports', None):
129 | from mmcv.utils import import_modules_from_strings
130 | import_modules_from_strings(**cfg['custom_imports'])
131 | # set cudnn_benchmark
132 | if cfg.get('cudnn_benchmark', False):
133 | torch.backends.cudnn.benchmark = True
134 |
135 | cfg.model.pretrained = None
136 | if cfg.model.get('neck'):
137 | if isinstance(cfg.model.neck, list):
138 | for neck_cfg in cfg.model.neck:
139 | if neck_cfg.get('rfp_backbone'):
140 | if neck_cfg.rfp_backbone.get('pretrained'):
141 | neck_cfg.rfp_backbone.pretrained = None
142 | elif cfg.model.neck.get('rfp_backbone'):
143 | if cfg.model.neck.rfp_backbone.get('pretrained'):
144 | cfg.model.neck.rfp_backbone.pretrained = None
145 |
146 | # in case the test dataset is concatenated
147 | samples_per_gpu = 1
148 | if isinstance(cfg.data.test, dict):
149 | cfg.data.test.test_mode = True
150 | samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1)
151 | if samples_per_gpu > 1:
152 | # Replace 'ImageToTensor' to 'DefaultFormatBundle'
153 | cfg.data.test.pipeline = replace_ImageToTensor(
154 | cfg.data.test.pipeline)
155 | elif isinstance(cfg.data.test, list):
156 | for ds_cfg in cfg.data.test:
157 | ds_cfg.test_mode = True
158 | samples_per_gpu = max(
159 | [ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test])
160 | if samples_per_gpu > 1:
161 | for ds_cfg in cfg.data.test:
162 | ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)
163 |
164 | # init distributed env first, since logger depends on the dist info.
165 | if args.launcher == 'none':
166 | distributed = False
167 | else:
168 | distributed = True
169 | init_dist(args.launcher, **cfg.dist_params)
170 |
171 | rank, _ = get_dist_info()
172 | # allows not to create
173 | if args.work_dir is not None and rank == 0:
174 | mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
175 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
176 | json_file = osp.join(args.work_dir, f'eval_{timestamp}.json')
177 |
178 | # build the dataloader
179 | dataset = build_dataset(cfg.data.test)
180 | data_loader = build_dataloader(
181 | dataset,
182 | samples_per_gpu=samples_per_gpu,
183 | workers_per_gpu=cfg.data.workers_per_gpu,
184 | dist=distributed,
185 | shuffle=False)
186 |
187 | # build the model and load checkpoint
188 | cfg.model.train_cfg = None
189 | model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
190 | fp16_cfg = cfg.get('fp16', None)
191 | if fp16_cfg is not None:
192 | wrap_fp16_model(model)
193 | checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
194 | if args.fuse_conv_bn:
195 | model = fuse_conv_bn(model)
196 | # old versions did not save class info in checkpoints, this walkaround is
197 | # for backward compatibility
198 | if 'CLASSES' in checkpoint.get('meta', {}):
199 | model.CLASSES = checkpoint['meta']['CLASSES']
200 | else:
201 | model.CLASSES = dataset.CLASSES
202 |
203 | if not distributed:
204 | model = MMDataParallel(model, device_ids=[0])
205 | outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,
206 | args.show_score_thr)
207 | else:
208 | model = MMDistributedDataParallel(
209 | model.cuda(),
210 | device_ids=[torch.cuda.current_device()],
211 | broadcast_buffers=False)
212 | outputs = multi_gpu_test(model, data_loader, args.tmpdir,
213 | args.gpu_collect)
214 |
215 | rank, _ = get_dist_info()
216 | if rank == 0:
217 | if args.out:
218 | print(f'\nwriting results to {args.out}')
219 | mmcv.dump(outputs, args.out)
220 | kwargs = {} if args.eval_options is None else args.eval_options
221 | if args.format_only:
222 | dataset.format_results(outputs, **kwargs)
223 | if args.eval:
224 | eval_kwargs = cfg.get('evaluation', {}).copy()
225 | # hard-code way to remove EvalHook args
226 | for key in [
227 | 'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',
228 | 'rule'
229 | ]:
230 | eval_kwargs.pop(key, None)
231 | eval_kwargs.update(dict(metric=args.eval, **kwargs))
232 | metric = dataset.evaluate(outputs, **eval_kwargs)
233 | print(metric)
234 | metric_dict = dict(config=args.config, metric=metric)
235 | if args.work_dir is not None and rank == 0:
236 | mmcv.dump(metric_dict, json_file)
237 |
238 |
239 | if __name__ == '__main__':
240 | main()
241 |
--------------------------------------------------------------------------------
/detection/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import copy
3 | import os
4 | import os.path as osp
5 | import time
6 | import warnings
7 |
8 | import mmcv
9 | import torch
10 | from mmcv import Config, DictAction
11 | from mmcv.runner import get_dist_info, init_dist
12 | from mmcv.utils import get_git_hash
13 |
14 | from mmdet import __version__
15 | from mmdet.apis import set_random_seed, train_detector
16 | from mmdet.datasets import build_dataset
17 | from mmdet.models import build_detector
18 | from mmdet.utils import collect_env, get_root_logger
19 |
20 | import lightvit
21 | import lightvit_fpn
22 |
23 |
24 | def parse_args():
25 | parser = argparse.ArgumentParser(description='Train a detector')
26 | parser.add_argument('config', help='train config file path')
27 | parser.add_argument('--work-dir', help='the dir to save logs and models')
28 | parser.add_argument(
29 | '--resume-from', help='the checkpoint file to resume from')
30 | parser.add_argument(
31 | '--no-validate',
32 | action='store_true',
33 | help='whether not to evaluate the checkpoint during training')
34 | group_gpus = parser.add_mutually_exclusive_group()
35 | group_gpus.add_argument(
36 | '--gpus',
37 | type=int,
38 | help='number of gpus to use '
39 | '(only applicable to non-distributed training)')
40 | group_gpus.add_argument(
41 | '--gpu-ids',
42 | type=int,
43 | nargs='+',
44 | help='ids of gpus to use '
45 | '(only applicable to non-distributed training)')
46 | parser.add_argument('--seed', type=int, default=None, help='random seed')
47 | parser.add_argument(
48 | '--deterministic',
49 | action='store_true',
50 | help='whether to set deterministic options for CUDNN backend.')
51 | parser.add_argument(
52 | '--options',
53 | nargs='+',
54 | action=DictAction,
55 | help='override some settings in the used config, the key-value pair '
56 | 'in xxx=yyy format will be merged into config file (deprecate), '
57 | 'change to --cfg-options instead.')
58 | parser.add_argument(
59 | '--cfg-options',
60 | nargs='+',
61 | action=DictAction,
62 | help='override some settings in the used config, the key-value pair '
63 | 'in xxx=yyy format will be merged into config file. If the value to '
64 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
65 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
66 | 'Note that the quotation marks are necessary and that no white space '
67 | 'is allowed.')
68 | parser.add_argument(
69 | '--launcher',
70 | choices=['none', 'pytorch', 'slurm', 'mpi'],
71 | default='none',
72 | help='job launcher')
73 | parser.add_argument('--local_rank', type=int, default=0)
74 | args = parser.parse_args()
75 | if 'LOCAL_RANK' not in os.environ:
76 | os.environ['LOCAL_RANK'] = str(args.local_rank)
77 |
78 | if args.options and args.cfg_options:
79 | raise ValueError(
80 | '--options and --cfg-options cannot be both '
81 | 'specified, --options is deprecated in favor of --cfg-options')
82 | if args.options:
83 | warnings.warn('--options is deprecated in favor of --cfg-options')
84 | args.cfg_options = args.options
85 |
86 | return args
87 |
88 |
89 | def main():
90 | args = parse_args()
91 |
92 | cfg = Config.fromfile(args.config)
93 | if args.cfg_options is not None:
94 | cfg.merge_from_dict(args.cfg_options)
95 | # import modules from string list.
96 | if cfg.get('custom_imports', None):
97 | from mmcv.utils import import_modules_from_strings
98 | import_modules_from_strings(**cfg['custom_imports'])
99 | # set cudnn_benchmark
100 | if cfg.get('cudnn_benchmark', False):
101 | torch.backends.cudnn.benchmark = True
102 |
103 | # work_dir is determined in this priority: CLI > segment in file > filename
104 | if args.work_dir is not None:
105 | # update configs according to CLI args if args.work_dir is not None
106 | cfg.work_dir = args.work_dir
107 | elif cfg.get('work_dir', None) is None:
108 | # use config filename as default work_dir if cfg.work_dir is None
109 | cfg.work_dir = osp.join('./work_dirs',
110 | osp.splitext(osp.basename(args.config))[0])
111 | if args.resume_from is not None:
112 | cfg.resume_from = args.resume_from
113 | if args.gpu_ids is not None:
114 | cfg.gpu_ids = args.gpu_ids
115 | else:
116 | cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
117 |
118 | # init distributed env first, since logger depends on the dist info.
119 | if args.launcher == 'none':
120 | distributed = False
121 | else:
122 | distributed = True
123 | init_dist(args.launcher, **cfg.dist_params)
124 | # re-set gpu_ids with distributed training mode
125 | _, world_size = get_dist_info()
126 | cfg.gpu_ids = range(world_size)
127 |
128 | # create work_dir
129 | mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
130 | # dump config
131 | cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
132 | # init the logger before other steps
133 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
134 | log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
135 | logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
136 |
137 | # init the meta dict to record some important information such as
138 | # environment info and seed, which will be logged
139 | meta = dict()
140 | # log env info
141 | env_info_dict = collect_env()
142 | env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
143 | dash_line = '-' * 60 + '\n'
144 | logger.info('Environment info:\n' + dash_line + env_info + '\n' +
145 | dash_line)
146 | meta['env_info'] = env_info
147 | meta['config'] = cfg.pretty_text
148 | # log some basic info
149 | logger.info(f'Distributed training: {distributed}')
150 | logger.info(f'Config:\n{cfg.pretty_text}')
151 |
152 | # set random seeds
153 | if args.seed is not None:
154 | logger.info(f'Set random seed to {args.seed}, '
155 | f'deterministic: {args.deterministic}')
156 | set_random_seed(args.seed, deterministic=args.deterministic)
157 | cfg.seed = args.seed
158 | meta['seed'] = args.seed
159 | meta['exp_name'] = osp.basename(args.config)
160 |
161 | model = build_detector(
162 | cfg.model,
163 | train_cfg=cfg.get('train_cfg'),
164 | test_cfg=cfg.get('test_cfg'))
165 | model.init_weights()
166 |
167 | datasets = [build_dataset(cfg.data.train)]
168 | if len(cfg.workflow) == 2:
169 | val_dataset = copy.deepcopy(cfg.data.val)
170 | val_dataset.pipeline = cfg.data.train.pipeline
171 | datasets.append(build_dataset(val_dataset))
172 | if cfg.checkpoint_config is not None:
173 | # save mmdet version, config file content and class names in
174 | # checkpoints as meta data
175 | cfg.checkpoint_config.meta = dict(
176 | mmdet_version=__version__ + get_git_hash()[:7],
177 | CLASSES=datasets[0].CLASSES)
178 | # add an attribute for visualization convenience
179 | model.CLASSES = datasets[0].CLASSES
180 | train_detector(
181 | model,
182 | datasets,
183 | cfg,
184 | distributed=distributed,
185 | validate=(not args.no_validate),
186 | timestamp=timestamp,
187 | meta=meta)
188 |
189 |
190 | if __name__ == '__main__':
191 | main()
192 |
--------------------------------------------------------------------------------