├── LICENSE
├── README.md
├── detection
├── README.md
├── backbone
│ └── vit_SelfPatch.py
├── configs
│ ├── _base_
│ │ ├── datasets
│ │ │ ├── coco_detection.py
│ │ │ ├── coco_instance.py
│ │ │ └── coco_instance_semantic.py
│ │ ├── default_runtime.py
│ │ ├── models
│ │ │ └── mask_rcnn_selfpatch_p16.py
│ │ └── schedules
│ │ │ ├── schedule_1x.py
│ │ │ ├── schedule_20e.py
│ │ │ └── schedule_2x.py
│ └── selfpatch
│ │ └── mask_rcnn_vit_small_12_p16_1x_coco.py
└── tools
│ ├── dist_test.sh
│ ├── dist_train.sh
│ ├── slurm_test.sh
│ ├── slurm_train.sh
│ ├── test.py
│ └── train.py
├── eval_video_segmentation.py
├── main_selfpatch.py
├── segmentation
├── README.md
├── backbones
│ ├── __init__.py
│ ├── bisenetv1.py
│ ├── bisenetv2.py
│ ├── cgnet.py
│ ├── fast_scnn.py
│ ├── hrnet.py
│ ├── mit.py
│ ├── mobilenet_v2.py
│ ├── mobilenet_v3.py
│ ├── resnest.py
│ ├── resnet.py
│ ├── resnext.py
│ ├── swin.py
│ ├── unet.py
│ └── vit_SelfPatch.py
├── configs
│ ├── _base_
│ │ ├── datasets
│ │ │ └── ade20k.py
│ │ ├── default_runtime.py
│ │ └── schedules
│ │ │ └── schedule_40k.py
│ └── semfpn_vit-s16_512x512_40k_ade20k.py
└── tools
│ ├── analyze_logs.py
│ ├── benchmark.py
│ ├── browse_dataset.py
│ ├── convert_datasets
│ ├── chase_db1.py
│ ├── cityscapes.py
│ ├── coco_stuff10k.py
│ ├── coco_stuff164k.py
│ ├── drive.py
│ ├── hrf.py
│ ├── pascal_context.py
│ ├── stare.py
│ └── voc_aug.py
│ ├── deploy_test.py
│ ├── dist_test.sh
│ ├── dist_train.sh
│ ├── get_flops.py
│ ├── model_converters
│ ├── mit2mmseg.py
│ ├── swin2mmseg.py
│ └── vit2mmseg.py
│ ├── onnx2tensorrt.py
│ ├── print_config.py
│ ├── publish_model.py
│ ├── pytorch2onnx.py
│ ├── pytorch2torchscript.py
│ ├── slurm_test.sh
│ ├── slurm_train.sh
│ ├── test.py
│ ├── torchserve
│ ├── mmseg2torchserve.py
│ ├── mmseg_handler.py
│ └── test_torchserve.py
│ └── train.py
├── selfpatch_vision_transformer.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Patch-level Representation Learning for Self-supervised Vision Transformers (SelfPatch)
2 |
3 | PyTorch implementation for "Patch-level Representation Learning for Self-supervised Vision Transformers" (accepted Oral presentation in CVPR 2022)
4 |
5 |
6 |
7 |
8 |
9 | ## Requirements
10 | - `torch==1.7.0`
11 | - `torchvision==0.8.1`
12 |
13 | ## Pretraining on ImageNet
14 | ```
15 | python -m torch.distributed.launch --nproc_per_node=8 main_selfpatch.py --arch vit_small --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir --local_crops_number 8 --patch_size 16 --batch_size_per_gpu 128 --out_dim_selfpatch 4096 --k_num 4
16 | ```
17 |
18 | ## Pretrained weights on ImageNet
19 | You can download the weights of the pretrained models on ImageNet. All models are trained on `ViT-S/16`. For detection and segmentation downstream tasks, please check SelfPatch/detection, SelfPatch/segmentation.
20 |
21 | | backbone | arch | checkpoint |
22 | | ------------- | ------------- | ------------- |
23 | | DINO | ViT-S/16 | download (pretrained model from VISSL) |
24 | | DINO + SelfPatch | ViT-S/16 | download |
25 |
26 | ## Evaluating video object segmentation on the DAVIS 2017 dataset
27 | Step 1. Prepare DAVIS 2017 data
28 |
29 | ```
30 | cd $HOME
31 | git clone https://github.com/davisvideochallenge/davis-2017
32 | cd davis-2017
33 | ./data/get_davis.sh
34 | ```
35 |
36 | Step 2. Run Video object segmentation
37 |
38 | ```
39 | python eval_video_segmentation.py --data_path /path/to/davis-2017/DAVIS/ --output_dir /path/to/saving_dir --pretrained_weights /path/to/model_dir --arch vit_small --patch_size 16
40 | ```
41 |
42 | Step 3. Evaluate the obtained segmentation
43 |
44 | ```
45 | git clone https://github.com/davisvideochallenge/davis2017-evaluation
46 | $HOME/davis2017-evaluation
47 | python /path/to/davis2017-evaluation/evaluation_method.py --task semi-supervised --davis_path /path/to/davis-2017/DAVIS --results_path /path/to/saving_dir
48 | ```
49 |
50 | ### Video object segmentation examples on the DAVIS 2017 dataset
51 |
52 | Video (left), DINO (middle) and our SelfPatch (right)
53 |
54 |
55 |
56 |
57 |
58 | ## Acknowledgement
59 | Our code base is built partly upon the packages:
60 | DINO, mmdetection, mmsegmentation and XCiT
61 |
62 | ## Citation
63 | If you use this code for your research, please cite our papers.
64 | ```
65 | @InProceedings{Yun_2022_CVPR,
66 | author = {Yun, Sukmin and Lee, Hankook and Kim, Jaehyung and Shin, Jinwoo},
67 | title = {Patch-Level Representation Learning for Self-Supervised Vision Transformers},
68 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
69 | month = {June},
70 | year = {2022},
71 | pages = {8354-8363}
72 | }
73 | ```
74 |
--------------------------------------------------------------------------------
/detection/README.md:
--------------------------------------------------------------------------------
1 | ## Evaluating object detection and instance segmentation on the COCO dataset
2 | Step 1. Prepare COCO dataset
3 |
4 | The dataset can be downloaded at `https://cocodataset.org/#download`
5 |
6 | Step 2. Install mmdetection
7 |
8 | ```
9 | git clone https://github.com/open-mmlab/mmdetection.git
10 | ```
11 |
12 | Step 3. Fine-tune on the COCO dataset
13 |
14 | ```
15 | tools/dist_train.sh configs/selfpatch/mask_rcnn_vit_small_12_p16_1x_coco.py [number of gpu] --work-dir /path/to/saving_dir --seed 0 --deterministic --options model.pretrained=/path/to/model_dir
16 | ```
17 |
18 | ## Pretrained weights on MS-COCO
19 | You can download the weights of the fine-tuned models on object detection and instance segmentation tasks. All models are fine-tuned with `Mask R-CNN`.
20 |
21 | | backbone | arch | bbox mAP | mask mAP | checkpoint |
22 | | ------------- | ------------- | ------------- | ------------- | ------------- |
23 | | DINO | ViT-S/16 + Mask R-CNN | 40.8 | 37.3 | download |
24 | | DINO + SelfPatch | ViT-S/16 + Mask R-CNN | 42.1 | 38.5 | download |
25 |
26 | ## Acknowledgement
27 | This code is built using the mmdetection libray.
28 |
--------------------------------------------------------------------------------
/detection/configs/_base_/datasets/coco_detection.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | dataset_type = 'CocoDataset'
4 | data_root = 'data/coco/'
5 | img_norm_cfg = dict(
6 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
7 | train_pipeline = [
8 | dict(type='LoadImageFromFile'),
9 | dict(type='LoadAnnotations', with_bbox=True),
10 | dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
11 | dict(type='RandomFlip', flip_ratio=0.5),
12 | dict(type='Normalize', **img_norm_cfg),
13 | dict(type='Pad', size_divisor=32),
14 | dict(type='DefaultFormatBundle'),
15 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
16 | ]
17 | test_pipeline = [
18 | dict(type='LoadImageFromFile'),
19 | dict(
20 | type='MultiScaleFlipAug',
21 | img_scale=(1333, 800),
22 | flip=False,
23 | transforms=[
24 | dict(type='Resize', keep_ratio=True),
25 | dict(type='RandomFlip'),
26 | dict(type='Normalize', **img_norm_cfg),
27 | dict(type='Pad', size_divisor=32),
28 | dict(type='ImageToTensor', keys=['img']),
29 | dict(type='Collect', keys=['img']),
30 | ])
31 | ]
32 | data = dict(
33 | samples_per_gpu=2,
34 | workers_per_gpu=2,
35 | train=dict(
36 | type=dataset_type,
37 | ann_file=data_root + 'annotations/instances_train2017.json',
38 | img_prefix=data_root + 'train2017/',
39 | pipeline=train_pipeline),
40 | val=dict(
41 | type=dataset_type,
42 | ann_file=data_root + 'annotations/instances_val2017.json',
43 | img_prefix=data_root + 'val2017/',
44 | pipeline=test_pipeline),
45 | test=dict(
46 | type=dataset_type,
47 | ann_file=data_root + 'annotations/instances_val2017.json',
48 | img_prefix=data_root + 'val2017/',
49 | pipeline=test_pipeline))
50 | evaluation = dict(interval=1, metric='bbox')
51 |
--------------------------------------------------------------------------------
/detection/configs/_base_/datasets/coco_instance.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | dataset_type = 'CocoDataset'
4 | data_root = '/data/COCO/'
5 | img_norm_cfg = dict(
6 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
7 | train_pipeline = [
8 | dict(type='LoadImageFromFile'),
9 | dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
10 | dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
11 | dict(type='RandomFlip', flip_ratio=0.5),
12 | dict(type='Normalize', **img_norm_cfg),
13 | dict(type='Pad', size_divisor=32),
14 | dict(type='DefaultFormatBundle'),
15 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
16 | ]
17 | test_pipeline = [
18 | dict(type='LoadImageFromFile'),
19 | dict(
20 | type='MultiScaleFlipAug',
21 | img_scale=(1333, 800),
22 | flip=False,
23 | transforms=[
24 | dict(type='Resize', keep_ratio=True),
25 | dict(type='RandomFlip'),
26 | dict(type='Normalize', **img_norm_cfg),
27 | dict(type='Pad', size_divisor=32),
28 | dict(type='ImageToTensor', keys=['img']),
29 | dict(type='Collect', keys=['img']),
30 | ])
31 | ]
32 | data = dict(
33 | samples_per_gpu=2,
34 | workers_per_gpu=2,
35 | train=dict(
36 | type=dataset_type,
37 | ann_file=data_root + 'annotations/instances_train2017.json',
38 | img_prefix=data_root + 'train2017/',
39 | pipeline=train_pipeline),
40 | val=dict(
41 | type=dataset_type,
42 | ann_file=data_root + 'annotations/instances_val2017.json',
43 | img_prefix=data_root + 'val2017/',
44 | pipeline=test_pipeline),
45 | test=dict(
46 | type=dataset_type,
47 | ann_file=data_root + 'annotations/instances_val2017.json',
48 | img_prefix=data_root + 'val2017/',
49 | pipeline=test_pipeline))
50 | evaluation = dict(metric=['bbox', 'segm'])
51 |
--------------------------------------------------------------------------------
/detection/configs/_base_/datasets/coco_instance_semantic.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | dataset_type = 'CocoDataset'
4 | data_root = 'data/coco/'
5 | img_norm_cfg = dict(
6 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
7 | train_pipeline = [
8 | dict(type='LoadImageFromFile'),
9 | dict(
10 | type='LoadAnnotations', with_bbox=True, with_mask=True, with_seg=True),
11 | dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
12 | dict(type='RandomFlip', flip_ratio=0.5),
13 | dict(type='Normalize', **img_norm_cfg),
14 | dict(type='Pad', size_divisor=32),
15 | dict(type='SegRescale', scale_factor=1 / 8),
16 | dict(type='DefaultFormatBundle'),
17 | dict(
18 | type='Collect',
19 | keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg']),
20 | ]
21 | test_pipeline = [
22 | dict(type='LoadImageFromFile'),
23 | dict(
24 | type='MultiScaleFlipAug',
25 | img_scale=(1333, 800),
26 | flip=False,
27 | transforms=[
28 | dict(type='Resize', keep_ratio=True),
29 | dict(type='RandomFlip', flip_ratio=0.5),
30 | dict(type='Normalize', **img_norm_cfg),
31 | dict(type='Pad', size_divisor=32),
32 | dict(type='ImageToTensor', keys=['img']),
33 | dict(type='Collect', keys=['img']),
34 | ])
35 | ]
36 | data = dict(
37 | samples_per_gpu=2,
38 | workers_per_gpu=2,
39 | train=dict(
40 | type=dataset_type,
41 | ann_file=data_root + 'annotations/instances_train2017.json',
42 | img_prefix=data_root + 'train2017/',
43 | seg_prefix=data_root + 'stuffthingmaps/train2017/',
44 | pipeline=train_pipeline),
45 | val=dict(
46 | type=dataset_type,
47 | ann_file=data_root + 'annotations/instances_val2017.json',
48 | img_prefix=data_root + 'val2017/',
49 | pipeline=test_pipeline),
50 | test=dict(
51 | type=dataset_type,
52 | ann_file=data_root + 'annotations/instances_val2017.json',
53 | img_prefix=data_root + 'val2017/',
54 | pipeline=test_pipeline))
55 | evaluation = dict(metric=['bbox', 'segm'])
56 |
--------------------------------------------------------------------------------
/detection/configs/_base_/default_runtime.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | checkpoint_config = dict(interval=1)
4 | # yapf:disable
5 | log_config = dict(
6 | interval=50,
7 | hooks=[
8 | dict(type='TextLoggerHook'),
9 | # dict(type='TensorboardLoggerHook')
10 | ])
11 | # yapf:enable
12 | custom_hooks = [dict(type='NumClassCheckHook')]
13 |
14 | dist_params = dict(backend='nccl')
15 | log_level = 'INFO'
16 | load_from = None
17 | resume_from = None
18 | workflow = [('train', 1)]
19 |
--------------------------------------------------------------------------------
/detection/configs/_base_/models/mask_rcnn_selfpatch_p16.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | # model settings
4 | model = dict(
5 | type='MaskRCNN',
6 | pretrained=None,
7 | backbone=dict(
8 | type='SelfPatch_ViT',
9 | patch_size=16,
10 | embed_dim=384,
11 | depth=12,
12 | num_heads=8,
13 | mlp_ratio=4,
14 | qkv_bias=True,
15 | ),
16 | neck=dict(
17 | type='FPN',
18 | in_channels=[384, 384, 384, 384],
19 | out_channels=256,
20 | num_outs=5),
21 | rpn_head=dict(
22 | type='RPNHead',
23 | in_channels=256,
24 | feat_channels=256,
25 | anchor_generator=dict(
26 | type='AnchorGenerator',
27 | scales=[8],
28 | ratios=[0.5, 1.0, 2.0],
29 | strides=[4, 8, 16, 32, 64]),
30 | bbox_coder=dict(
31 | type='DeltaXYWHBBoxCoder',
32 | target_means=[.0, .0, .0, .0],
33 | target_stds=[1.0, 1.0, 1.0, 1.0]),
34 | loss_cls=dict(
35 | type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
36 | loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
37 | roi_head=dict(
38 | type='StandardRoIHead',
39 | bbox_roi_extractor=dict(
40 | type='SingleRoIExtractor',
41 | roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
42 | out_channels=256,
43 | featmap_strides=[16, 16, 16, 16]),
44 | bbox_head=dict(
45 | type='Shared2FCBBoxHead',
46 | in_channels=256,
47 | fc_out_channels=1024,
48 | roi_feat_size=7,
49 | num_classes=80,
50 | bbox_coder=dict(
51 | type='DeltaXYWHBBoxCoder',
52 | target_means=[0., 0., 0., 0.],
53 | target_stds=[0.1, 0.1, 0.2, 0.2]),
54 | reg_class_agnostic=False,
55 | loss_cls=dict(
56 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
57 | loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
58 | mask_roi_extractor=dict(
59 | type='SingleRoIExtractor',
60 | roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
61 | out_channels=256,
62 | featmap_strides=[16, 16, 16, 16]),
63 | mask_head=dict(
64 | type='FCNMaskHead',
65 | num_convs=4,
66 | in_channels=256,
67 | conv_out_channels=256,
68 | num_classes=80,
69 | loss_mask=dict(
70 | type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
71 | # model training and testing settings
72 | train_cfg=dict(
73 | rpn=dict(
74 | assigner=dict(
75 | type='MaxIoUAssigner',
76 | pos_iou_thr=0.7,
77 | neg_iou_thr=0.3,
78 | min_pos_iou=0.3,
79 | match_low_quality=True,
80 | ignore_iof_thr=-1),
81 | sampler=dict(
82 | type='RandomSampler',
83 | num=256,
84 | pos_fraction=0.5,
85 | neg_pos_ub=-1,
86 | add_gt_as_proposals=False),
87 | allowed_border=-1,
88 | pos_weight=-1,
89 | debug=False),
90 | rpn_proposal=dict(
91 | nms_pre=2000,
92 | max_per_img=1000,
93 | nms=dict(type='nms', iou_threshold=0.7),
94 | min_bbox_size=0),
95 | rcnn=dict(
96 | assigner=dict(
97 | type='MaxIoUAssigner',
98 | pos_iou_thr=0.5,
99 | neg_iou_thr=0.5,
100 | min_pos_iou=0.5,
101 | match_low_quality=True,
102 | ignore_iof_thr=-1),
103 | sampler=dict(
104 | type='RandomSampler',
105 | num=512,
106 | pos_fraction=0.25,
107 | neg_pos_ub=-1,
108 | add_gt_as_proposals=True),
109 | mask_size=28,
110 | pos_weight=-1,
111 | debug=False)),
112 | test_cfg=dict(
113 | rpn=dict(
114 | nms_pre=1000,
115 | max_per_img=1000,
116 | nms=dict(type='nms', iou_threshold=0.7),
117 | min_bbox_size=0),
118 | rcnn=dict(
119 | score_thr=0.05,
120 | nms=dict(type='nms', iou_threshold=0.5),
121 | max_per_img=100,
122 | mask_thr_binary=0.5)))
123 |
--------------------------------------------------------------------------------
/detection/configs/_base_/schedules/schedule_1x.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | # optimizer
4 | optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
5 | optimizer_config = dict(grad_clip=None)
6 | # learning policy
7 | lr_config = dict(
8 | policy='step',
9 | warmup='linear',
10 | warmup_iters=500,
11 | warmup_ratio=0.001,
12 | step=[8, 11])
13 | runner = dict(type='EpochBasedRunner', max_epochs=12)
14 |
--------------------------------------------------------------------------------
/detection/configs/_base_/schedules/schedule_20e.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | # optimizer
4 | optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
5 | optimizer_config = dict(grad_clip=None)
6 | # learning policy
7 | lr_config = dict(
8 | policy='step',
9 | warmup='linear',
10 | warmup_iters=500,
11 | warmup_ratio=0.001,
12 | step=[16, 19])
13 | runner = dict(type='EpochBasedRunner', max_epochs=20)
14 |
--------------------------------------------------------------------------------
/detection/configs/_base_/schedules/schedule_2x.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | # optimizer
4 | optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
5 | optimizer_config = dict(grad_clip=None)
6 | # learning policy
7 | lr_config = dict(
8 | policy='step',
9 | warmup='linear',
10 | warmup_iters=500,
11 | warmup_ratio=0.001,
12 | step=[16, 22])
13 | runner = dict(type='EpochBasedRunner', max_epochs=24)
14 |
--------------------------------------------------------------------------------
/detection/configs/selfpatch/mask_rcnn_vit_small_12_p16_1x_coco.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | """
4 | Hyperparameters modifed from
5 | https://github.com/SwinTransformer/Swin-Transformer-Object-Detection
6 | """
7 |
8 | _base_ = [
9 | '../_base_/models/mask_rcnn_selfpatch_p16.py',
10 | '../_base_/datasets/coco_instance.py',
11 | '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
12 | ]
13 |
14 | model = dict(
15 | backbone=dict(
16 | type='SelfPatch_ViT',
17 | patch_size=16,
18 | embed_dim=384,
19 | depth=12,
20 | num_heads=8,
21 | mlp_ratio=4,
22 | qkv_bias=True,
23 | eta=1.0,
24 | drop_path_rate=0.05,
25 | out_indices=[3, 5, 7, 11]
26 | ),
27 | neck=dict(in_channels=[384, 384, 384, 384]),
28 | roi_head=dict(
29 | type='StandardRoIHead',
30 | bbox_roi_extractor=dict(
31 | type='SingleRoIExtractor',
32 | roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
33 | out_channels=256,
34 | featmap_strides=[4, 8, 16, 32]),
35 | mask_roi_extractor=dict(
36 | type='SingleRoIExtractor',
37 | roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
38 | out_channels=256,
39 | featmap_strides=[4, 8, 16, 32])
40 | ),
41 | )
42 |
43 | img_norm_cfg = dict(
44 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
45 |
46 | # augmentation strategy originates from DETR / Sparse RCNN
47 | train_pipeline = [
48 | dict(type='LoadImageFromFile'),
49 | dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
50 | dict(type='RandomFlip', flip_ratio=0.5),
51 | dict(type='AutoAugment',
52 | policies=[
53 | [
54 | dict(type='Resize',
55 | img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
56 | (608, 1333), (640, 1333), (672, 1333), (704, 1333),
57 | (736, 1333), (768, 1333), (800, 1333)],
58 | multiscale_mode='value',
59 | keep_ratio=True)
60 | ],
61 | [
62 | dict(type='Resize',
63 | img_scale=[(400, 1333), (500, 1333), (600, 1333)],
64 | multiscale_mode='value',
65 | keep_ratio=True),
66 | dict(type='RandomCrop',
67 | crop_type='absolute_range',
68 | crop_size=(384, 600),
69 | allow_negative_crop=True),
70 | dict(type='Resize',
71 | img_scale=[(480, 1333), (512, 1333), (544, 1333),
72 | (576, 1333), (608, 1333), (640, 1333),
73 | (672, 1333), (704, 1333), (736, 1333),
74 | (768, 1333), (800, 1333)],
75 | multiscale_mode='value',
76 | override=True,
77 | keep_ratio=True)
78 | ]
79 | ]),
80 | dict(type='Normalize', **img_norm_cfg),
81 | dict(type='Pad', size_divisor=32),
82 | dict(type='DefaultFormatBundle'),
83 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
84 | ]
85 | data = dict(samples_per_gpu=1,train=dict(pipeline=train_pipeline))
86 |
87 | optimizer = dict(_delete_=True, type='AdamW', lr=0.0001/2, betas=(0.9, 0.999), weight_decay=0.05,
88 | paramwise_cfg=dict(custom_keys={'pos_embed': dict(decay_mult=0.),
89 | 'cls_token': dict(decay_mult=0.),
90 | 'norm': dict(decay_mult=0.)}))
91 | lr_config = dict(warmup_iters=500*2,step=[8, 11])
92 | runner = dict(type='EpochBasedRunnerAmp', max_epochs=12)
93 |
94 | # do not use mmdet version fp16
95 | fp16 = None
96 | optimizer_config = dict(
97 | type="DistOptimizerHook",
98 | update_interval=1,
99 | grad_clip=None,
100 | coalesce=True,
101 | bucket_size_mb=-1,
102 | use_fp16=True,
103 | )
104 |
--------------------------------------------------------------------------------
/detection/tools/dist_test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CONFIG=$1
4 | CHECKPOINT=$2
5 | GPUS=$3
6 | PORT=${PORT:-29500}
7 |
8 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
9 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
10 | $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4}
11 |
--------------------------------------------------------------------------------
/detection/tools/dist_train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CONFIG=$1
4 | GPUS=$2
5 | PORT=${PORT:-29500}
6 |
7 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
8 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
9 | $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3}
10 |
--------------------------------------------------------------------------------
/detection/tools/slurm_test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -x
4 |
5 | PARTITION=$1
6 | JOB_NAME=$2
7 | CONFIG=$3
8 | CHECKPOINT=$4
9 | GPUS=${GPUS:-8}
10 | GPUS_PER_NODE=${GPUS_PER_NODE:-8}
11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5}
12 | PY_ARGS=${@:5}
13 | SRUN_ARGS=${SRUN_ARGS:-""}
14 |
15 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
16 | srun -p ${PARTITION} \
17 | --job-name=${JOB_NAME} \
18 | --gres=gpu:${GPUS_PER_NODE} \
19 | --ntasks=${GPUS} \
20 | --ntasks-per-node=${GPUS_PER_NODE} \
21 | --cpus-per-task=${CPUS_PER_TASK} \
22 | --kill-on-bad-exit=1 \
23 | ${SRUN_ARGS} \
24 | python -u tools/test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS}
25 |
--------------------------------------------------------------------------------
/detection/tools/slurm_train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -x
4 |
5 | PARTITION=$1
6 | JOB_NAME=$2
7 | CONFIG=$3
8 | WORK_DIR=$4
9 | GPUS=${GPUS:-8}
10 | GPUS_PER_NODE=${GPUS_PER_NODE:-8}
11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5}
12 | SRUN_ARGS=${SRUN_ARGS:-""}
13 | PY_ARGS=${@:5}
14 |
15 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
16 | srun -p ${PARTITION} \
17 | --job-name=${JOB_NAME} \
18 | --gres=gpu:${GPUS_PER_NODE} \
19 | --ntasks=${GPUS} \
20 | --ntasks-per-node=${GPUS_PER_NODE} \
21 | --cpus-per-task=${CPUS_PER_TASK} \
22 | --mem 350GB \
23 | --kill-on-bad-exit=0 \
24 | ${SRUN_ARGS} \
25 | python -u tools/train.py ${CONFIG} --work-dir=${WORK_DIR} --launcher="slurm" ${PY_ARGS}
--------------------------------------------------------------------------------
/detection/tools/test.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | """
4 | Testing script modified from
5 | https://github.com/open-mmlab/mmdetection
6 | """
7 | import argparse
8 | import os
9 | import warnings
10 |
11 | import mmcv
12 | import torch
13 | from mmcv import Config, DictAction
14 | from mmcv.cnn import fuse_conv_bn
15 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
16 | from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
17 | wrap_fp16_model)
18 |
19 | from mmdet.apis import multi_gpu_test, single_gpu_test
20 | from mmdet.datasets import (build_dataloader, build_dataset,
21 | replace_ImageToTensor)
22 | from mmdet.models import build_detector, build_backbone
23 |
24 | from backbone import xcit
25 |
26 |
27 | def parse_args():
28 | parser = argparse.ArgumentParser(
29 | description='MMDet test (and eval) a model')
30 | parser.add_argument('config', help='test config file path')
31 | parser.add_argument('checkpoint', help='checkpoint file')
32 | parser.add_argument('--out', help='output result file in pickle format')
33 | parser.add_argument(
34 | '--fuse-conv-bn',
35 | action='store_true',
36 | help='Whether to fuse conv and bn, this will slightly increase'
37 | 'the inference speed')
38 | parser.add_argument(
39 | '--format-only',
40 | action='store_true',
41 | help='Format the output results without perform evaluation. It is'
42 | 'useful when you want to format the result to a specific format and '
43 | 'submit it to the test server')
44 | parser.add_argument(
45 | '--eval',
46 | type=str,
47 | nargs='+',
48 | help='evaluation metrics, which depends on the dataset, e.g., "bbox",'
49 | ' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC')
50 | parser.add_argument('--show', action='store_true', help='show results')
51 | parser.add_argument(
52 | '--show-dir', help='directory where painted images will be saved')
53 | parser.add_argument(
54 | '--show-score-thr',
55 | type=float,
56 | default=0.3,
57 | help='score threshold (default: 0.3)')
58 | parser.add_argument(
59 | '--gpu-collect',
60 | action='store_true',
61 | help='whether to use gpu to collect results.')
62 | parser.add_argument(
63 | '--tmpdir',
64 | help='tmp directory used for collecting results from multiple '
65 | 'workers, available when gpu-collect is not specified')
66 | parser.add_argument(
67 | '--cfg-options',
68 | nargs='+',
69 | action=DictAction,
70 | help='override some settings in the used config, the key-value pair '
71 | 'in xxx=yyy format will be merged into config file. If the value to '
72 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
73 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
74 | 'Note that the quotation marks are necessary and that no white space '
75 | 'is allowed.')
76 | parser.add_argument(
77 | '--options',
78 | nargs='+',
79 | action=DictAction,
80 | help='custom options for evaluation, the key-value pair in xxx=yyy '
81 | 'format will be kwargs for dataset.evaluate() function (deprecate), '
82 | 'change to --eval-options instead.')
83 | parser.add_argument(
84 | '--eval-options',
85 | nargs='+',
86 | action=DictAction,
87 | help='custom options for evaluation, the key-value pair in xxx=yyy '
88 | 'format will be kwargs for dataset.evaluate() function')
89 | parser.add_argument(
90 | '--launcher',
91 | choices=['none', 'pytorch', 'slurm', 'mpi'],
92 | default='none',
93 | help='job launcher')
94 | parser.add_argument('--local_rank', type=int, default=0)
95 | args = parser.parse_args()
96 | if 'LOCAL_RANK' not in os.environ:
97 | os.environ['LOCAL_RANK'] = str(args.local_rank)
98 |
99 | if args.options and args.eval_options:
100 | raise ValueError(
101 | '--options and --eval-options cannot be both '
102 | 'specified, --options is deprecated in favor of --eval-options')
103 | if args.options:
104 | warnings.warn('--options is deprecated in favor of --eval-options')
105 | args.eval_options = args.options
106 | return args
107 |
108 |
109 | def main():
110 | args = parse_args()
111 |
112 | assert args.out or args.eval or args.format_only or args.show \
113 | or args.show_dir, \
114 | ('Please specify at least one operation (save/eval/format/show the '
115 | 'results / save the results) with the argument "--out", "--eval"'
116 | ', "--format-only", "--show" or "--show-dir"')
117 |
118 | if args.eval and args.format_only:
119 | raise ValueError('--eval and --format_only cannot be both specified')
120 |
121 | if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
122 | raise ValueError('The output file must be a pkl file.')
123 |
124 | cfg = Config.fromfile(args.config)
125 | if args.cfg_options is not None:
126 | cfg.merge_from_dict(args.cfg_options)
127 | # import modules from string list.
128 | if cfg.get('custom_imports', None):
129 | from mmcv.utils import import_modules_from_strings
130 | import_modules_from_strings(**cfg['custom_imports'])
131 | # set cudnn_benchmark
132 | if cfg.get('cudnn_benchmark', False):
133 | torch.backends.cudnn.benchmark = True
134 | cfg.model.pretrained = None
135 | if cfg.model.get('neck'):
136 | if isinstance(cfg.model.neck, list):
137 | for neck_cfg in cfg.model.neck:
138 | if neck_cfg.get('rfp_backbone'):
139 | if neck_cfg.rfp_backbone.get('pretrained'):
140 | neck_cfg.rfp_backbone.pretrained = None
141 | elif cfg.model.neck.get('rfp_backbone'):
142 | if cfg.model.neck.rfp_backbone.get('pretrained'):
143 | cfg.model.neck.rfp_backbone.pretrained = None
144 |
145 | # in case the test dataset is concatenated
146 | samples_per_gpu = 1
147 | if isinstance(cfg.data.test, dict):
148 | cfg.data.test.test_mode = True
149 | samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1)
150 | if samples_per_gpu > 1:
151 | # Replace 'ImageToTensor' to 'DefaultFormatBundle'
152 | cfg.data.test.pipeline = replace_ImageToTensor(
153 | cfg.data.test.pipeline)
154 | elif isinstance(cfg.data.test, list):
155 | for ds_cfg in cfg.data.test:
156 | ds_cfg.test_mode = True
157 | samples_per_gpu = max(
158 | [ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test])
159 | if samples_per_gpu > 1:
160 | for ds_cfg in cfg.data.test:
161 | ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)
162 |
163 | # init distributed env first, since logger depends on the dist info.
164 | if args.launcher == 'none':
165 | distributed = False
166 | else:
167 | distributed = True
168 | init_dist(args.launcher, **cfg.dist_params)
169 |
170 | # build the dataloader
171 | dataset = build_dataset(cfg.data.test)
172 | data_loader = build_dataloader(
173 | dataset,
174 | samples_per_gpu=samples_per_gpu,
175 | workers_per_gpu=cfg.data.workers_per_gpu,
176 | dist=distributed,
177 | shuffle=False)
178 |
179 | # build the model and load checkpoint
180 | cfg.model.train_cfg = None
181 | model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
182 | fp16_cfg = cfg.get('fp16', None)
183 | if fp16_cfg is not None:
184 | wrap_fp16_model(model)
185 | checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
186 | if args.fuse_conv_bn:
187 | model = fuse_conv_bn(model)
188 | # old versions did not save class info in checkpoints, this walkaround is
189 | # for backward compatibility
190 | if 'CLASSES' in checkpoint.get('meta', {}):
191 | model.CLASSES = checkpoint['meta']['CLASSES']
192 | else:
193 | model.CLASSES = dataset.CLASSES
194 |
195 | if not distributed:
196 | model = MMDataParallel(model, device_ids=[0])
197 | outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,
198 | args.show_score_thr)
199 | else:
200 | model = MMDistributedDataParallel(
201 | model.cuda(),
202 | device_ids=[torch.cuda.current_device()],
203 | broadcast_buffers=False)
204 | outputs = multi_gpu_test(model, data_loader, args.tmpdir,
205 | args.gpu_collect)
206 |
207 | rank, _ = get_dist_info()
208 | if rank == 0:
209 | if args.out:
210 | print(f'\nwriting results to {args.out}')
211 | mmcv.dump(outputs, args.out)
212 | kwargs = {} if args.eval_options is None else args.eval_options
213 | if args.format_only:
214 | dataset.format_results(outputs, **kwargs)
215 | if args.eval:
216 | eval_kwargs = cfg.get('evaluation', {}).copy()
217 | # hard-code way to remove EvalHook args
218 | for key in [
219 | 'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',
220 | 'rule'
221 | ]:
222 | eval_kwargs.pop(key, None)
223 | eval_kwargs.update(dict(metric=args.eval, **kwargs))
224 | print(dataset.evaluate(outputs, **eval_kwargs))
225 |
226 |
227 | if __name__ == '__main__':
228 | main()
229 |
--------------------------------------------------------------------------------
/detection/tools/train.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2015-present, Facebook, Inc.
2 | # All rights reserved.
3 | """
4 | Training script modified from
5 | https://github.com/open-mmlab/mmdetection
6 | """
7 | import argparse
8 | import copy
9 | import os
10 | import os.path as osp
11 | import time
12 | import warnings
13 |
14 | import mmcv
15 | import torch
16 | from mmcv import Config, DictAction
17 | from mmcv.runner import get_dist_info, init_dist
18 | from mmcv.utils import get_git_hash
19 |
20 | from mmdet import __version__
21 | from mmdet.apis import set_random_seed, train_detector
22 | from mmdet.datasets import build_dataset
23 | from mmdet.models import build_detector
24 | from mmdet.utils import collect_env, get_root_logger
25 |
26 | from backbone import vit_PASS
27 |
28 |
29 | def parse_args():
30 | parser = argparse.ArgumentParser(description='Train a detector')
31 | parser.add_argument('config', help='train config file path')
32 | parser.add_argument('--work-dir', help='the dir to save logs and models')
33 | parser.add_argument(
34 | '--resume-from', help='the checkpoint file to resume from')
35 | parser.add_argument(
36 | '--no-validate',
37 | action='store_true',
38 | help='whether not to evaluate the checkpoint during training')
39 | group_gpus = parser.add_mutually_exclusive_group()
40 | group_gpus.add_argument(
41 | '--gpus',
42 | type=int,
43 | help='number of gpus to use '
44 | '(only applicable to non-distributed training)')
45 | group_gpus.add_argument(
46 | '--gpu-ids',
47 | type=int,
48 | nargs='+',
49 | help='ids of gpus to use '
50 | '(only applicable to non-distributed training)')
51 | parser.add_argument('--seed', type=int, default=None, help='random seed')
52 | parser.add_argument(
53 | '--deterministic',
54 | action='store_true',
55 | help='whether to set deterministic options for CUDNN backend.')
56 | parser.add_argument(
57 | '--options',
58 | nargs='+',
59 | action=DictAction,
60 | help='override some settings in the used config, the key-value pair '
61 | 'in xxx=yyy format will be merged into config file (deprecate), '
62 | 'change to --cfg-options instead.')
63 | parser.add_argument(
64 | '--cfg-options',
65 | nargs='+',
66 | action=DictAction,
67 | help='override some settings in the used config, the key-value pair '
68 | 'in xxx=yyy format will be merged into config file. If the value to '
69 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
70 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
71 | 'Note that the quotation marks are necessary and that no white space '
72 | 'is allowed.')
73 | parser.add_argument(
74 | '--launcher',
75 | choices=['none', 'pytorch', 'slurm', 'mpi'],
76 | default='none',
77 | help='job launcher')
78 | parser.add_argument('--local_rank', type=int, default=0)
79 | args = parser.parse_args()
80 | if 'LOCAL_RANK' not in os.environ:
81 | os.environ['LOCAL_RANK'] = str(args.local_rank)
82 |
83 | if args.options and args.cfg_options:
84 | raise ValueError(
85 | '--options and --cfg-options cannot be both '
86 | 'specified, --options is deprecated in favor of --cfg-options')
87 | if args.options:
88 | warnings.warn('--options is deprecated in favor of --cfg-options')
89 | args.cfg_options = args.options
90 |
91 | return args
92 |
93 |
94 | def main():
95 | args = parse_args()
96 |
97 | cfg = Config.fromfile(args.config)
98 | if args.cfg_options is not None:
99 | cfg.merge_from_dict(args.cfg_options)
100 | # import modules from string list.
101 | if cfg.get('custom_imports', None):
102 | from mmcv.utils import import_modules_from_strings
103 | import_modules_from_strings(**cfg['custom_imports'])
104 | # set cudnn_benchmark
105 | if cfg.get('cudnn_benchmark', False):
106 | torch.backends.cudnn.benchmark = True
107 |
108 | # work_dir is determined in this priority: CLI > segment in file > filename
109 | if args.work_dir is not None:
110 | # update configs according to CLI args if args.work_dir is not None
111 | cfg.work_dir = args.work_dir
112 | elif cfg.get('work_dir', None) is None:
113 | # use config filename as default work_dir if cfg.work_dir is None
114 | cfg.work_dir = osp.join('./work_dirs',
115 | osp.splitext(osp.basename(args.config))[0])
116 | if args.resume_from is not None:
117 | cfg.resume_from = args.resume_from
118 | if args.gpu_ids is not None:
119 | cfg.gpu_ids = args.gpu_ids
120 | else:
121 | cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
122 |
123 | # init distributed env first, since logger depends on the dist info.
124 | if args.launcher == 'none':
125 | distributed = False
126 | else:
127 | distributed = True
128 | init_dist(args.launcher, **cfg.dist_params)
129 | # re-set gpu_ids with distributed training mode
130 | _, world_size = get_dist_info()
131 | cfg.gpu_ids = range(world_size)
132 |
133 | # create work_dir
134 | mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
135 | # dump config
136 | cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
137 | # init the logger before other steps
138 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
139 | log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
140 | logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
141 |
142 | # init the meta dict to record some important information such as
143 | # environment info and seed, which will be logged
144 | meta = dict()
145 | # log env info
146 | env_info_dict = collect_env()
147 | env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
148 | dash_line = '-' * 60 + '\n'
149 | logger.info('Environment info:\n' + dash_line + env_info + '\n' +
150 | dash_line)
151 | meta['env_info'] = env_info
152 | meta['config'] = cfg.pretty_text
153 | # log some basic info
154 | logger.info(f'Distributed training: {distributed}')
155 | logger.info(f'Config:\n{cfg.pretty_text}')
156 |
157 | # set random seeds
158 | if args.seed is not None:
159 | logger.info(f'Set random seed to {args.seed}, '
160 | f'deterministic: {args.deterministic}')
161 | set_random_seed(args.seed, deterministic=args.deterministic)
162 | cfg.seed = args.seed
163 | meta['seed'] = args.seed
164 | meta['exp_name'] = osp.basename(args.config)
165 |
166 | model = build_detector(
167 | cfg.model,
168 | train_cfg=cfg.get('train_cfg'),
169 | test_cfg=cfg.get('test_cfg'))
170 |
171 | datasets = [build_dataset(cfg.data.train)]
172 | if len(cfg.workflow) == 2:
173 | val_dataset = copy.deepcopy(cfg.data.val)
174 | val_dataset.pipeline = cfg.data.train.pipeline
175 | datasets.append(build_dataset(val_dataset))
176 | if cfg.checkpoint_config is not None:
177 | # save mmdet version, config file content and class names in
178 | # checkpoints as meta data
179 | cfg.checkpoint_config.meta = dict(
180 | mmdet_version=__version__ + get_git_hash()[:7],
181 | CLASSES=datasets[0].CLASSES)
182 | # add an attribute for visualization convenience
183 | model.CLASSES = datasets[0].CLASSES
184 | train_detector(
185 | model,
186 | datasets,
187 | cfg,
188 | distributed=distributed,
189 | validate=(not args.no_validate),
190 | timestamp=timestamp,
191 | meta=meta)
192 |
193 |
194 | if __name__ == '__main__':
195 | main()
196 |
--------------------------------------------------------------------------------
/segmentation/README.md:
--------------------------------------------------------------------------------
1 | ## Evaluating semantic segmentation on the ADE20K dataset
2 |
3 | Step 1. Prepare ADE20K dataset
4 |
5 | The dataset can be downloaded at
6 | `http://groups.csail.mit.edu/vision/datasets/ADE20K/toolkit/index_ade20k.pkl`
7 |
8 | or following instruction of `https://github.com/CSAILVision/ADE20K`
9 |
10 | Step 2. Install mmsegmentation
11 |
12 | ```
13 | git clone https://github.com/open-mmlab/mmsegmentation.git
14 | ```
15 |
16 | Step 3. Convert your model
17 |
18 | ```
19 | python tools/model_converters/vit2mmseg.py /path/to/model_dir /path/to/saving_dir
20 | ```
21 |
22 | Step 4. Fine-tune on the ADE20K dataset
23 |
24 | ```
25 | tools/dist_train.sh configs/selfpatch/semfpn_vit-s16_512x512_40k_ade20k.py [number of gpu] --work-dir /path/to/saving_dir --seed 0 --deterministic --options model.pretrained=/path/to/model_dir
26 | ```
27 |
28 | ## Pretrained weights on ADE20K
29 | You can download the weights of the fine-tuned models on semantic segmentation task. We provide fine-tuned models with `Semantic FPN` (40k iterations) and `UperNet` (160k iterations).
30 |
31 | | backbone | arch | iterations | mIoU | checkpoint |
32 | | ------------- | ------------- | ------------- | ------------- | ------------- |
33 | | DINO | ViT-S/16 + Semantic FPN | 40k | 38.3 | download |
34 | | DINO + SelfPatch | ViT-S/16 + Semantic FPN | 40k | 41.2 | download |
35 | | DINO | ViT-S/16 + UperNet | 160k | 42.3 | download |
36 | | DINO + SelfPatch | ViT-S/16 + UperNet | 160k | 43.2 | download |
37 |
38 | ## Acknowledgement
39 | This code is built using the mmsegmentation libray. The optimization hyperarameters are adopted from XCiT.
40 |
--------------------------------------------------------------------------------
/segmentation/backbones/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .bisenetv1 import BiSeNetV1
3 | from .bisenetv2 import BiSeNetV2
4 | from .cgnet import CGNet
5 | from .fast_scnn import FastSCNN
6 | from .hrnet import HRNet
7 | from .mit import MixVisionTransformer
8 | from .mobilenet_v2 import MobileNetV2
9 | from .mobilenet_v3 import MobileNetV3
10 | from .resnest import ResNeSt
11 | from .resnet import ResNet, ResNetV1c, ResNetV1d
12 | from .resnext import ResNeXt
13 | from .swin import SwinTransformer
14 | from .unet import UNet
15 | from .vit_SelfPatch import SelfPatch_ViT
16 |
17 | __all__ = [
18 | 'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN',
19 | 'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
20 | 'SelfPatch_ViT', 'SwinTransformer', 'MixVisionTransformer',
21 | 'BiSeNetV1', 'BiSeNetV2'
22 | ]
23 |
--------------------------------------------------------------------------------
/segmentation/backbones/mobilenet_v2.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import warnings
3 |
4 | import torch.nn as nn
5 | from mmcv.cnn import ConvModule
6 | from mmcv.runner import BaseModule
7 | from torch.nn.modules.batchnorm import _BatchNorm
8 |
9 | from ..builder import BACKBONES
10 | from ..utils import InvertedResidual, make_divisible
11 |
12 |
13 | @BACKBONES.register_module()
14 | class MobileNetV2(BaseModule):
15 | """MobileNetV2 backbone.
16 |
17 | This backbone is the implementation of
18 | `MobileNetV2: Inverted Residuals and Linear Bottlenecks
19 | `_.
20 |
21 | Args:
22 | widen_factor (float): Width multiplier, multiply number of
23 | channels in each layer by this amount. Default: 1.0.
24 | strides (Sequence[int], optional): Strides of the first block of each
25 | layer. If not specified, default config in ``arch_setting`` will
26 | be used.
27 | dilations (Sequence[int]): Dilation of each layer.
28 | out_indices (None or Sequence[int]): Output from which stages.
29 | Default: (7, ).
30 | frozen_stages (int): Stages to be frozen (all param fixed).
31 | Default: -1, which means not freezing any parameters.
32 | conv_cfg (dict): Config dict for convolution layer.
33 | Default: None, which means using conv2d.
34 | norm_cfg (dict): Config dict for normalization layer.
35 | Default: dict(type='BN').
36 | act_cfg (dict): Config dict for activation layer.
37 | Default: dict(type='ReLU6').
38 | norm_eval (bool): Whether to set norm layers to eval mode, namely,
39 | freeze running stats (mean and var). Note: Effect on Batch Norm
40 | and its variants only. Default: False.
41 | with_cp (bool): Use checkpoint or not. Using checkpoint will save some
42 | memory while slowing down the training speed. Default: False.
43 | pretrained (str, optional): model pretrained path. Default: None
44 | init_cfg (dict or list[dict], optional): Initialization config dict.
45 | Default: None
46 | """
47 |
48 | # Parameters to build layers. 3 parameters are needed to construct a
49 | # layer, from left to right: expand_ratio, channel, num_blocks.
50 | arch_settings = [[1, 16, 1], [6, 24, 2], [6, 32, 3], [6, 64, 4],
51 | [6, 96, 3], [6, 160, 3], [6, 320, 1]]
52 |
53 | def __init__(self,
54 | widen_factor=1.,
55 | strides=(1, 2, 2, 2, 1, 2, 1),
56 | dilations=(1, 1, 1, 1, 1, 1, 1),
57 | out_indices=(1, 2, 4, 6),
58 | frozen_stages=-1,
59 | conv_cfg=None,
60 | norm_cfg=dict(type='BN'),
61 | act_cfg=dict(type='ReLU6'),
62 | norm_eval=False,
63 | with_cp=False,
64 | pretrained=None,
65 | init_cfg=None):
66 | super(MobileNetV2, self).__init__(init_cfg)
67 |
68 | self.pretrained = pretrained
69 | assert not (init_cfg and pretrained), \
70 | 'init_cfg and pretrained cannot be setting at the same time'
71 | if isinstance(pretrained, str):
72 | warnings.warn('DeprecationWarning: pretrained is a deprecated, '
73 | 'please use "init_cfg" instead')
74 | self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
75 | elif pretrained is None:
76 | if init_cfg is None:
77 | self.init_cfg = [
78 | dict(type='Kaiming', layer='Conv2d'),
79 | dict(
80 | type='Constant',
81 | val=1,
82 | layer=['_BatchNorm', 'GroupNorm'])
83 | ]
84 | else:
85 | raise TypeError('pretrained must be a str or None')
86 |
87 | self.widen_factor = widen_factor
88 | self.strides = strides
89 | self.dilations = dilations
90 | assert len(strides) == len(dilations) == len(self.arch_settings)
91 | self.out_indices = out_indices
92 | for index in out_indices:
93 | if index not in range(0, 7):
94 | raise ValueError('the item in out_indices must in '
95 | f'range(0, 7). But received {index}')
96 |
97 | if frozen_stages not in range(-1, 7):
98 | raise ValueError('frozen_stages must be in range(-1, 7). '
99 | f'But received {frozen_stages}')
100 | self.out_indices = out_indices
101 | self.frozen_stages = frozen_stages
102 | self.conv_cfg = conv_cfg
103 | self.norm_cfg = norm_cfg
104 | self.act_cfg = act_cfg
105 | self.norm_eval = norm_eval
106 | self.with_cp = with_cp
107 |
108 | self.in_channels = make_divisible(32 * widen_factor, 8)
109 |
110 | self.conv1 = ConvModule(
111 | in_channels=3,
112 | out_channels=self.in_channels,
113 | kernel_size=3,
114 | stride=2,
115 | padding=1,
116 | conv_cfg=self.conv_cfg,
117 | norm_cfg=self.norm_cfg,
118 | act_cfg=self.act_cfg)
119 |
120 | self.layers = []
121 |
122 | for i, layer_cfg in enumerate(self.arch_settings):
123 | expand_ratio, channel, num_blocks = layer_cfg
124 | stride = self.strides[i]
125 | dilation = self.dilations[i]
126 | out_channels = make_divisible(channel * widen_factor, 8)
127 | inverted_res_layer = self.make_layer(
128 | out_channels=out_channels,
129 | num_blocks=num_blocks,
130 | stride=stride,
131 | dilation=dilation,
132 | expand_ratio=expand_ratio)
133 | layer_name = f'layer{i + 1}'
134 | self.add_module(layer_name, inverted_res_layer)
135 | self.layers.append(layer_name)
136 |
137 | def make_layer(self, out_channels, num_blocks, stride, dilation,
138 | expand_ratio):
139 | """Stack InvertedResidual blocks to build a layer for MobileNetV2.
140 |
141 | Args:
142 | out_channels (int): out_channels of block.
143 | num_blocks (int): Number of blocks.
144 | stride (int): Stride of the first block.
145 | dilation (int): Dilation of the first block.
146 | expand_ratio (int): Expand the number of channels of the
147 | hidden layer in InvertedResidual by this ratio.
148 | """
149 | layers = []
150 | for i in range(num_blocks):
151 | layers.append(
152 | InvertedResidual(
153 | self.in_channels,
154 | out_channels,
155 | stride if i == 0 else 1,
156 | expand_ratio=expand_ratio,
157 | dilation=dilation if i == 0 else 1,
158 | conv_cfg=self.conv_cfg,
159 | norm_cfg=self.norm_cfg,
160 | act_cfg=self.act_cfg,
161 | with_cp=self.with_cp))
162 | self.in_channels = out_channels
163 |
164 | return nn.Sequential(*layers)
165 |
166 | def forward(self, x):
167 | x = self.conv1(x)
168 |
169 | outs = []
170 | for i, layer_name in enumerate(self.layers):
171 | layer = getattr(self, layer_name)
172 | x = layer(x)
173 | if i in self.out_indices:
174 | outs.append(x)
175 |
176 | if len(outs) == 1:
177 | return outs[0]
178 | else:
179 | return tuple(outs)
180 |
181 | def _freeze_stages(self):
182 | if self.frozen_stages >= 0:
183 | for param in self.conv1.parameters():
184 | param.requires_grad = False
185 | for i in range(1, self.frozen_stages + 1):
186 | layer = getattr(self, f'layer{i}')
187 | layer.eval()
188 | for param in layer.parameters():
189 | param.requires_grad = False
190 |
191 | def train(self, mode=True):
192 | super(MobileNetV2, self).train(mode)
193 | self._freeze_stages()
194 | if mode and self.norm_eval:
195 | for m in self.modules():
196 | if isinstance(m, _BatchNorm):
197 | m.eval()
198 |
--------------------------------------------------------------------------------
/segmentation/backbones/mobilenet_v3.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import warnings
3 |
4 | import mmcv
5 | from mmcv.cnn import ConvModule
6 | from mmcv.cnn.bricks import Conv2dAdaptivePadding
7 | from mmcv.runner import BaseModule
8 | from torch.nn.modules.batchnorm import _BatchNorm
9 |
10 | from ..builder import BACKBONES
11 | from ..utils import InvertedResidualV3 as InvertedResidual
12 |
13 |
14 | @BACKBONES.register_module()
15 | class MobileNetV3(BaseModule):
16 | """MobileNetV3 backbone.
17 |
18 | This backbone is the improved implementation of `Searching for MobileNetV3
19 | `_.
20 |
21 | Args:
22 | arch (str): Architecture of mobilnetv3, from {'small', 'large'}.
23 | Default: 'small'.
24 | conv_cfg (dict): Config dict for convolution layer.
25 | Default: None, which means using conv2d.
26 | norm_cfg (dict): Config dict for normalization layer.
27 | Default: dict(type='BN').
28 | out_indices (tuple[int]): Output from which layer.
29 | Default: (0, 1, 12).
30 | frozen_stages (int): Stages to be frozen (all param fixed).
31 | Default: -1, which means not freezing any parameters.
32 | norm_eval (bool): Whether to set norm layers to eval mode, namely,
33 | freeze running stats (mean and var). Note: Effect on Batch Norm
34 | and its variants only. Default: False.
35 | with_cp (bool): Use checkpoint or not. Using checkpoint will save
36 | some memory while slowing down the training speed.
37 | Default: False.
38 | pretrained (str, optional): model pretrained path. Default: None
39 | init_cfg (dict or list[dict], optional): Initialization config dict.
40 | Default: None
41 | """
42 | # Parameters to build each block:
43 | # [kernel size, mid channels, out channels, with_se, act type, stride]
44 | arch_settings = {
45 | 'small': [[3, 16, 16, True, 'ReLU', 2], # block0 layer1 os=4
46 | [3, 72, 24, False, 'ReLU', 2], # block1 layer2 os=8
47 | [3, 88, 24, False, 'ReLU', 1],
48 | [5, 96, 40, True, 'HSwish', 2], # block2 layer4 os=16
49 | [5, 240, 40, True, 'HSwish', 1],
50 | [5, 240, 40, True, 'HSwish', 1],
51 | [5, 120, 48, True, 'HSwish', 1], # block3 layer7 os=16
52 | [5, 144, 48, True, 'HSwish', 1],
53 | [5, 288, 96, True, 'HSwish', 2], # block4 layer9 os=32
54 | [5, 576, 96, True, 'HSwish', 1],
55 | [5, 576, 96, True, 'HSwish', 1]],
56 | 'large': [[3, 16, 16, False, 'ReLU', 1], # block0 layer1 os=2
57 | [3, 64, 24, False, 'ReLU', 2], # block1 layer2 os=4
58 | [3, 72, 24, False, 'ReLU', 1],
59 | [5, 72, 40, True, 'ReLU', 2], # block2 layer4 os=8
60 | [5, 120, 40, True, 'ReLU', 1],
61 | [5, 120, 40, True, 'ReLU', 1],
62 | [3, 240, 80, False, 'HSwish', 2], # block3 layer7 os=16
63 | [3, 200, 80, False, 'HSwish', 1],
64 | [3, 184, 80, False, 'HSwish', 1],
65 | [3, 184, 80, False, 'HSwish', 1],
66 | [3, 480, 112, True, 'HSwish', 1], # block4 layer11 os=16
67 | [3, 672, 112, True, 'HSwish', 1],
68 | [5, 672, 160, True, 'HSwish', 2], # block5 layer13 os=32
69 | [5, 960, 160, True, 'HSwish', 1],
70 | [5, 960, 160, True, 'HSwish', 1]]
71 | } # yapf: disable
72 |
73 | def __init__(self,
74 | arch='small',
75 | conv_cfg=None,
76 | norm_cfg=dict(type='BN'),
77 | out_indices=(0, 1, 12),
78 | frozen_stages=-1,
79 | reduction_factor=1,
80 | norm_eval=False,
81 | with_cp=False,
82 | pretrained=None,
83 | init_cfg=None):
84 | super(MobileNetV3, self).__init__(init_cfg)
85 |
86 | self.pretrained = pretrained
87 | assert not (init_cfg and pretrained), \
88 | 'init_cfg and pretrained cannot be setting at the same time'
89 | if isinstance(pretrained, str):
90 | warnings.warn('DeprecationWarning: pretrained is a deprecated, '
91 | 'please use "init_cfg" instead')
92 | self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
93 | elif pretrained is None:
94 | if init_cfg is None:
95 | self.init_cfg = [
96 | dict(type='Kaiming', layer='Conv2d'),
97 | dict(
98 | type='Constant',
99 | val=1,
100 | layer=['_BatchNorm', 'GroupNorm'])
101 | ]
102 | else:
103 | raise TypeError('pretrained must be a str or None')
104 |
105 | assert arch in self.arch_settings
106 | assert isinstance(reduction_factor, int) and reduction_factor > 0
107 | assert mmcv.is_tuple_of(out_indices, int)
108 | for index in out_indices:
109 | if index not in range(0, len(self.arch_settings[arch]) + 2):
110 | raise ValueError(
111 | 'the item in out_indices must in '
112 | f'range(0, {len(self.arch_settings[arch])+2}). '
113 | f'But received {index}')
114 |
115 | if frozen_stages not in range(-1, len(self.arch_settings[arch]) + 2):
116 | raise ValueError('frozen_stages must be in range(-1, '
117 | f'{len(self.arch_settings[arch])+2}). '
118 | f'But received {frozen_stages}')
119 | self.arch = arch
120 | self.conv_cfg = conv_cfg
121 | self.norm_cfg = norm_cfg
122 | self.out_indices = out_indices
123 | self.frozen_stages = frozen_stages
124 | self.reduction_factor = reduction_factor
125 | self.norm_eval = norm_eval
126 | self.with_cp = with_cp
127 | self.layers = self._make_layer()
128 |
129 | def _make_layer(self):
130 | layers = []
131 |
132 | # build the first layer (layer0)
133 | in_channels = 16
134 | layer = ConvModule(
135 | in_channels=3,
136 | out_channels=in_channels,
137 | kernel_size=3,
138 | stride=2,
139 | padding=1,
140 | conv_cfg=dict(type='Conv2dAdaptivePadding'),
141 | norm_cfg=self.norm_cfg,
142 | act_cfg=dict(type='HSwish'))
143 | self.add_module('layer0', layer)
144 | layers.append('layer0')
145 |
146 | layer_setting = self.arch_settings[self.arch]
147 | for i, params in enumerate(layer_setting):
148 | (kernel_size, mid_channels, out_channels, with_se, act,
149 | stride) = params
150 |
151 | if self.arch == 'large' and i >= 12 or self.arch == 'small' and \
152 | i >= 8:
153 | mid_channels = mid_channels // self.reduction_factor
154 | out_channels = out_channels // self.reduction_factor
155 |
156 | if with_se:
157 | se_cfg = dict(
158 | channels=mid_channels,
159 | ratio=4,
160 | act_cfg=(dict(type='ReLU'),
161 | dict(type='HSigmoid', bias=3.0, divisor=6.0)))
162 | else:
163 | se_cfg = None
164 |
165 | layer = InvertedResidual(
166 | in_channels=in_channels,
167 | out_channels=out_channels,
168 | mid_channels=mid_channels,
169 | kernel_size=kernel_size,
170 | stride=stride,
171 | se_cfg=se_cfg,
172 | with_expand_conv=(in_channels != mid_channels),
173 | conv_cfg=self.conv_cfg,
174 | norm_cfg=self.norm_cfg,
175 | act_cfg=dict(type=act),
176 | with_cp=self.with_cp)
177 | in_channels = out_channels
178 | layer_name = 'layer{}'.format(i + 1)
179 | self.add_module(layer_name, layer)
180 | layers.append(layer_name)
181 |
182 | # build the last layer
183 | # block5 layer12 os=32 for small model
184 | # block6 layer16 os=32 for large model
185 | layer = ConvModule(
186 | in_channels=in_channels,
187 | out_channels=576 if self.arch == 'small' else 960,
188 | kernel_size=1,
189 | stride=1,
190 | dilation=4,
191 | padding=0,
192 | conv_cfg=self.conv_cfg,
193 | norm_cfg=self.norm_cfg,
194 | act_cfg=dict(type='HSwish'))
195 | layer_name = 'layer{}'.format(len(layer_setting) + 1)
196 | self.add_module(layer_name, layer)
197 | layers.append(layer_name)
198 |
199 | # next, convert backbone MobileNetV3 to a semantic segmentation version
200 | if self.arch == 'small':
201 | self.layer4.depthwise_conv.conv.stride = (1, 1)
202 | self.layer9.depthwise_conv.conv.stride = (1, 1)
203 | for i in range(4, len(layers)):
204 | layer = getattr(self, layers[i])
205 | if isinstance(layer, InvertedResidual):
206 | modified_module = layer.depthwise_conv.conv
207 | else:
208 | modified_module = layer.conv
209 |
210 | if i < 9:
211 | modified_module.dilation = (2, 2)
212 | pad = 2
213 | else:
214 | modified_module.dilation = (4, 4)
215 | pad = 4
216 |
217 | if not isinstance(modified_module, Conv2dAdaptivePadding):
218 | # Adjust padding
219 | pad *= (modified_module.kernel_size[0] - 1) // 2
220 | modified_module.padding = (pad, pad)
221 | else:
222 | self.layer7.depthwise_conv.conv.stride = (1, 1)
223 | self.layer13.depthwise_conv.conv.stride = (1, 1)
224 | for i in range(7, len(layers)):
225 | layer = getattr(self, layers[i])
226 | if isinstance(layer, InvertedResidual):
227 | modified_module = layer.depthwise_conv.conv
228 | else:
229 | modified_module = layer.conv
230 |
231 | if i < 13:
232 | modified_module.dilation = (2, 2)
233 | pad = 2
234 | else:
235 | modified_module.dilation = (4, 4)
236 | pad = 4
237 |
238 | if not isinstance(modified_module, Conv2dAdaptivePadding):
239 | # Adjust padding
240 | pad *= (modified_module.kernel_size[0] - 1) // 2
241 | modified_module.padding = (pad, pad)
242 |
243 | return layers
244 |
245 | def forward(self, x):
246 | outs = []
247 | for i, layer_name in enumerate(self.layers):
248 | layer = getattr(self, layer_name)
249 | x = layer(x)
250 | if i in self.out_indices:
251 | outs.append(x)
252 | return outs
253 |
254 | def _freeze_stages(self):
255 | for i in range(self.frozen_stages + 1):
256 | layer = getattr(self, f'layer{i}')
257 | layer.eval()
258 | for param in layer.parameters():
259 | param.requires_grad = False
260 |
261 | def train(self, mode=True):
262 | super(MobileNetV3, self).train(mode)
263 | self._freeze_stages()
264 | if mode and self.norm_eval:
265 | for m in self.modules():
266 | if isinstance(m, _BatchNorm):
267 | m.eval()
268 |
--------------------------------------------------------------------------------
/segmentation/backbones/resnest.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import math
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import torch.utils.checkpoint as cp
8 | from mmcv.cnn import build_conv_layer, build_norm_layer
9 |
10 | from ..builder import BACKBONES
11 | from ..utils import ResLayer
12 | from .resnet import Bottleneck as _Bottleneck
13 | from .resnet import ResNetV1d
14 |
15 |
16 | class RSoftmax(nn.Module):
17 | """Radix Softmax module in ``SplitAttentionConv2d``.
18 |
19 | Args:
20 | radix (int): Radix of input.
21 | groups (int): Groups of input.
22 | """
23 |
24 | def __init__(self, radix, groups):
25 | super().__init__()
26 | self.radix = radix
27 | self.groups = groups
28 |
29 | def forward(self, x):
30 | batch = x.size(0)
31 | if self.radix > 1:
32 | x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2)
33 | x = F.softmax(x, dim=1)
34 | x = x.reshape(batch, -1)
35 | else:
36 | x = torch.sigmoid(x)
37 | return x
38 |
39 |
40 | class SplitAttentionConv2d(nn.Module):
41 | """Split-Attention Conv2d in ResNeSt.
42 |
43 | Args:
44 | in_channels (int): Same as nn.Conv2d.
45 | out_channels (int): Same as nn.Conv2d.
46 | kernel_size (int | tuple[int]): Same as nn.Conv2d.
47 | stride (int | tuple[int]): Same as nn.Conv2d.
48 | padding (int | tuple[int]): Same as nn.Conv2d.
49 | dilation (int | tuple[int]): Same as nn.Conv2d.
50 | groups (int): Same as nn.Conv2d.
51 | radix (int): Radix of SpltAtConv2d. Default: 2
52 | reduction_factor (int): Reduction factor of inter_channels. Default: 4.
53 | conv_cfg (dict): Config dict for convolution layer. Default: None,
54 | which means using conv2d.
55 | norm_cfg (dict): Config dict for normalization layer. Default: None.
56 | dcn (dict): Config dict for DCN. Default: None.
57 | """
58 |
59 | def __init__(self,
60 | in_channels,
61 | channels,
62 | kernel_size,
63 | stride=1,
64 | padding=0,
65 | dilation=1,
66 | groups=1,
67 | radix=2,
68 | reduction_factor=4,
69 | conv_cfg=None,
70 | norm_cfg=dict(type='BN'),
71 | dcn=None):
72 | super(SplitAttentionConv2d, self).__init__()
73 | inter_channels = max(in_channels * radix // reduction_factor, 32)
74 | self.radix = radix
75 | self.groups = groups
76 | self.channels = channels
77 | self.with_dcn = dcn is not None
78 | self.dcn = dcn
79 | fallback_on_stride = False
80 | if self.with_dcn:
81 | fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
82 | if self.with_dcn and not fallback_on_stride:
83 | assert conv_cfg is None, 'conv_cfg must be None for DCN'
84 | conv_cfg = dcn
85 | self.conv = build_conv_layer(
86 | conv_cfg,
87 | in_channels,
88 | channels * radix,
89 | kernel_size,
90 | stride=stride,
91 | padding=padding,
92 | dilation=dilation,
93 | groups=groups * radix,
94 | bias=False)
95 | self.norm0_name, norm0 = build_norm_layer(
96 | norm_cfg, channels * radix, postfix=0)
97 | self.add_module(self.norm0_name, norm0)
98 | self.relu = nn.ReLU(inplace=True)
99 | self.fc1 = build_conv_layer(
100 | None, channels, inter_channels, 1, groups=self.groups)
101 | self.norm1_name, norm1 = build_norm_layer(
102 | norm_cfg, inter_channels, postfix=1)
103 | self.add_module(self.norm1_name, norm1)
104 | self.fc2 = build_conv_layer(
105 | None, inter_channels, channels * radix, 1, groups=self.groups)
106 | self.rsoftmax = RSoftmax(radix, groups)
107 |
108 | @property
109 | def norm0(self):
110 | """nn.Module: the normalization layer named "norm0" """
111 | return getattr(self, self.norm0_name)
112 |
113 | @property
114 | def norm1(self):
115 | """nn.Module: the normalization layer named "norm1" """
116 | return getattr(self, self.norm1_name)
117 |
118 | def forward(self, x):
119 | x = self.conv(x)
120 | x = self.norm0(x)
121 | x = self.relu(x)
122 |
123 | batch, rchannel = x.shape[:2]
124 | batch = x.size(0)
125 | if self.radix > 1:
126 | splits = x.view(batch, self.radix, -1, *x.shape[2:])
127 | gap = splits.sum(dim=1)
128 | else:
129 | gap = x
130 | gap = F.adaptive_avg_pool2d(gap, 1)
131 | gap = self.fc1(gap)
132 |
133 | gap = self.norm1(gap)
134 | gap = self.relu(gap)
135 |
136 | atten = self.fc2(gap)
137 | atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
138 |
139 | if self.radix > 1:
140 | attens = atten.view(batch, self.radix, -1, *atten.shape[2:])
141 | out = torch.sum(attens * splits, dim=1)
142 | else:
143 | out = atten * x
144 | return out.contiguous()
145 |
146 |
147 | class Bottleneck(_Bottleneck):
148 | """Bottleneck block for ResNeSt.
149 |
150 | Args:
151 | inplane (int): Input planes of this block.
152 | planes (int): Middle planes of this block.
153 | groups (int): Groups of conv2.
154 | width_per_group (int): Width per group of conv2. 64x4d indicates
155 | ``groups=64, width_per_group=4`` and 32x8d indicates
156 | ``groups=32, width_per_group=8``.
157 | radix (int): Radix of SpltAtConv2d. Default: 2
158 | reduction_factor (int): Reduction factor of inter_channels in
159 | SplitAttentionConv2d. Default: 4.
160 | avg_down_stride (bool): Whether to use average pool for stride in
161 | Bottleneck. Default: True.
162 | kwargs (dict): Key word arguments for base class.
163 | """
164 | expansion = 4
165 |
166 | def __init__(self,
167 | inplanes,
168 | planes,
169 | groups=1,
170 | base_width=4,
171 | base_channels=64,
172 | radix=2,
173 | reduction_factor=4,
174 | avg_down_stride=True,
175 | **kwargs):
176 | """Bottleneck block for ResNeSt."""
177 | super(Bottleneck, self).__init__(inplanes, planes, **kwargs)
178 |
179 | if groups == 1:
180 | width = self.planes
181 | else:
182 | width = math.floor(self.planes *
183 | (base_width / base_channels)) * groups
184 |
185 | self.avg_down_stride = avg_down_stride and self.conv2_stride > 1
186 |
187 | self.norm1_name, norm1 = build_norm_layer(
188 | self.norm_cfg, width, postfix=1)
189 | self.norm3_name, norm3 = build_norm_layer(
190 | self.norm_cfg, self.planes * self.expansion, postfix=3)
191 |
192 | self.conv1 = build_conv_layer(
193 | self.conv_cfg,
194 | self.inplanes,
195 | width,
196 | kernel_size=1,
197 | stride=self.conv1_stride,
198 | bias=False)
199 | self.add_module(self.norm1_name, norm1)
200 | self.with_modulated_dcn = False
201 | self.conv2 = SplitAttentionConv2d(
202 | width,
203 | width,
204 | kernel_size=3,
205 | stride=1 if self.avg_down_stride else self.conv2_stride,
206 | padding=self.dilation,
207 | dilation=self.dilation,
208 | groups=groups,
209 | radix=radix,
210 | reduction_factor=reduction_factor,
211 | conv_cfg=self.conv_cfg,
212 | norm_cfg=self.norm_cfg,
213 | dcn=self.dcn)
214 | delattr(self, self.norm2_name)
215 |
216 | if self.avg_down_stride:
217 | self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1)
218 |
219 | self.conv3 = build_conv_layer(
220 | self.conv_cfg,
221 | width,
222 | self.planes * self.expansion,
223 | kernel_size=1,
224 | bias=False)
225 | self.add_module(self.norm3_name, norm3)
226 |
227 | def forward(self, x):
228 |
229 | def _inner_forward(x):
230 | identity = x
231 |
232 | out = self.conv1(x)
233 | out = self.norm1(out)
234 | out = self.relu(out)
235 |
236 | if self.with_plugins:
237 | out = self.forward_plugin(out, self.after_conv1_plugin_names)
238 |
239 | out = self.conv2(out)
240 |
241 | if self.avg_down_stride:
242 | out = self.avd_layer(out)
243 |
244 | if self.with_plugins:
245 | out = self.forward_plugin(out, self.after_conv2_plugin_names)
246 |
247 | out = self.conv3(out)
248 | out = self.norm3(out)
249 |
250 | if self.with_plugins:
251 | out = self.forward_plugin(out, self.after_conv3_plugin_names)
252 |
253 | if self.downsample is not None:
254 | identity = self.downsample(x)
255 |
256 | out += identity
257 |
258 | return out
259 |
260 | if self.with_cp and x.requires_grad:
261 | out = cp.checkpoint(_inner_forward, x)
262 | else:
263 | out = _inner_forward(x)
264 |
265 | out = self.relu(out)
266 |
267 | return out
268 |
269 |
270 | @BACKBONES.register_module()
271 | class ResNeSt(ResNetV1d):
272 | """ResNeSt backbone.
273 |
274 | This backbone is the implementation of `ResNeSt:
275 | Split-Attention Networks `_.
276 |
277 | Args:
278 | groups (int): Number of groups of Bottleneck. Default: 1
279 | base_width (int): Base width of Bottleneck. Default: 4
280 | radix (int): Radix of SpltAtConv2d. Default: 2
281 | reduction_factor (int): Reduction factor of inter_channels in
282 | SplitAttentionConv2d. Default: 4.
283 | avg_down_stride (bool): Whether to use average pool for stride in
284 | Bottleneck. Default: True.
285 | kwargs (dict): Keyword arguments for ResNet.
286 | """
287 |
288 | arch_settings = {
289 | 50: (Bottleneck, (3, 4, 6, 3)),
290 | 101: (Bottleneck, (3, 4, 23, 3)),
291 | 152: (Bottleneck, (3, 8, 36, 3)),
292 | 200: (Bottleneck, (3, 24, 36, 3))
293 | }
294 |
295 | def __init__(self,
296 | groups=1,
297 | base_width=4,
298 | radix=2,
299 | reduction_factor=4,
300 | avg_down_stride=True,
301 | **kwargs):
302 | self.groups = groups
303 | self.base_width = base_width
304 | self.radix = radix
305 | self.reduction_factor = reduction_factor
306 | self.avg_down_stride = avg_down_stride
307 | super(ResNeSt, self).__init__(**kwargs)
308 |
309 | def make_res_layer(self, **kwargs):
310 | """Pack all blocks in a stage into a ``ResLayer``."""
311 | return ResLayer(
312 | groups=self.groups,
313 | base_width=self.base_width,
314 | base_channels=self.base_channels,
315 | radix=self.radix,
316 | reduction_factor=self.reduction_factor,
317 | avg_down_stride=self.avg_down_stride,
318 | **kwargs)
319 |
--------------------------------------------------------------------------------
/segmentation/backbones/resnext.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import math
3 |
4 | from mmcv.cnn import build_conv_layer, build_norm_layer
5 |
6 | from ..builder import BACKBONES
7 | from ..utils import ResLayer
8 | from .resnet import Bottleneck as _Bottleneck
9 | from .resnet import ResNet
10 |
11 |
12 | class Bottleneck(_Bottleneck):
13 | """Bottleneck block for ResNeXt.
14 |
15 | If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is
16 | "caffe", the stride-two layer is the first 1x1 conv layer.
17 | """
18 |
19 | def __init__(self,
20 | inplanes,
21 | planes,
22 | groups=1,
23 | base_width=4,
24 | base_channels=64,
25 | **kwargs):
26 | super(Bottleneck, self).__init__(inplanes, planes, **kwargs)
27 |
28 | if groups == 1:
29 | width = self.planes
30 | else:
31 | width = math.floor(self.planes *
32 | (base_width / base_channels)) * groups
33 |
34 | self.norm1_name, norm1 = build_norm_layer(
35 | self.norm_cfg, width, postfix=1)
36 | self.norm2_name, norm2 = build_norm_layer(
37 | self.norm_cfg, width, postfix=2)
38 | self.norm3_name, norm3 = build_norm_layer(
39 | self.norm_cfg, self.planes * self.expansion, postfix=3)
40 |
41 | self.conv1 = build_conv_layer(
42 | self.conv_cfg,
43 | self.inplanes,
44 | width,
45 | kernel_size=1,
46 | stride=self.conv1_stride,
47 | bias=False)
48 | self.add_module(self.norm1_name, norm1)
49 | fallback_on_stride = False
50 | self.with_modulated_dcn = False
51 | if self.with_dcn:
52 | fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
53 | if not self.with_dcn or fallback_on_stride:
54 | self.conv2 = build_conv_layer(
55 | self.conv_cfg,
56 | width,
57 | width,
58 | kernel_size=3,
59 | stride=self.conv2_stride,
60 | padding=self.dilation,
61 | dilation=self.dilation,
62 | groups=groups,
63 | bias=False)
64 | else:
65 | assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
66 | self.conv2 = build_conv_layer(
67 | self.dcn,
68 | width,
69 | width,
70 | kernel_size=3,
71 | stride=self.conv2_stride,
72 | padding=self.dilation,
73 | dilation=self.dilation,
74 | groups=groups,
75 | bias=False)
76 |
77 | self.add_module(self.norm2_name, norm2)
78 | self.conv3 = build_conv_layer(
79 | self.conv_cfg,
80 | width,
81 | self.planes * self.expansion,
82 | kernel_size=1,
83 | bias=False)
84 | self.add_module(self.norm3_name, norm3)
85 |
86 |
87 | @BACKBONES.register_module()
88 | class ResNeXt(ResNet):
89 | """ResNeXt backbone.
90 |
91 | This backbone is the implementation of `Aggregated
92 | Residual Transformations for Deep Neural
93 | Networks `_.
94 |
95 | Args:
96 | depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
97 | in_channels (int): Number of input image channels. Normally 3.
98 | num_stages (int): Resnet stages, normally 4.
99 | groups (int): Group of resnext.
100 | base_width (int): Base width of resnext.
101 | strides (Sequence[int]): Strides of the first block of each stage.
102 | dilations (Sequence[int]): Dilation of each stage.
103 | out_indices (Sequence[int]): Output from which stages.
104 | style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
105 | layer is the 3x3 conv layer, otherwise the stride-two layer is
106 | the first 1x1 conv layer.
107 | frozen_stages (int): Stages to be frozen (all param fixed). -1 means
108 | not freezing any parameters.
109 | norm_cfg (dict): dictionary to construct and config norm layer.
110 | norm_eval (bool): Whether to set norm layers to eval mode, namely,
111 | freeze running stats (mean and var). Note: Effect on Batch Norm
112 | and its variants only.
113 | with_cp (bool): Use checkpoint or not. Using checkpoint will save some
114 | memory while slowing down the training speed.
115 | zero_init_residual (bool): whether to use zero init for last norm layer
116 | in resblocks to let them behave as identity.
117 |
118 | Example:
119 | >>> from mmseg.models import ResNeXt
120 | >>> import torch
121 | >>> self = ResNeXt(depth=50)
122 | >>> self.eval()
123 | >>> inputs = torch.rand(1, 3, 32, 32)
124 | >>> level_outputs = self.forward(inputs)
125 | >>> for level_out in level_outputs:
126 | ... print(tuple(level_out.shape))
127 | (1, 256, 8, 8)
128 | (1, 512, 4, 4)
129 | (1, 1024, 2, 2)
130 | (1, 2048, 1, 1)
131 | """
132 |
133 | arch_settings = {
134 | 50: (Bottleneck, (3, 4, 6, 3)),
135 | 101: (Bottleneck, (3, 4, 23, 3)),
136 | 152: (Bottleneck, (3, 8, 36, 3))
137 | }
138 |
139 | def __init__(self, groups=1, base_width=4, **kwargs):
140 | self.groups = groups
141 | self.base_width = base_width
142 | super(ResNeXt, self).__init__(**kwargs)
143 |
144 | def make_res_layer(self, **kwargs):
145 | """Pack all blocks in a stage into a ``ResLayer``"""
146 | return ResLayer(
147 | groups=self.groups,
148 | base_width=self.base_width,
149 | base_channels=self.base_channels,
150 | **kwargs)
151 |
--------------------------------------------------------------------------------
/segmentation/configs/_base_/datasets/ade20k.py:
--------------------------------------------------------------------------------
1 | # dataset settings
2 | dataset_type = 'ADE20KDataset'
3 | data_root = '/data/ADEChallengeData2016'
4 | img_norm_cfg = dict(
5 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
6 | crop_size = (512, 512)
7 | train_pipeline = [
8 | dict(type='LoadImageFromFile'),
9 | dict(type='LoadAnnotations', reduce_zero_label=True),
10 | dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
11 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
12 | dict(type='RandomFlip', prob=0.5),
13 | dict(type='PhotoMetricDistortion'),
14 | dict(type='Normalize', **img_norm_cfg),
15 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
16 | dict(type='DefaultFormatBundle'),
17 | dict(type='Collect', keys=['img', 'gt_semantic_seg']),
18 | ]
19 | test_pipeline = [
20 | dict(type='LoadImageFromFile'),
21 | dict(
22 | type='MultiScaleFlipAug',
23 | img_scale=(2048, 512),
24 | # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
25 | flip=False,
26 | transforms=[
27 | dict(type='Resize', keep_ratio=True),
28 | dict(type='RandomFlip'),
29 | dict(type='Normalize', **img_norm_cfg),
30 | dict(type='ImageToTensor', keys=['img']),
31 | dict(type='Collect', keys=['img']),
32 | ])
33 | ]
34 | data = dict(
35 | samples_per_gpu=4,
36 | workers_per_gpu=4,
37 | train=dict(
38 | type=dataset_type,
39 | data_root=data_root,
40 | img_dir='images/training',
41 | ann_dir='annotations/training',
42 | pipeline=train_pipeline),
43 | val=dict(
44 | type=dataset_type,
45 | data_root=data_root,
46 | img_dir='images/validation',
47 | ann_dir='annotations/validation',
48 | pipeline=test_pipeline),
49 | test=dict(
50 | type=dataset_type,
51 | data_root=data_root,
52 | img_dir='images/validation',
53 | ann_dir='annotations/validation',
54 | pipeline=test_pipeline))
55 |
--------------------------------------------------------------------------------
/segmentation/configs/_base_/default_runtime.py:
--------------------------------------------------------------------------------
1 | # yapf:disable
2 | log_config = dict(
3 | interval=50,
4 | hooks=[
5 | dict(type='TextLoggerHook', by_epoch=False),
6 | # dict(type='TensorboardLoggerHook')
7 | ])
8 | # yapf:enable
9 | dist_params = dict(backend='nccl')
10 | log_level = 'INFO'
11 | load_from = None
12 | resume_from = None
13 | workflow = [('train', 1)]
14 | cudnn_benchmark = True
15 |
--------------------------------------------------------------------------------
/segmentation/configs/_base_/schedules/schedule_40k.py:
--------------------------------------------------------------------------------
1 | # optimizer
2 | optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
3 | optimizer_config = dict()
4 | # learning policy
5 | lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False)
6 | # runtime settings
7 | runner = dict(type='IterBasedRunner', max_iters=40000)
8 | checkpoint_config = dict(by_epoch=False, interval=4000)
9 | evaluation = dict(interval=4000, metric='mIoU', pre_eval=True)
10 |
--------------------------------------------------------------------------------
/segmentation/configs/semfpn_vit-s16_512x512_40k_ade20k.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/datasets/ade20k.py',
3 | '../_base_/default_runtime.py', '../_base_/schedules/schedule_40k.py'
4 | ]
5 |
6 | # model settings
7 | norm_cfg = dict(type='SyncBN', requires_grad=True)
8 | model = dict(
9 | type='EncoderDecoder',
10 | pretrained='pretrain/jx_vit_base_p16_224-80ecf9dd.pth',
11 | backbone=dict(
12 | type='SelfPatch_ViT',
13 | img_size=(512, 512),
14 | patch_size=16,
15 | in_channels=3,
16 | embed_dims=384,
17 | num_layers=12,
18 | num_heads=6,
19 | mlp_ratio=4,
20 | out_indices=(2, 5, 8, 11),
21 | qkv_bias=True,
22 | drop_rate=0.0,
23 | attn_drop_rate=0.0,
24 | drop_path_rate=0.1,
25 | with_cls_token=True,
26 | norm_cfg=dict(type='LN', eps=1e-6),
27 | act_cfg=dict(type='GELU'),
28 | norm_eval=False,
29 | interpolate_mode='bicubic'),
30 | neck=dict(
31 | type='FPN',
32 | in_channels=[384, 384, 384, 384],
33 | out_channels=256,
34 | num_outs=4),
35 | decode_head=dict(
36 | type='FPNHead',
37 | in_channels=[256, 256, 256, 256],
38 | in_index=[0, 1, 2, 3],
39 | feature_strides=[4, 8, 16, 32],
40 | channels=128,
41 | dropout_ratio=0.1,
42 | num_classes=150,
43 | norm_cfg=norm_cfg,
44 | align_corners=False,
45 | loss_decode=dict(
46 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
47 | # model training and testing settings
48 | train_cfg=dict(),
49 | test_cfg=dict(mode='whole'))
50 |
51 |
52 | # AdamW optimizer, no weight decay for position embedding & layer norm
53 | # in backbone
54 | optimizer = dict(
55 | _delete_=True,
56 | type='AdamW',
57 | lr=0.00006,
58 | betas=(0.9, 0.999),
59 | weight_decay=0.01,
60 | paramwise_cfg=dict(
61 | custom_keys={
62 | 'pos_embed': dict(decay_mult=0.),
63 | 'cls_token': dict(decay_mult=0.),
64 | 'norm': dict(decay_mult=0.)
65 | }))
66 |
67 | lr_config = dict(
68 | _delete_=True,
69 | policy='poly',
70 | warmup='linear',
71 | warmup_iters=1500,
72 | warmup_ratio=1e-6,
73 | power=1.0,
74 | min_lr=0.0,
75 | by_epoch=False)
76 |
77 | # By default, models are trained on 8 GPUs with 2 images per GPU
78 | data = dict(samples_per_gpu=4)
79 |
--------------------------------------------------------------------------------
/segmentation/tools/analyze_logs.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | """Modified from https://github.com/open-
3 | mmlab/mmdetection/blob/master/tools/analysis_tools/analyze_logs.py."""
4 | import argparse
5 | import json
6 | from collections import defaultdict
7 |
8 | import matplotlib.pyplot as plt
9 | import seaborn as sns
10 |
11 |
12 | def plot_curve(log_dicts, args):
13 | if args.backend is not None:
14 | plt.switch_backend(args.backend)
15 | sns.set_style(args.style)
16 | # if legend is None, use {filename}_{key} as legend
17 | legend = args.legend
18 | if legend is None:
19 | legend = []
20 | for json_log in args.json_logs:
21 | for metric in args.keys:
22 | legend.append(f'{json_log}_{metric}')
23 | assert len(legend) == (len(args.json_logs) * len(args.keys))
24 | metrics = args.keys
25 |
26 | num_metrics = len(metrics)
27 | for i, log_dict in enumerate(log_dicts):
28 | epochs = list(log_dict.keys())
29 | for j, metric in enumerate(metrics):
30 | print(f'plot curve of {args.json_logs[i]}, metric is {metric}')
31 | plot_epochs = []
32 | plot_iters = []
33 | plot_values = []
34 | # In some log files, iters number is not correct, `pre_iter` is
35 | # used to prevent generate wrong lines.
36 | pre_iter = -1
37 | for epoch in epochs:
38 | epoch_logs = log_dict[epoch]
39 | if metric not in epoch_logs.keys():
40 | continue
41 | if metric in ['mIoU', 'mAcc', 'aAcc']:
42 | plot_epochs.append(epoch)
43 | plot_values.append(epoch_logs[metric][0])
44 | else:
45 | for idx in range(len(epoch_logs[metric])):
46 | if pre_iter > epoch_logs['iter'][idx]:
47 | continue
48 | pre_iter = epoch_logs['iter'][idx]
49 | plot_iters.append(epoch_logs['iter'][idx])
50 | plot_values.append(epoch_logs[metric][idx])
51 | ax = plt.gca()
52 | label = legend[i * num_metrics + j]
53 | if metric in ['mIoU', 'mAcc', 'aAcc']:
54 | ax.set_xticks(plot_epochs)
55 | plt.xlabel('epoch')
56 | plt.plot(plot_epochs, plot_values, label=label, marker='o')
57 | else:
58 | plt.xlabel('iter')
59 | plt.plot(plot_iters, plot_values, label=label, linewidth=0.5)
60 | plt.legend()
61 | if args.title is not None:
62 | plt.title(args.title)
63 | if args.out is None:
64 | plt.show()
65 | else:
66 | print(f'save curve to: {args.out}')
67 | plt.savefig(args.out)
68 | plt.cla()
69 |
70 |
71 | def parse_args():
72 | parser = argparse.ArgumentParser(description='Analyze Json Log')
73 | parser.add_argument(
74 | 'json_logs',
75 | type=str,
76 | nargs='+',
77 | help='path of train log in json format')
78 | parser.add_argument(
79 | '--keys',
80 | type=str,
81 | nargs='+',
82 | default=['mIoU'],
83 | help='the metric that you want to plot')
84 | parser.add_argument('--title', type=str, help='title of figure')
85 | parser.add_argument(
86 | '--legend',
87 | type=str,
88 | nargs='+',
89 | default=None,
90 | help='legend of each plot')
91 | parser.add_argument(
92 | '--backend', type=str, default=None, help='backend of plt')
93 | parser.add_argument(
94 | '--style', type=str, default='dark', help='style of plt')
95 | parser.add_argument('--out', type=str, default=None)
96 | args = parser.parse_args()
97 | return args
98 |
99 |
100 | def load_json_logs(json_logs):
101 | # load and convert json_logs to log_dict, key is epoch, value is a sub dict
102 | # keys of sub dict is different metrics
103 | # value of sub dict is a list of corresponding values of all iterations
104 | log_dicts = [dict() for _ in json_logs]
105 | for json_log, log_dict in zip(json_logs, log_dicts):
106 | with open(json_log, 'r') as log_file:
107 | for line in log_file:
108 | log = json.loads(line.strip())
109 | # skip lines without `epoch` field
110 | if 'epoch' not in log:
111 | continue
112 | epoch = log.pop('epoch')
113 | if epoch not in log_dict:
114 | log_dict[epoch] = defaultdict(list)
115 | for k, v in log.items():
116 | log_dict[epoch][k].append(v)
117 | return log_dicts
118 |
119 |
120 | def main():
121 | args = parse_args()
122 | json_logs = args.json_logs
123 | for json_log in json_logs:
124 | assert json_log.endswith('.json')
125 | log_dicts = load_json_logs(json_logs)
126 | plot_curve(log_dicts, args)
127 |
128 |
129 | if __name__ == '__main__':
130 | main()
131 |
--------------------------------------------------------------------------------
/segmentation/tools/benchmark.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import argparse
3 | import time
4 |
5 | import torch
6 | from mmcv import Config
7 | from mmcv.parallel import MMDataParallel
8 | from mmcv.runner import load_checkpoint, wrap_fp16_model
9 |
10 | from mmseg.datasets import build_dataloader, build_dataset
11 | from mmseg.models import build_segmentor
12 |
13 |
14 | def parse_args():
15 | parser = argparse.ArgumentParser(description='MMSeg benchmark a model')
16 | parser.add_argument('config', help='test config file path')
17 | parser.add_argument('checkpoint', help='checkpoint file')
18 | parser.add_argument(
19 | '--log-interval', type=int, default=50, help='interval of logging')
20 | args = parser.parse_args()
21 | return args
22 |
23 |
24 | def main():
25 | args = parse_args()
26 |
27 | cfg = Config.fromfile(args.config)
28 | # set cudnn_benchmark
29 | torch.backends.cudnn.benchmark = False
30 | cfg.model.pretrained = None
31 | cfg.data.test.test_mode = True
32 |
33 | # build the dataloader
34 | # TODO: support multiple images per gpu (only minor changes are needed)
35 | dataset = build_dataset(cfg.data.test)
36 | data_loader = build_dataloader(
37 | dataset,
38 | samples_per_gpu=1,
39 | workers_per_gpu=cfg.data.workers_per_gpu,
40 | dist=False,
41 | shuffle=False)
42 |
43 | # build the model and load checkpoint
44 | cfg.model.train_cfg = None
45 | model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg'))
46 | fp16_cfg = cfg.get('fp16', None)
47 | if fp16_cfg is not None:
48 | wrap_fp16_model(model)
49 | load_checkpoint(model, args.checkpoint, map_location='cpu')
50 |
51 | model = MMDataParallel(model, device_ids=[0])
52 |
53 | model.eval()
54 |
55 | # the first several iterations may be very slow so skip them
56 | num_warmup = 5
57 | pure_inf_time = 0
58 | total_iters = 200
59 |
60 | # benchmark with 200 image and take the average
61 | for i, data in enumerate(data_loader):
62 |
63 | torch.cuda.synchronize()
64 | start_time = time.perf_counter()
65 |
66 | with torch.no_grad():
67 | model(return_loss=False, rescale=True, **data)
68 |
69 | torch.cuda.synchronize()
70 | elapsed = time.perf_counter() - start_time
71 |
72 | if i >= num_warmup:
73 | pure_inf_time += elapsed
74 | if (i + 1) % args.log_interval == 0:
75 | fps = (i + 1 - num_warmup) / pure_inf_time
76 | print(f'Done image [{i + 1:<3}/ {total_iters}], '
77 | f'fps: {fps:.2f} img / s')
78 |
79 | if (i + 1) == total_iters:
80 | fps = (i + 1 - num_warmup) / pure_inf_time
81 | print(f'Overall fps: {fps:.2f} img / s')
82 | break
83 |
84 |
85 | if __name__ == '__main__':
86 | main()
87 |
--------------------------------------------------------------------------------
/segmentation/tools/browse_dataset.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import warnings
4 | from pathlib import Path
5 |
6 | import mmcv
7 | import numpy as np
8 | from mmcv import Config
9 |
10 | from mmseg.datasets.builder import build_dataset
11 |
12 |
13 | def parse_args():
14 | parser = argparse.ArgumentParser(description='Browse a dataset')
15 | parser.add_argument('config', help='train config file path')
16 | parser.add_argument(
17 | '--show-origin',
18 | default=False,
19 | action='store_true',
20 | help='if True, omit all augmentation in pipeline,'
21 | ' show origin image and seg map')
22 | parser.add_argument(
23 | '--skip-type',
24 | type=str,
25 | nargs='+',
26 | default=['DefaultFormatBundle', 'Normalize', 'Collect'],
27 | help='skip some useless pipeline,if `show-origin` is true, '
28 | 'all pipeline except `Load` will be skipped')
29 | parser.add_argument(
30 | '--output-dir',
31 | default='./output',
32 | type=str,
33 | help='If there is no display interface, you can save it')
34 | parser.add_argument('--show', default=False, action='store_true')
35 | parser.add_argument(
36 | '--show-interval',
37 | type=int,
38 | default=999,
39 | help='the interval of show (ms)')
40 | parser.add_argument(
41 | '--opacity',
42 | type=float,
43 | default=0.5,
44 | help='the opacity of semantic map')
45 | args = parser.parse_args()
46 | return args
47 |
48 |
49 | def imshow_semantic(img,
50 | seg,
51 | class_names,
52 | palette=None,
53 | win_name='',
54 | show=False,
55 | wait_time=0,
56 | out_file=None,
57 | opacity=0.5):
58 | """Draw `result` over `img`.
59 |
60 | Args:
61 | img (str or Tensor): The image to be displayed.
62 | seg (Tensor): The semantic segmentation results to draw over
63 | `img`.
64 | class_names (list[str]): Names of each classes.
65 | palette (list[list[int]]] | np.ndarray | None): The palette of
66 | segmentation map. If None is given, random palette will be
67 | generated. Default: None
68 | win_name (str): The window name.
69 | wait_time (int): Value of waitKey param.
70 | Default: 0.
71 | show (bool): Whether to show the image.
72 | Default: False.
73 | out_file (str or None): The filename to write the image.
74 | Default: None.
75 | opacity(float): Opacity of painted segmentation map.
76 | Default 0.5.
77 | Must be in (0, 1] range.
78 | Returns:
79 | img (Tensor): Only if not `show` or `out_file`
80 | """
81 | img = mmcv.imread(img)
82 | img = img.copy()
83 | if palette is None:
84 | palette = np.random.randint(0, 255, size=(len(class_names), 3))
85 | palette = np.array(palette)
86 | assert palette.shape[0] == len(class_names)
87 | assert palette.shape[1] == 3
88 | assert len(palette.shape) == 2
89 | assert 0 < opacity <= 1.0
90 | color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
91 | for label, color in enumerate(palette):
92 | color_seg[seg == label, :] = color
93 | # convert to BGR
94 | color_seg = color_seg[..., ::-1]
95 |
96 | img = img * (1 - opacity) + color_seg * opacity
97 | img = img.astype(np.uint8)
98 | # if out_file specified, do not show image in window
99 | if out_file is not None:
100 | show = False
101 |
102 | if show:
103 | mmcv.imshow(img, win_name, wait_time)
104 | if out_file is not None:
105 | mmcv.imwrite(img, out_file)
106 |
107 | if not (show or out_file):
108 | warnings.warn('show==False and out_file is not specified, only '
109 | 'result image will be returned')
110 | return img
111 |
112 |
113 | def _retrieve_data_cfg(_data_cfg, skip_type, show_origin):
114 | if show_origin is True:
115 | # only keep pipeline of Loading data and ann
116 | _data_cfg['pipeline'] = [
117 | x for x in _data_cfg.pipeline if 'Load' in x['type']
118 | ]
119 | else:
120 | _data_cfg['pipeline'] = [
121 | x for x in _data_cfg.pipeline if x['type'] not in skip_type
122 | ]
123 |
124 |
125 | def retrieve_data_cfg(config_path, skip_type, show_origin=False):
126 | cfg = Config.fromfile(config_path)
127 | train_data_cfg = cfg.data.train
128 | if isinstance(train_data_cfg, list):
129 | for _data_cfg in train_data_cfg:
130 | if 'pipeline' in _data_cfg:
131 | _retrieve_data_cfg(_data_cfg, skip_type, show_origin)
132 | elif 'dataset' in _data_cfg:
133 | _retrieve_data_cfg(_data_cfg['dataset'], skip_type,
134 | show_origin)
135 | else:
136 | raise ValueError
137 | elif 'dataset' in train_data_cfg:
138 | _retrieve_data_cfg(train_data_cfg['dataset'], skip_type, show_origin)
139 | else:
140 | _retrieve_data_cfg(train_data_cfg, skip_type, show_origin)
141 | return cfg
142 |
143 |
144 | def main():
145 | args = parse_args()
146 | cfg = retrieve_data_cfg(args.config, args.skip_type, args.show_origin)
147 | dataset = build_dataset(cfg.data.train)
148 | progress_bar = mmcv.ProgressBar(len(dataset))
149 | for item in dataset:
150 | filename = os.path.join(args.output_dir,
151 | Path(item['filename']).name
152 | ) if args.output_dir is not None else None
153 | imshow_semantic(
154 | item['img'],
155 | item['gt_semantic_seg'],
156 | dataset.CLASSES,
157 | dataset.PALETTE,
158 | show=args.show,
159 | wait_time=args.show_interval,
160 | out_file=filename,
161 | opacity=args.opacity,
162 | )
163 | progress_bar.update()
164 |
165 |
166 | if __name__ == '__main__':
167 | main()
168 |
--------------------------------------------------------------------------------
/segmentation/tools/convert_datasets/chase_db1.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import argparse
3 | import os
4 | import os.path as osp
5 | import tempfile
6 | import zipfile
7 |
8 | import mmcv
9 |
10 | CHASE_DB1_LEN = 28 * 3
11 | TRAINING_LEN = 60
12 |
13 |
14 | def parse_args():
15 | parser = argparse.ArgumentParser(
16 | description='Convert CHASE_DB1 dataset to mmsegmentation format')
17 | parser.add_argument('dataset_path', help='path of CHASEDB1.zip')
18 | parser.add_argument('--tmp_dir', help='path of the temporary directory')
19 | parser.add_argument('-o', '--out_dir', help='output path')
20 | args = parser.parse_args()
21 | return args
22 |
23 |
24 | def main():
25 | args = parse_args()
26 | dataset_path = args.dataset_path
27 | if args.out_dir is None:
28 | out_dir = osp.join('data', 'CHASE_DB1')
29 | else:
30 | out_dir = args.out_dir
31 |
32 | print('Making directories...')
33 | mmcv.mkdir_or_exist(out_dir)
34 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images'))
35 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'training'))
36 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'validation'))
37 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations'))
38 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'training'))
39 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation'))
40 |
41 | with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
42 | print('Extracting CHASEDB1.zip...')
43 | zip_file = zipfile.ZipFile(dataset_path)
44 | zip_file.extractall(tmp_dir)
45 |
46 | print('Generating training dataset...')
47 |
48 | assert len(os.listdir(tmp_dir)) == CHASE_DB1_LEN, \
49 | 'len(os.listdir(tmp_dir)) != {}'.format(CHASE_DB1_LEN)
50 |
51 | for img_name in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]:
52 | img = mmcv.imread(osp.join(tmp_dir, img_name))
53 | if osp.splitext(img_name)[1] == '.jpg':
54 | mmcv.imwrite(
55 | img,
56 | osp.join(out_dir, 'images', 'training',
57 | osp.splitext(img_name)[0] + '.png'))
58 | else:
59 | # The annotation img should be divided by 128, because some of
60 | # the annotation imgs are not standard. We should set a
61 | # threshold to convert the nonstandard annotation imgs. The
62 | # value divided by 128 is equivalent to '1 if value >= 128
63 | # else 0'
64 | mmcv.imwrite(
65 | img[:, :, 0] // 128,
66 | osp.join(out_dir, 'annotations', 'training',
67 | osp.splitext(img_name)[0] + '.png'))
68 |
69 | for img_name in sorted(os.listdir(tmp_dir))[TRAINING_LEN:]:
70 | img = mmcv.imread(osp.join(tmp_dir, img_name))
71 | if osp.splitext(img_name)[1] == '.jpg':
72 | mmcv.imwrite(
73 | img,
74 | osp.join(out_dir, 'images', 'validation',
75 | osp.splitext(img_name)[0] + '.png'))
76 | else:
77 | mmcv.imwrite(
78 | img[:, :, 0] // 128,
79 | osp.join(out_dir, 'annotations', 'validation',
80 | osp.splitext(img_name)[0] + '.png'))
81 |
82 | print('Removing the temporary files...')
83 |
84 | print('Done!')
85 |
86 |
87 | if __name__ == '__main__':
88 | main()
89 |
--------------------------------------------------------------------------------
/segmentation/tools/convert_datasets/cityscapes.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import argparse
3 | import os.path as osp
4 |
5 | import mmcv
6 | from cityscapesscripts.preparation.json2labelImg import json2labelImg
7 |
8 |
9 | def convert_json_to_label(json_file):
10 | label_file = json_file.replace('_polygons.json', '_labelTrainIds.png')
11 | json2labelImg(json_file, label_file, 'trainIds')
12 |
13 |
14 | def parse_args():
15 | parser = argparse.ArgumentParser(
16 | description='Convert Cityscapes annotations to TrainIds')
17 | parser.add_argument('cityscapes_path', help='cityscapes data path')
18 | parser.add_argument('--gt-dir', default='gtFine', type=str)
19 | parser.add_argument('-o', '--out-dir', help='output path')
20 | parser.add_argument(
21 | '--nproc', default=1, type=int, help='number of process')
22 | args = parser.parse_args()
23 | return args
24 |
25 |
26 | def main():
27 | args = parse_args()
28 | cityscapes_path = args.cityscapes_path
29 | out_dir = args.out_dir if args.out_dir else cityscapes_path
30 | mmcv.mkdir_or_exist(out_dir)
31 |
32 | gt_dir = osp.join(cityscapes_path, args.gt_dir)
33 |
34 | poly_files = []
35 | for poly in mmcv.scandir(gt_dir, '_polygons.json', recursive=True):
36 | poly_file = osp.join(gt_dir, poly)
37 | poly_files.append(poly_file)
38 | if args.nproc > 1:
39 | mmcv.track_parallel_progress(convert_json_to_label, poly_files,
40 | args.nproc)
41 | else:
42 | mmcv.track_progress(convert_json_to_label, poly_files)
43 |
44 | split_names = ['train', 'val', 'test']
45 |
46 | for split in split_names:
47 | filenames = []
48 | for poly in mmcv.scandir(
49 | osp.join(gt_dir, split), '_polygons.json', recursive=True):
50 | filenames.append(poly.replace('_gtFine_polygons.json', ''))
51 | with open(osp.join(out_dir, f'{split}.txt'), 'w') as f:
52 | f.writelines(f + '\n' for f in filenames)
53 |
54 |
55 | if __name__ == '__main__':
56 | main()
57 |
--------------------------------------------------------------------------------
/segmentation/tools/convert_datasets/coco_stuff10k.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os.path as osp
3 | import shutil
4 | from functools import partial
5 |
6 | import mmcv
7 | import numpy as np
8 | from PIL import Image
9 | from scipy.io import loadmat
10 |
11 | COCO_LEN = 10000
12 |
13 | clsID_to_trID = {
14 | 0: 0,
15 | 1: 1,
16 | 2: 2,
17 | 3: 3,
18 | 4: 4,
19 | 5: 5,
20 | 6: 6,
21 | 7: 7,
22 | 8: 8,
23 | 9: 9,
24 | 10: 10,
25 | 11: 11,
26 | 13: 12,
27 | 14: 13,
28 | 15: 14,
29 | 16: 15,
30 | 17: 16,
31 | 18: 17,
32 | 19: 18,
33 | 20: 19,
34 | 21: 20,
35 | 22: 21,
36 | 23: 22,
37 | 24: 23,
38 | 25: 24,
39 | 27: 25,
40 | 28: 26,
41 | 31: 27,
42 | 32: 28,
43 | 33: 29,
44 | 34: 30,
45 | 35: 31,
46 | 36: 32,
47 | 37: 33,
48 | 38: 34,
49 | 39: 35,
50 | 40: 36,
51 | 41: 37,
52 | 42: 38,
53 | 43: 39,
54 | 44: 40,
55 | 46: 41,
56 | 47: 42,
57 | 48: 43,
58 | 49: 44,
59 | 50: 45,
60 | 51: 46,
61 | 52: 47,
62 | 53: 48,
63 | 54: 49,
64 | 55: 50,
65 | 56: 51,
66 | 57: 52,
67 | 58: 53,
68 | 59: 54,
69 | 60: 55,
70 | 61: 56,
71 | 62: 57,
72 | 63: 58,
73 | 64: 59,
74 | 65: 60,
75 | 67: 61,
76 | 70: 62,
77 | 72: 63,
78 | 73: 64,
79 | 74: 65,
80 | 75: 66,
81 | 76: 67,
82 | 77: 68,
83 | 78: 69,
84 | 79: 70,
85 | 80: 71,
86 | 81: 72,
87 | 82: 73,
88 | 84: 74,
89 | 85: 75,
90 | 86: 76,
91 | 87: 77,
92 | 88: 78,
93 | 89: 79,
94 | 90: 80,
95 | 92: 81,
96 | 93: 82,
97 | 94: 83,
98 | 95: 84,
99 | 96: 85,
100 | 97: 86,
101 | 98: 87,
102 | 99: 88,
103 | 100: 89,
104 | 101: 90,
105 | 102: 91,
106 | 103: 92,
107 | 104: 93,
108 | 105: 94,
109 | 106: 95,
110 | 107: 96,
111 | 108: 97,
112 | 109: 98,
113 | 110: 99,
114 | 111: 100,
115 | 112: 101,
116 | 113: 102,
117 | 114: 103,
118 | 115: 104,
119 | 116: 105,
120 | 117: 106,
121 | 118: 107,
122 | 119: 108,
123 | 120: 109,
124 | 121: 110,
125 | 122: 111,
126 | 123: 112,
127 | 124: 113,
128 | 125: 114,
129 | 126: 115,
130 | 127: 116,
131 | 128: 117,
132 | 129: 118,
133 | 130: 119,
134 | 131: 120,
135 | 132: 121,
136 | 133: 122,
137 | 134: 123,
138 | 135: 124,
139 | 136: 125,
140 | 137: 126,
141 | 138: 127,
142 | 139: 128,
143 | 140: 129,
144 | 141: 130,
145 | 142: 131,
146 | 143: 132,
147 | 144: 133,
148 | 145: 134,
149 | 146: 135,
150 | 147: 136,
151 | 148: 137,
152 | 149: 138,
153 | 150: 139,
154 | 151: 140,
155 | 152: 141,
156 | 153: 142,
157 | 154: 143,
158 | 155: 144,
159 | 156: 145,
160 | 157: 146,
161 | 158: 147,
162 | 159: 148,
163 | 160: 149,
164 | 161: 150,
165 | 162: 151,
166 | 163: 152,
167 | 164: 153,
168 | 165: 154,
169 | 166: 155,
170 | 167: 156,
171 | 168: 157,
172 | 169: 158,
173 | 170: 159,
174 | 171: 160,
175 | 172: 161,
176 | 173: 162,
177 | 174: 163,
178 | 175: 164,
179 | 176: 165,
180 | 177: 166,
181 | 178: 167,
182 | 179: 168,
183 | 180: 169,
184 | 181: 170,
185 | 182: 171
186 | }
187 |
188 |
189 | def convert_to_trainID(tuple_path, in_img_dir, in_ann_dir, out_img_dir,
190 | out_mask_dir, is_train):
191 | imgpath, maskpath = tuple_path
192 | shutil.copyfile(
193 | osp.join(in_img_dir, imgpath),
194 | osp.join(out_img_dir, 'train2014', imgpath) if is_train else osp.join(
195 | out_img_dir, 'test2014', imgpath))
196 | annotate = loadmat(osp.join(in_ann_dir, maskpath))
197 | mask = annotate['S'].astype(np.uint8)
198 | mask_copy = mask.copy()
199 | for clsID, trID in clsID_to_trID.items():
200 | mask_copy[mask == clsID] = trID
201 | seg_filename = osp.join(out_mask_dir, 'train2014',
202 | maskpath.split('.')[0] +
203 | '_labelTrainIds.png') if is_train else osp.join(
204 | out_mask_dir, 'test2014',
205 | maskpath.split('.')[0] + '_labelTrainIds.png')
206 | Image.fromarray(mask_copy).save(seg_filename, 'PNG')
207 |
208 |
209 | def generate_coco_list(folder):
210 | train_list = osp.join(folder, 'imageLists', 'train.txt')
211 | test_list = osp.join(folder, 'imageLists', 'test.txt')
212 | train_paths = []
213 | test_paths = []
214 |
215 | with open(train_list) as f:
216 | for filename in f:
217 | basename = filename.strip()
218 | imgpath = basename + '.jpg'
219 | maskpath = basename + '.mat'
220 | train_paths.append((imgpath, maskpath))
221 |
222 | with open(test_list) as f:
223 | for filename in f:
224 | basename = filename.strip()
225 | imgpath = basename + '.jpg'
226 | maskpath = basename + '.mat'
227 | test_paths.append((imgpath, maskpath))
228 |
229 | return train_paths, test_paths
230 |
231 |
232 | def parse_args():
233 | parser = argparse.ArgumentParser(
234 | description=\
235 | 'Convert COCO Stuff 10k annotations to mmsegmentation format') # noqa
236 | parser.add_argument('coco_path', help='coco stuff path')
237 | parser.add_argument('-o', '--out_dir', help='output path')
238 | parser.add_argument(
239 | '--nproc', default=16, type=int, help='number of process')
240 | args = parser.parse_args()
241 | return args
242 |
243 |
244 | def main():
245 | args = parse_args()
246 | coco_path = args.coco_path
247 | nproc = args.nproc
248 |
249 | out_dir = args.out_dir or coco_path
250 | out_img_dir = osp.join(out_dir, 'images')
251 | out_mask_dir = osp.join(out_dir, 'annotations')
252 |
253 | mmcv.mkdir_or_exist(osp.join(out_img_dir, 'train2014'))
254 | mmcv.mkdir_or_exist(osp.join(out_img_dir, 'test2014'))
255 | mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'train2014'))
256 | mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'test2014'))
257 |
258 | train_list, test_list = generate_coco_list(coco_path)
259 | assert (len(train_list) +
260 | len(test_list)) == COCO_LEN, 'Wrong length of list {} & {}'.format(
261 | len(train_list), len(test_list))
262 |
263 | if args.nproc > 1:
264 | mmcv.track_parallel_progress(
265 | partial(
266 | convert_to_trainID,
267 | in_img_dir=osp.join(coco_path, 'images'),
268 | in_ann_dir=osp.join(coco_path, 'annotations'),
269 | out_img_dir=out_img_dir,
270 | out_mask_dir=out_mask_dir,
271 | is_train=True),
272 | train_list,
273 | nproc=nproc)
274 | mmcv.track_parallel_progress(
275 | partial(
276 | convert_to_trainID,
277 | in_img_dir=osp.join(coco_path, 'images'),
278 | in_ann_dir=osp.join(coco_path, 'annotations'),
279 | out_img_dir=out_img_dir,
280 | out_mask_dir=out_mask_dir,
281 | is_train=False),
282 | test_list,
283 | nproc=nproc)
284 | else:
285 | mmcv.track_progress(
286 | partial(
287 | convert_to_trainID,
288 | in_img_dir=osp.join(coco_path, 'images'),
289 | in_ann_dir=osp.join(coco_path, 'annotations'),
290 | out_img_dir=out_img_dir,
291 | out_mask_dir=out_mask_dir,
292 | is_train=True), train_list)
293 | mmcv.track_progress(
294 | partial(
295 | convert_to_trainID,
296 | in_img_dir=osp.join(coco_path, 'images'),
297 | in_ann_dir=osp.join(coco_path, 'annotations'),
298 | out_img_dir=out_img_dir,
299 | out_mask_dir=out_mask_dir,
300 | is_train=False), test_list)
301 |
302 | print('Done!')
303 |
304 |
305 | if __name__ == '__main__':
306 | main()
307 |
--------------------------------------------------------------------------------
/segmentation/tools/convert_datasets/coco_stuff164k.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os.path as osp
3 | import shutil
4 | from functools import partial
5 | from glob import glob
6 |
7 | import mmcv
8 | import numpy as np
9 | from PIL import Image
10 |
11 | COCO_LEN = 123287
12 |
13 | clsID_to_trID = {
14 | 0: 0,
15 | 1: 1,
16 | 2: 2,
17 | 3: 3,
18 | 4: 4,
19 | 5: 5,
20 | 6: 6,
21 | 7: 7,
22 | 8: 8,
23 | 9: 9,
24 | 10: 10,
25 | 12: 11,
26 | 13: 12,
27 | 14: 13,
28 | 15: 14,
29 | 16: 15,
30 | 17: 16,
31 | 18: 17,
32 | 19: 18,
33 | 20: 19,
34 | 21: 20,
35 | 22: 21,
36 | 23: 22,
37 | 24: 23,
38 | 26: 24,
39 | 27: 25,
40 | 30: 26,
41 | 31: 27,
42 | 32: 28,
43 | 33: 29,
44 | 34: 30,
45 | 35: 31,
46 | 36: 32,
47 | 37: 33,
48 | 38: 34,
49 | 39: 35,
50 | 40: 36,
51 | 41: 37,
52 | 42: 38,
53 | 43: 39,
54 | 45: 40,
55 | 46: 41,
56 | 47: 42,
57 | 48: 43,
58 | 49: 44,
59 | 50: 45,
60 | 51: 46,
61 | 52: 47,
62 | 53: 48,
63 | 54: 49,
64 | 55: 50,
65 | 56: 51,
66 | 57: 52,
67 | 58: 53,
68 | 59: 54,
69 | 60: 55,
70 | 61: 56,
71 | 62: 57,
72 | 63: 58,
73 | 64: 59,
74 | 66: 60,
75 | 69: 61,
76 | 71: 62,
77 | 72: 63,
78 | 73: 64,
79 | 74: 65,
80 | 75: 66,
81 | 76: 67,
82 | 77: 68,
83 | 78: 69,
84 | 79: 70,
85 | 80: 71,
86 | 81: 72,
87 | 83: 73,
88 | 84: 74,
89 | 85: 75,
90 | 86: 76,
91 | 87: 77,
92 | 88: 78,
93 | 89: 79,
94 | 91: 80,
95 | 92: 81,
96 | 93: 82,
97 | 94: 83,
98 | 95: 84,
99 | 96: 85,
100 | 97: 86,
101 | 98: 87,
102 | 99: 88,
103 | 100: 89,
104 | 101: 90,
105 | 102: 91,
106 | 103: 92,
107 | 104: 93,
108 | 105: 94,
109 | 106: 95,
110 | 107: 96,
111 | 108: 97,
112 | 109: 98,
113 | 110: 99,
114 | 111: 100,
115 | 112: 101,
116 | 113: 102,
117 | 114: 103,
118 | 115: 104,
119 | 116: 105,
120 | 117: 106,
121 | 118: 107,
122 | 119: 108,
123 | 120: 109,
124 | 121: 110,
125 | 122: 111,
126 | 123: 112,
127 | 124: 113,
128 | 125: 114,
129 | 126: 115,
130 | 127: 116,
131 | 128: 117,
132 | 129: 118,
133 | 130: 119,
134 | 131: 120,
135 | 132: 121,
136 | 133: 122,
137 | 134: 123,
138 | 135: 124,
139 | 136: 125,
140 | 137: 126,
141 | 138: 127,
142 | 139: 128,
143 | 140: 129,
144 | 141: 130,
145 | 142: 131,
146 | 143: 132,
147 | 144: 133,
148 | 145: 134,
149 | 146: 135,
150 | 147: 136,
151 | 148: 137,
152 | 149: 138,
153 | 150: 139,
154 | 151: 140,
155 | 152: 141,
156 | 153: 142,
157 | 154: 143,
158 | 155: 144,
159 | 156: 145,
160 | 157: 146,
161 | 158: 147,
162 | 159: 148,
163 | 160: 149,
164 | 161: 150,
165 | 162: 151,
166 | 163: 152,
167 | 164: 153,
168 | 165: 154,
169 | 166: 155,
170 | 167: 156,
171 | 168: 157,
172 | 169: 158,
173 | 170: 159,
174 | 171: 160,
175 | 172: 161,
176 | 173: 162,
177 | 174: 163,
178 | 175: 164,
179 | 176: 165,
180 | 177: 166,
181 | 178: 167,
182 | 179: 168,
183 | 180: 169,
184 | 181: 170,
185 | 255: 255
186 | }
187 |
188 |
189 | def convert_to_trainID(maskpath, out_mask_dir, is_train):
190 | mask = np.array(Image.open(maskpath))
191 | mask_copy = mask.copy()
192 | for clsID, trID in clsID_to_trID.items():
193 | mask_copy[mask == clsID] = trID
194 | seg_filename = osp.join(
195 | out_mask_dir, 'train2017',
196 | osp.basename(maskpath).split('.')[0] +
197 | '_labelTrainIds.png') if is_train else osp.join(
198 | out_mask_dir, 'val2017',
199 | osp.basename(maskpath).split('.')[0] + '_labelTrainIds.png')
200 | Image.fromarray(mask_copy).save(seg_filename, 'PNG')
201 |
202 |
203 | def parse_args():
204 | parser = argparse.ArgumentParser(
205 | description=\
206 | 'Convert COCO Stuff 164k annotations to mmsegmentation format') # noqa
207 | parser.add_argument('coco_path', help='coco stuff path')
208 | parser.add_argument('-o', '--out_dir', help='output path')
209 | parser.add_argument(
210 | '--nproc', default=16, type=int, help='number of process')
211 | args = parser.parse_args()
212 | return args
213 |
214 |
215 | def main():
216 | args = parse_args()
217 | coco_path = args.coco_path
218 | nproc = args.nproc
219 |
220 | out_dir = args.out_dir or coco_path
221 | out_img_dir = osp.join(out_dir, 'images')
222 | out_mask_dir = osp.join(out_dir, 'annotations')
223 |
224 | mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'train2017'))
225 | mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'val2017'))
226 |
227 | if out_dir != coco_path:
228 | shutil.copytree(osp.join(coco_path, 'images'), out_img_dir)
229 |
230 | train_list = glob(osp.join(coco_path, 'annotations', 'train2017', '*.png'))
231 | train_list = [file for file in train_list if '_labelTrainIds' not in file]
232 | test_list = glob(osp.join(coco_path, 'annotations', 'val2017', '*.png'))
233 | test_list = [file for file in test_list if '_labelTrainIds' not in file]
234 | assert (len(train_list) +
235 | len(test_list)) == COCO_LEN, 'Wrong length of list {} & {}'.format(
236 | len(train_list), len(test_list))
237 |
238 | if args.nproc > 1:
239 | mmcv.track_parallel_progress(
240 | partial(
241 | convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True),
242 | train_list,
243 | nproc=nproc)
244 | mmcv.track_parallel_progress(
245 | partial(
246 | convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False),
247 | test_list,
248 | nproc=nproc)
249 | else:
250 | mmcv.track_progress(
251 | partial(
252 | convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True),
253 | train_list)
254 | mmcv.track_progress(
255 | partial(
256 | convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False),
257 | test_list)
258 |
259 | print('Done!')
260 |
261 |
262 | if __name__ == '__main__':
263 | main()
264 |
--------------------------------------------------------------------------------
/segmentation/tools/convert_datasets/drive.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import argparse
3 | import os
4 | import os.path as osp
5 | import tempfile
6 | import zipfile
7 |
8 | import cv2
9 | import mmcv
10 |
11 |
12 | def parse_args():
13 | parser = argparse.ArgumentParser(
14 | description='Convert DRIVE dataset to mmsegmentation format')
15 | parser.add_argument(
16 | 'training_path', help='the training part of DRIVE dataset')
17 | parser.add_argument(
18 | 'testing_path', help='the testing part of DRIVE dataset')
19 | parser.add_argument('--tmp_dir', help='path of the temporary directory')
20 | parser.add_argument('-o', '--out_dir', help='output path')
21 | args = parser.parse_args()
22 | return args
23 |
24 |
25 | def main():
26 | args = parse_args()
27 | training_path = args.training_path
28 | testing_path = args.testing_path
29 | if args.out_dir is None:
30 | out_dir = osp.join('data', 'DRIVE')
31 | else:
32 | out_dir = args.out_dir
33 |
34 | print('Making directories...')
35 | mmcv.mkdir_or_exist(out_dir)
36 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images'))
37 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'training'))
38 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'validation'))
39 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations'))
40 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'training'))
41 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation'))
42 |
43 | with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
44 | print('Extracting training.zip...')
45 | zip_file = zipfile.ZipFile(training_path)
46 | zip_file.extractall(tmp_dir)
47 |
48 | print('Generating training dataset...')
49 | now_dir = osp.join(tmp_dir, 'training', 'images')
50 | for img_name in os.listdir(now_dir):
51 | img = mmcv.imread(osp.join(now_dir, img_name))
52 | mmcv.imwrite(
53 | img,
54 | osp.join(
55 | out_dir, 'images', 'training',
56 | osp.splitext(img_name)[0].replace('_training', '') +
57 | '.png'))
58 |
59 | now_dir = osp.join(tmp_dir, 'training', '1st_manual')
60 | for img_name in os.listdir(now_dir):
61 | cap = cv2.VideoCapture(osp.join(now_dir, img_name))
62 | ret, img = cap.read()
63 | mmcv.imwrite(
64 | img[:, :, 0] // 128,
65 | osp.join(out_dir, 'annotations', 'training',
66 | osp.splitext(img_name)[0] + '.png'))
67 |
68 | print('Extracting test.zip...')
69 | zip_file = zipfile.ZipFile(testing_path)
70 | zip_file.extractall(tmp_dir)
71 |
72 | print('Generating validation dataset...')
73 | now_dir = osp.join(tmp_dir, 'test', 'images')
74 | for img_name in os.listdir(now_dir):
75 | img = mmcv.imread(osp.join(now_dir, img_name))
76 | mmcv.imwrite(
77 | img,
78 | osp.join(
79 | out_dir, 'images', 'validation',
80 | osp.splitext(img_name)[0].replace('_test', '') + '.png'))
81 |
82 | now_dir = osp.join(tmp_dir, 'test', '1st_manual')
83 | if osp.exists(now_dir):
84 | for img_name in os.listdir(now_dir):
85 | cap = cv2.VideoCapture(osp.join(now_dir, img_name))
86 | ret, img = cap.read()
87 | # The annotation img should be divided by 128, because some of
88 | # the annotation imgs are not standard. We should set a
89 | # threshold to convert the nonstandard annotation imgs. The
90 | # value divided by 128 is equivalent to '1 if value >= 128
91 | # else 0'
92 | mmcv.imwrite(
93 | img[:, :, 0] // 128,
94 | osp.join(out_dir, 'annotations', 'validation',
95 | osp.splitext(img_name)[0] + '.png'))
96 |
97 | now_dir = osp.join(tmp_dir, 'test', '2nd_manual')
98 | if osp.exists(now_dir):
99 | for img_name in os.listdir(now_dir):
100 | cap = cv2.VideoCapture(osp.join(now_dir, img_name))
101 | ret, img = cap.read()
102 | mmcv.imwrite(
103 | img[:, :, 0] // 128,
104 | osp.join(out_dir, 'annotations', 'validation',
105 | osp.splitext(img_name)[0] + '.png'))
106 |
107 | print('Removing the temporary files...')
108 |
109 | print('Done!')
110 |
111 |
112 | if __name__ == '__main__':
113 | main()
114 |
--------------------------------------------------------------------------------
/segmentation/tools/convert_datasets/hrf.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import argparse
3 | import os
4 | import os.path as osp
5 | import tempfile
6 | import zipfile
7 |
8 | import mmcv
9 |
10 | HRF_LEN = 15
11 | TRAINING_LEN = 5
12 |
13 |
14 | def parse_args():
15 | parser = argparse.ArgumentParser(
16 | description='Convert HRF dataset to mmsegmentation format')
17 | parser.add_argument('healthy_path', help='the path of healthy.zip')
18 | parser.add_argument(
19 | 'healthy_manualsegm_path', help='the path of healthy_manualsegm.zip')
20 | parser.add_argument('glaucoma_path', help='the path of glaucoma.zip')
21 | parser.add_argument(
22 | 'glaucoma_manualsegm_path', help='the path of glaucoma_manualsegm.zip')
23 | parser.add_argument(
24 | 'diabetic_retinopathy_path',
25 | help='the path of diabetic_retinopathy.zip')
26 | parser.add_argument(
27 | 'diabetic_retinopathy_manualsegm_path',
28 | help='the path of diabetic_retinopathy_manualsegm.zip')
29 | parser.add_argument('--tmp_dir', help='path of the temporary directory')
30 | parser.add_argument('-o', '--out_dir', help='output path')
31 | args = parser.parse_args()
32 | return args
33 |
34 |
35 | def main():
36 | args = parse_args()
37 | images_path = [
38 | args.healthy_path, args.glaucoma_path, args.diabetic_retinopathy_path
39 | ]
40 | annotations_path = [
41 | args.healthy_manualsegm_path, args.glaucoma_manualsegm_path,
42 | args.diabetic_retinopathy_manualsegm_path
43 | ]
44 | if args.out_dir is None:
45 | out_dir = osp.join('data', 'HRF')
46 | else:
47 | out_dir = args.out_dir
48 |
49 | print('Making directories...')
50 | mmcv.mkdir_or_exist(out_dir)
51 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images'))
52 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'training'))
53 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'validation'))
54 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations'))
55 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'training'))
56 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation'))
57 |
58 | print('Generating images...')
59 | for now_path in images_path:
60 | with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
61 | zip_file = zipfile.ZipFile(now_path)
62 | zip_file.extractall(tmp_dir)
63 |
64 | assert len(os.listdir(tmp_dir)) == HRF_LEN, \
65 | 'len(os.listdir(tmp_dir)) != {}'.format(HRF_LEN)
66 |
67 | for filename in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]:
68 | img = mmcv.imread(osp.join(tmp_dir, filename))
69 | mmcv.imwrite(
70 | img,
71 | osp.join(out_dir, 'images', 'training',
72 | osp.splitext(filename)[0] + '.png'))
73 | for filename in sorted(os.listdir(tmp_dir))[TRAINING_LEN:]:
74 | img = mmcv.imread(osp.join(tmp_dir, filename))
75 | mmcv.imwrite(
76 | img,
77 | osp.join(out_dir, 'images', 'validation',
78 | osp.splitext(filename)[0] + '.png'))
79 |
80 | print('Generating annotations...')
81 | for now_path in annotations_path:
82 | with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
83 | zip_file = zipfile.ZipFile(now_path)
84 | zip_file.extractall(tmp_dir)
85 |
86 | assert len(os.listdir(tmp_dir)) == HRF_LEN, \
87 | 'len(os.listdir(tmp_dir)) != {}'.format(HRF_LEN)
88 |
89 | for filename in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]:
90 | img = mmcv.imread(osp.join(tmp_dir, filename))
91 | # The annotation img should be divided by 128, because some of
92 | # the annotation imgs are not standard. We should set a
93 | # threshold to convert the nonstandard annotation imgs. The
94 | # value divided by 128 is equivalent to '1 if value >= 128
95 | # else 0'
96 | mmcv.imwrite(
97 | img[:, :, 0] // 128,
98 | osp.join(out_dir, 'annotations', 'training',
99 | osp.splitext(filename)[0] + '.png'))
100 | for filename in sorted(os.listdir(tmp_dir))[TRAINING_LEN:]:
101 | img = mmcv.imread(osp.join(tmp_dir, filename))
102 | mmcv.imwrite(
103 | img[:, :, 0] // 128,
104 | osp.join(out_dir, 'annotations', 'validation',
105 | osp.splitext(filename)[0] + '.png'))
106 |
107 | print('Done!')
108 |
109 |
110 | if __name__ == '__main__':
111 | main()
112 |
--------------------------------------------------------------------------------
/segmentation/tools/convert_datasets/pascal_context.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import argparse
3 | import os.path as osp
4 | from functools import partial
5 |
6 | import mmcv
7 | import numpy as np
8 | from detail import Detail
9 | from PIL import Image
10 |
11 | _mapping = np.sort(
12 | np.array([
13 | 0, 2, 259, 260, 415, 324, 9, 258, 144, 18, 19, 22, 23, 397, 25, 284,
14 | 158, 159, 416, 33, 162, 420, 454, 295, 296, 427, 44, 45, 46, 308, 59,
15 | 440, 445, 31, 232, 65, 354, 424, 68, 326, 72, 458, 34, 207, 80, 355,
16 | 85, 347, 220, 349, 360, 98, 187, 104, 105, 366, 189, 368, 113, 115
17 | ]))
18 | _key = np.array(range(len(_mapping))).astype('uint8')
19 |
20 |
21 | def generate_labels(img_id, detail, out_dir):
22 |
23 | def _class_to_index(mask, _mapping, _key):
24 | # assert the values
25 | values = np.unique(mask)
26 | for i in range(len(values)):
27 | assert (values[i] in _mapping)
28 | index = np.digitize(mask.ravel(), _mapping, right=True)
29 | return _key[index].reshape(mask.shape)
30 |
31 | mask = Image.fromarray(
32 | _class_to_index(detail.getMask(img_id), _mapping=_mapping, _key=_key))
33 | filename = img_id['file_name']
34 | mask.save(osp.join(out_dir, filename.replace('jpg', 'png')))
35 | return osp.splitext(osp.basename(filename))[0]
36 |
37 |
38 | def parse_args():
39 | parser = argparse.ArgumentParser(
40 | description='Convert PASCAL VOC annotations to mmsegmentation format')
41 | parser.add_argument('devkit_path', help='pascal voc devkit path')
42 | parser.add_argument('json_path', help='annoation json filepath')
43 | parser.add_argument('-o', '--out_dir', help='output path')
44 | args = parser.parse_args()
45 | return args
46 |
47 |
48 | def main():
49 | args = parse_args()
50 | devkit_path = args.devkit_path
51 | if args.out_dir is None:
52 | out_dir = osp.join(devkit_path, 'VOC2010', 'SegmentationClassContext')
53 | else:
54 | out_dir = args.out_dir
55 | json_path = args.json_path
56 | mmcv.mkdir_or_exist(out_dir)
57 | img_dir = osp.join(devkit_path, 'VOC2010', 'JPEGImages')
58 |
59 | train_detail = Detail(json_path, img_dir, 'train')
60 | train_ids = train_detail.getImgs()
61 |
62 | val_detail = Detail(json_path, img_dir, 'val')
63 | val_ids = val_detail.getImgs()
64 |
65 | mmcv.mkdir_or_exist(
66 | osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext'))
67 |
68 | train_list = mmcv.track_progress(
69 | partial(generate_labels, detail=train_detail, out_dir=out_dir),
70 | train_ids)
71 | with open(
72 | osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext',
73 | 'train.txt'), 'w') as f:
74 | f.writelines(line + '\n' for line in sorted(train_list))
75 |
76 | val_list = mmcv.track_progress(
77 | partial(generate_labels, detail=val_detail, out_dir=out_dir), val_ids)
78 | with open(
79 | osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext',
80 | 'val.txt'), 'w') as f:
81 | f.writelines(line + '\n' for line in sorted(val_list))
82 |
83 | print('Done!')
84 |
85 |
86 | if __name__ == '__main__':
87 | main()
88 |
--------------------------------------------------------------------------------
/segmentation/tools/convert_datasets/stare.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import argparse
3 | import gzip
4 | import os
5 | import os.path as osp
6 | import tarfile
7 | import tempfile
8 |
9 | import mmcv
10 |
11 | STARE_LEN = 20
12 | TRAINING_LEN = 10
13 |
14 |
15 | def un_gz(src, dst):
16 | g_file = gzip.GzipFile(src)
17 | with open(dst, 'wb+') as f:
18 | f.write(g_file.read())
19 | g_file.close()
20 |
21 |
22 | def parse_args():
23 | parser = argparse.ArgumentParser(
24 | description='Convert STARE dataset to mmsegmentation format')
25 | parser.add_argument('image_path', help='the path of stare-images.tar')
26 | parser.add_argument('labels_ah', help='the path of labels-ah.tar')
27 | parser.add_argument('labels_vk', help='the path of labels-vk.tar')
28 | parser.add_argument('--tmp_dir', help='path of the temporary directory')
29 | parser.add_argument('-o', '--out_dir', help='output path')
30 | args = parser.parse_args()
31 | return args
32 |
33 |
34 | def main():
35 | args = parse_args()
36 | image_path = args.image_path
37 | labels_ah = args.labels_ah
38 | labels_vk = args.labels_vk
39 | if args.out_dir is None:
40 | out_dir = osp.join('data', 'STARE')
41 | else:
42 | out_dir = args.out_dir
43 |
44 | print('Making directories...')
45 | mmcv.mkdir_or_exist(out_dir)
46 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images'))
47 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'training'))
48 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'validation'))
49 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations'))
50 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'training'))
51 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation'))
52 |
53 | with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
54 | mmcv.mkdir_or_exist(osp.join(tmp_dir, 'gz'))
55 | mmcv.mkdir_or_exist(osp.join(tmp_dir, 'files'))
56 |
57 | print('Extracting stare-images.tar...')
58 | with tarfile.open(image_path) as f:
59 | f.extractall(osp.join(tmp_dir, 'gz'))
60 |
61 | for filename in os.listdir(osp.join(tmp_dir, 'gz')):
62 | un_gz(
63 | osp.join(tmp_dir, 'gz', filename),
64 | osp.join(tmp_dir, 'files',
65 | osp.splitext(filename)[0]))
66 |
67 | now_dir = osp.join(tmp_dir, 'files')
68 |
69 | assert len(os.listdir(now_dir)) == STARE_LEN, \
70 | 'len(os.listdir(now_dir)) != {}'.format(STARE_LEN)
71 |
72 | for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]:
73 | img = mmcv.imread(osp.join(now_dir, filename))
74 | mmcv.imwrite(
75 | img,
76 | osp.join(out_dir, 'images', 'training',
77 | osp.splitext(filename)[0] + '.png'))
78 |
79 | for filename in sorted(os.listdir(now_dir))[TRAINING_LEN:]:
80 | img = mmcv.imread(osp.join(now_dir, filename))
81 | mmcv.imwrite(
82 | img,
83 | osp.join(out_dir, 'images', 'validation',
84 | osp.splitext(filename)[0] + '.png'))
85 |
86 | print('Removing the temporary files...')
87 |
88 | with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
89 | mmcv.mkdir_or_exist(osp.join(tmp_dir, 'gz'))
90 | mmcv.mkdir_or_exist(osp.join(tmp_dir, 'files'))
91 |
92 | print('Extracting labels-ah.tar...')
93 | with tarfile.open(labels_ah) as f:
94 | f.extractall(osp.join(tmp_dir, 'gz'))
95 |
96 | for filename in os.listdir(osp.join(tmp_dir, 'gz')):
97 | un_gz(
98 | osp.join(tmp_dir, 'gz', filename),
99 | osp.join(tmp_dir, 'files',
100 | osp.splitext(filename)[0]))
101 |
102 | now_dir = osp.join(tmp_dir, 'files')
103 |
104 | assert len(os.listdir(now_dir)) == STARE_LEN, \
105 | 'len(os.listdir(now_dir)) != {}'.format(STARE_LEN)
106 |
107 | for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]:
108 | img = mmcv.imread(osp.join(now_dir, filename))
109 | # The annotation img should be divided by 128, because some of
110 | # the annotation imgs are not standard. We should set a threshold
111 | # to convert the nonstandard annotation imgs. The value divided by
112 | # 128 equivalent to '1 if value >= 128 else 0'
113 | mmcv.imwrite(
114 | img[:, :, 0] // 128,
115 | osp.join(out_dir, 'annotations', 'training',
116 | osp.splitext(filename)[0] + '.png'))
117 |
118 | for filename in sorted(os.listdir(now_dir))[TRAINING_LEN:]:
119 | img = mmcv.imread(osp.join(now_dir, filename))
120 | mmcv.imwrite(
121 | img[:, :, 0] // 128,
122 | osp.join(out_dir, 'annotations', 'validation',
123 | osp.splitext(filename)[0] + '.png'))
124 |
125 | print('Removing the temporary files...')
126 |
127 | with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir:
128 | mmcv.mkdir_or_exist(osp.join(tmp_dir, 'gz'))
129 | mmcv.mkdir_or_exist(osp.join(tmp_dir, 'files'))
130 |
131 | print('Extracting labels-vk.tar...')
132 | with tarfile.open(labels_vk) as f:
133 | f.extractall(osp.join(tmp_dir, 'gz'))
134 |
135 | for filename in os.listdir(osp.join(tmp_dir, 'gz')):
136 | un_gz(
137 | osp.join(tmp_dir, 'gz', filename),
138 | osp.join(tmp_dir, 'files',
139 | osp.splitext(filename)[0]))
140 |
141 | now_dir = osp.join(tmp_dir, 'files')
142 |
143 | assert len(os.listdir(now_dir)) == STARE_LEN, \
144 | 'len(os.listdir(now_dir)) != {}'.format(STARE_LEN)
145 |
146 | for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]:
147 | img = mmcv.imread(osp.join(now_dir, filename))
148 | mmcv.imwrite(
149 | img[:, :, 0] // 128,
150 | osp.join(out_dir, 'annotations', 'training',
151 | osp.splitext(filename)[0] + '.png'))
152 |
153 | for filename in sorted(os.listdir(now_dir))[TRAINING_LEN:]:
154 | img = mmcv.imread(osp.join(now_dir, filename))
155 | mmcv.imwrite(
156 | img[:, :, 0] // 128,
157 | osp.join(out_dir, 'annotations', 'validation',
158 | osp.splitext(filename)[0] + '.png'))
159 |
160 | print('Removing the temporary files...')
161 |
162 | print('Done!')
163 |
164 |
165 | if __name__ == '__main__':
166 | main()
167 |
--------------------------------------------------------------------------------
/segmentation/tools/convert_datasets/voc_aug.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import argparse
3 | import os.path as osp
4 | from functools import partial
5 |
6 | import mmcv
7 | import numpy as np
8 | from PIL import Image
9 | from scipy.io import loadmat
10 |
11 | AUG_LEN = 10582
12 |
13 |
14 | def convert_mat(mat_file, in_dir, out_dir):
15 | data = loadmat(osp.join(in_dir, mat_file))
16 | mask = data['GTcls'][0]['Segmentation'][0].astype(np.uint8)
17 | seg_filename = osp.join(out_dir, mat_file.replace('.mat', '.png'))
18 | Image.fromarray(mask).save(seg_filename, 'PNG')
19 |
20 |
21 | def generate_aug_list(merged_list, excluded_list):
22 | return list(set(merged_list) - set(excluded_list))
23 |
24 |
25 | def parse_args():
26 | parser = argparse.ArgumentParser(
27 | description='Convert PASCAL VOC annotations to mmsegmentation format')
28 | parser.add_argument('devkit_path', help='pascal voc devkit path')
29 | parser.add_argument('aug_path', help='pascal voc aug path')
30 | parser.add_argument('-o', '--out_dir', help='output path')
31 | parser.add_argument(
32 | '--nproc', default=1, type=int, help='number of process')
33 | args = parser.parse_args()
34 | return args
35 |
36 |
37 | def main():
38 | args = parse_args()
39 | devkit_path = args.devkit_path
40 | aug_path = args.aug_path
41 | nproc = args.nproc
42 | if args.out_dir is None:
43 | out_dir = osp.join(devkit_path, 'VOC2012', 'SegmentationClassAug')
44 | else:
45 | out_dir = args.out_dir
46 | mmcv.mkdir_or_exist(out_dir)
47 | in_dir = osp.join(aug_path, 'dataset', 'cls')
48 |
49 | mmcv.track_parallel_progress(
50 | partial(convert_mat, in_dir=in_dir, out_dir=out_dir),
51 | list(mmcv.scandir(in_dir, suffix='.mat')),
52 | nproc=nproc)
53 |
54 | full_aug_list = []
55 | with open(osp.join(aug_path, 'dataset', 'train.txt')) as f:
56 | full_aug_list += [line.strip() for line in f]
57 | with open(osp.join(aug_path, 'dataset', 'val.txt')) as f:
58 | full_aug_list += [line.strip() for line in f]
59 |
60 | with open(
61 | osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation',
62 | 'train.txt')) as f:
63 | ori_train_list = [line.strip() for line in f]
64 | with open(
65 | osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation',
66 | 'val.txt')) as f:
67 | val_list = [line.strip() for line in f]
68 |
69 | aug_train_list = generate_aug_list(ori_train_list + full_aug_list,
70 | val_list)
71 | assert len(aug_train_list) == AUG_LEN, 'len(aug_train_list) != {}'.format(
72 | AUG_LEN)
73 |
74 | with open(
75 | osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation',
76 | 'trainaug.txt'), 'w') as f:
77 | f.writelines(line + '\n' for line in aug_train_list)
78 |
79 | aug_list = generate_aug_list(full_aug_list, ori_train_list + val_list)
80 | assert len(aug_list) == AUG_LEN - len(
81 | ori_train_list), 'len(aug_list) != {}'.format(AUG_LEN -
82 | len(ori_train_list))
83 | with open(
84 | osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', 'aug.txt'),
85 | 'w') as f:
86 | f.writelines(line + '\n' for line in aug_list)
87 |
88 | print('Done!')
89 |
90 |
91 | if __name__ == '__main__':
92 | main()
93 |
--------------------------------------------------------------------------------
/segmentation/tools/deploy_test.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import argparse
3 | import os
4 | import os.path as osp
5 | import shutil
6 | import warnings
7 | from typing import Any, Iterable
8 |
9 | import mmcv
10 | import numpy as np
11 | import torch
12 | from mmcv.parallel import MMDataParallel
13 | from mmcv.runner import get_dist_info
14 | from mmcv.utils import DictAction
15 |
16 | from mmseg.apis import single_gpu_test
17 | from mmseg.datasets import build_dataloader, build_dataset
18 | from mmseg.models.segmentors.base import BaseSegmentor
19 | from mmseg.ops import resize
20 |
21 |
22 | class ONNXRuntimeSegmentor(BaseSegmentor):
23 |
24 | def __init__(self, onnx_file: str, cfg: Any, device_id: int):
25 | super(ONNXRuntimeSegmentor, self).__init__()
26 | import onnxruntime as ort
27 |
28 | # get the custom op path
29 | ort_custom_op_path = ''
30 | try:
31 | from mmcv.ops import get_onnxruntime_op_path
32 | ort_custom_op_path = get_onnxruntime_op_path()
33 | except (ImportError, ModuleNotFoundError):
34 | warnings.warn('If input model has custom op from mmcv, \
35 | you may have to build mmcv with ONNXRuntime from source.')
36 | session_options = ort.SessionOptions()
37 | # register custom op for onnxruntime
38 | if osp.exists(ort_custom_op_path):
39 | session_options.register_custom_ops_library(ort_custom_op_path)
40 | sess = ort.InferenceSession(onnx_file, session_options)
41 | providers = ['CPUExecutionProvider']
42 | options = [{}]
43 | is_cuda_available = ort.get_device() == 'GPU'
44 | if is_cuda_available:
45 | providers.insert(0, 'CUDAExecutionProvider')
46 | options.insert(0, {'device_id': device_id})
47 |
48 | sess.set_providers(providers, options)
49 |
50 | self.sess = sess
51 | self.device_id = device_id
52 | self.io_binding = sess.io_binding()
53 | self.output_names = [_.name for _ in sess.get_outputs()]
54 | for name in self.output_names:
55 | self.io_binding.bind_output(name)
56 | self.cfg = cfg
57 | self.test_mode = cfg.model.test_cfg.mode
58 | self.is_cuda_available = is_cuda_available
59 |
60 | def extract_feat(self, imgs):
61 | raise NotImplementedError('This method is not implemented.')
62 |
63 | def encode_decode(self, img, img_metas):
64 | raise NotImplementedError('This method is not implemented.')
65 |
66 | def forward_train(self, imgs, img_metas, **kwargs):
67 | raise NotImplementedError('This method is not implemented.')
68 |
69 | def simple_test(self, img: torch.Tensor, img_meta: Iterable,
70 | **kwargs) -> list:
71 | if not self.is_cuda_available:
72 | img = img.detach().cpu()
73 | elif self.device_id >= 0:
74 | img = img.cuda(self.device_id)
75 | device_type = img.device.type
76 | self.io_binding.bind_input(
77 | name='input',
78 | device_type=device_type,
79 | device_id=self.device_id,
80 | element_type=np.float32,
81 | shape=img.shape,
82 | buffer_ptr=img.data_ptr())
83 | self.sess.run_with_iobinding(self.io_binding)
84 | seg_pred = self.io_binding.copy_outputs_to_cpu()[0]
85 | # whole might support dynamic reshape
86 | ori_shape = img_meta[0]['ori_shape']
87 | if not (ori_shape[0] == seg_pred.shape[-2]
88 | and ori_shape[1] == seg_pred.shape[-1]):
89 | seg_pred = torch.from_numpy(seg_pred).float()
90 | seg_pred = resize(
91 | seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
92 | seg_pred = seg_pred.long().detach().cpu().numpy()
93 | seg_pred = seg_pred[0]
94 | seg_pred = list(seg_pred)
95 | return seg_pred
96 |
97 | def aug_test(self, imgs, img_metas, **kwargs):
98 | raise NotImplementedError('This method is not implemented.')
99 |
100 |
101 | class TensorRTSegmentor(BaseSegmentor):
102 |
103 | def __init__(self, trt_file: str, cfg: Any, device_id: int):
104 | super(TensorRTSegmentor, self).__init__()
105 | from mmcv.tensorrt import TRTWraper, load_tensorrt_plugin
106 | try:
107 | load_tensorrt_plugin()
108 | except (ImportError, ModuleNotFoundError):
109 | warnings.warn('If input model has custom op from mmcv, \
110 | you may have to build mmcv with TensorRT from source.')
111 | model = TRTWraper(
112 | trt_file, input_names=['input'], output_names=['output'])
113 |
114 | self.model = model
115 | self.device_id = device_id
116 | self.cfg = cfg
117 | self.test_mode = cfg.model.test_cfg.mode
118 |
119 | def extract_feat(self, imgs):
120 | raise NotImplementedError('This method is not implemented.')
121 |
122 | def encode_decode(self, img, img_metas):
123 | raise NotImplementedError('This method is not implemented.')
124 |
125 | def forward_train(self, imgs, img_metas, **kwargs):
126 | raise NotImplementedError('This method is not implemented.')
127 |
128 | def simple_test(self, img: torch.Tensor, img_meta: Iterable,
129 | **kwargs) -> list:
130 | with torch.cuda.device(self.device_id), torch.no_grad():
131 | seg_pred = self.model({'input': img})['output']
132 | seg_pred = seg_pred.detach().cpu().numpy()
133 | # whole might support dynamic reshape
134 | ori_shape = img_meta[0]['ori_shape']
135 | if not (ori_shape[0] == seg_pred.shape[-2]
136 | and ori_shape[1] == seg_pred.shape[-1]):
137 | seg_pred = torch.from_numpy(seg_pred).float()
138 | seg_pred = resize(
139 | seg_pred, size=tuple(ori_shape[:2]), mode='nearest')
140 | seg_pred = seg_pred.long().detach().cpu().numpy()
141 | seg_pred = seg_pred[0]
142 | seg_pred = list(seg_pred)
143 | return seg_pred
144 |
145 | def aug_test(self, imgs, img_metas, **kwargs):
146 | raise NotImplementedError('This method is not implemented.')
147 |
148 |
149 | def parse_args() -> argparse.Namespace:
150 | parser = argparse.ArgumentParser(
151 | description='mmseg backend test (and eval)')
152 | parser.add_argument('config', help='test config file path')
153 | parser.add_argument('model', help='Input model file')
154 | parser.add_argument(
155 | '--backend',
156 | help='Backend of the model.',
157 | choices=['onnxruntime', 'tensorrt'])
158 | parser.add_argument('--out', help='output result file in pickle format')
159 | parser.add_argument(
160 | '--format-only',
161 | action='store_true',
162 | help='Format the output results without perform evaluation. It is'
163 | 'useful when you want to format the result to a specific format and '
164 | 'submit it to the test server')
165 | parser.add_argument(
166 | '--eval',
167 | type=str,
168 | nargs='+',
169 | help='evaluation metrics, which depends on the dataset, e.g., "mIoU"'
170 | ' for generic datasets, and "cityscapes" for Cityscapes')
171 | parser.add_argument('--show', action='store_true', help='show results')
172 | parser.add_argument(
173 | '--show-dir', help='directory where painted images will be saved')
174 | parser.add_argument(
175 | '--options', nargs='+', action=DictAction, help='custom options')
176 | parser.add_argument(
177 | '--eval-options',
178 | nargs='+',
179 | action=DictAction,
180 | help='custom options for evaluation')
181 | parser.add_argument(
182 | '--opacity',
183 | type=float,
184 | default=0.5,
185 | help='Opacity of painted segmentation map. In (0, 1] range.')
186 | parser.add_argument('--local_rank', type=int, default=0)
187 | args = parser.parse_args()
188 | if 'LOCAL_RANK' not in os.environ:
189 | os.environ['LOCAL_RANK'] = str(args.local_rank)
190 | return args
191 |
192 |
193 | def main():
194 | args = parse_args()
195 |
196 | assert args.out or args.eval or args.format_only or args.show \
197 | or args.show_dir, \
198 | ('Please specify at least one operation (save/eval/format/show the '
199 | 'results / save the results) with the argument "--out", "--eval"'
200 | ', "--format-only", "--show" or "--show-dir"')
201 |
202 | if args.eval and args.format_only:
203 | raise ValueError('--eval and --format_only cannot be both specified')
204 |
205 | if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
206 | raise ValueError('The output file must be a pkl file.')
207 |
208 | cfg = mmcv.Config.fromfile(args.config)
209 | if args.options is not None:
210 | cfg.merge_from_dict(args.options)
211 | cfg.model.pretrained = None
212 | cfg.data.test.test_mode = True
213 |
214 | # init distributed env first, since logger depends on the dist info.
215 | distributed = False
216 |
217 | # build the dataloader
218 | # TODO: support multiple images per gpu (only minor changes are needed)
219 | dataset = build_dataset(cfg.data.test)
220 | data_loader = build_dataloader(
221 | dataset,
222 | samples_per_gpu=1,
223 | workers_per_gpu=cfg.data.workers_per_gpu,
224 | dist=distributed,
225 | shuffle=False)
226 |
227 | # load onnx config and meta
228 | cfg.model.train_cfg = None
229 |
230 | if args.backend == 'onnxruntime':
231 | model = ONNXRuntimeSegmentor(args.model, cfg=cfg, device_id=0)
232 | elif args.backend == 'tensorrt':
233 | model = TensorRTSegmentor(args.model, cfg=cfg, device_id=0)
234 |
235 | model.CLASSES = dataset.CLASSES
236 | model.PALETTE = dataset.PALETTE
237 |
238 | # clean gpu memory when starting a new evaluation.
239 | torch.cuda.empty_cache()
240 | eval_kwargs = {} if args.eval_options is None else args.eval_options
241 |
242 | # Deprecated
243 | efficient_test = eval_kwargs.get('efficient_test', False)
244 | if efficient_test:
245 | warnings.warn(
246 | '``efficient_test=True`` does not have effect in tools/test.py, '
247 | 'the evaluation and format results are CPU memory efficient by '
248 | 'default')
249 |
250 | eval_on_format_results = (
251 | args.eval is not None and 'cityscapes' in args.eval)
252 | if eval_on_format_results:
253 | assert len(args.eval) == 1, 'eval on format results is not ' \
254 | 'applicable for metrics other than ' \
255 | 'cityscapes'
256 | if args.format_only or eval_on_format_results:
257 | if 'imgfile_prefix' in eval_kwargs:
258 | tmpdir = eval_kwargs['imgfile_prefix']
259 | else:
260 | tmpdir = '.format_cityscapes'
261 | eval_kwargs.setdefault('imgfile_prefix', tmpdir)
262 | mmcv.mkdir_or_exist(tmpdir)
263 | else:
264 | tmpdir = None
265 |
266 | model = MMDataParallel(model, device_ids=[0])
267 | results = single_gpu_test(
268 | model,
269 | data_loader,
270 | args.show,
271 | args.show_dir,
272 | False,
273 | args.opacity,
274 | pre_eval=args.eval is not None and not eval_on_format_results,
275 | format_only=args.format_only or eval_on_format_results,
276 | format_args=eval_kwargs)
277 |
278 | rank, _ = get_dist_info()
279 | if rank == 0:
280 | if args.out:
281 | warnings.warn(
282 | 'The behavior of ``args.out`` has been changed since MMSeg '
283 | 'v0.16, the pickled outputs could be seg map as type of '
284 | 'np.array, pre-eval results or file paths for '
285 | '``dataset.format_results()``.')
286 | print(f'\nwriting results to {args.out}')
287 | mmcv.dump(results, args.out)
288 | if args.eval:
289 | dataset.evaluate(results, args.eval, **eval_kwargs)
290 | if tmpdir is not None and eval_on_format_results:
291 | # remove tmp dir when cityscapes evaluation
292 | shutil.rmtree(tmpdir)
293 |
294 |
295 | if __name__ == '__main__':
296 | main()
297 |
--------------------------------------------------------------------------------
/segmentation/tools/dist_test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CONFIG=$1
4 | CHECKPOINT=$2
5 | GPUS=$3
6 | PORT=${PORT:-29500}
7 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
8 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
9 | $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4}
10 |
--------------------------------------------------------------------------------
/segmentation/tools/dist_train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CONFIG=$1
4 | GPUS=$2
5 | PORT=${PORT:-29500}
6 |
7 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
8 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
9 | $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3}
10 |
--------------------------------------------------------------------------------
/segmentation/tools/get_flops.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import argparse
3 |
4 | from mmcv import Config
5 | from mmcv.cnn import get_model_complexity_info
6 |
7 | from mmseg.models import build_segmentor
8 |
9 |
10 | def parse_args():
11 | parser = argparse.ArgumentParser(description='Train a segmentor')
12 | parser.add_argument('config', help='train config file path')
13 | parser.add_argument(
14 | '--shape',
15 | type=int,
16 | nargs='+',
17 | default=[2048, 1024],
18 | help='input image size')
19 | args = parser.parse_args()
20 | return args
21 |
22 |
23 | def main():
24 |
25 | args = parse_args()
26 |
27 | if len(args.shape) == 1:
28 | input_shape = (3, args.shape[0], args.shape[0])
29 | elif len(args.shape) == 2:
30 | input_shape = (3, ) + tuple(args.shape)
31 | else:
32 | raise ValueError('invalid input shape')
33 |
34 | cfg = Config.fromfile(args.config)
35 | cfg.model.pretrained = None
36 | model = build_segmentor(
37 | cfg.model,
38 | train_cfg=cfg.get('train_cfg'),
39 | test_cfg=cfg.get('test_cfg')).cuda()
40 | model.eval()
41 |
42 | if hasattr(model, 'forward_dummy'):
43 | model.forward = model.forward_dummy
44 | else:
45 | raise NotImplementedError(
46 | 'FLOPs counter is currently not currently supported with {}'.
47 | format(model.__class__.__name__))
48 |
49 | flops, params = get_model_complexity_info(model, input_shape)
50 | split_line = '=' * 30
51 | print('{0}\nInput shape: {1}\nFlops: {2}\nParams: {3}\n{0}'.format(
52 | split_line, input_shape, flops, params))
53 | print('!!!Please be cautious if you use the results in papers. '
54 | 'You may need to check if all ops are supported and verify that the '
55 | 'flops computation is correct.')
56 |
57 |
58 | if __name__ == '__main__':
59 | main()
60 |
--------------------------------------------------------------------------------
/segmentation/tools/model_converters/mit2mmseg.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import argparse
3 | import os.path as osp
4 | from collections import OrderedDict
5 |
6 | import mmcv
7 | import torch
8 | from mmcv.runner import CheckpointLoader
9 |
10 |
11 | def convert_mit(ckpt):
12 | new_ckpt = OrderedDict()
13 | # Process the concat between q linear weights and kv linear weights
14 | for k, v in ckpt.items():
15 | if k.startswith('head'):
16 | continue
17 | # patch embedding convertion
18 | elif k.startswith('patch_embed'):
19 | stage_i = int(k.split('.')[0].replace('patch_embed', ''))
20 | new_k = k.replace(f'patch_embed{stage_i}', f'layers.{stage_i-1}.0')
21 | new_v = v
22 | if 'proj.' in new_k:
23 | new_k = new_k.replace('proj.', 'projection.')
24 | # transformer encoder layer convertion
25 | elif k.startswith('block'):
26 | stage_i = int(k.split('.')[0].replace('block', ''))
27 | new_k = k.replace(f'block{stage_i}', f'layers.{stage_i-1}.1')
28 | new_v = v
29 | if 'attn.q.' in new_k:
30 | sub_item_k = k.replace('q.', 'kv.')
31 | new_k = new_k.replace('q.', 'attn.in_proj_')
32 | new_v = torch.cat([v, ckpt[sub_item_k]], dim=0)
33 | elif 'attn.kv.' in new_k:
34 | continue
35 | elif 'attn.proj.' in new_k:
36 | new_k = new_k.replace('proj.', 'attn.out_proj.')
37 | elif 'attn.sr.' in new_k:
38 | new_k = new_k.replace('sr.', 'sr.')
39 | elif 'mlp.' in new_k:
40 | string = f'{new_k}-'
41 | new_k = new_k.replace('mlp.', 'ffn.layers.')
42 | if 'fc1.weight' in new_k or 'fc2.weight' in new_k:
43 | new_v = v.reshape((*v.shape, 1, 1))
44 | new_k = new_k.replace('fc1.', '0.')
45 | new_k = new_k.replace('dwconv.dwconv.', '1.')
46 | new_k = new_k.replace('fc2.', '4.')
47 | string += f'{new_k} {v.shape}-{new_v.shape}'
48 | # norm layer convertion
49 | elif k.startswith('norm'):
50 | stage_i = int(k.split('.')[0].replace('norm', ''))
51 | new_k = k.replace(f'norm{stage_i}', f'layers.{stage_i-1}.2')
52 | new_v = v
53 | else:
54 | new_k = k
55 | new_v = v
56 | new_ckpt[new_k] = new_v
57 | return new_ckpt
58 |
59 |
60 | def main():
61 | parser = argparse.ArgumentParser(
62 | description='Convert keys in official pretrained segformer to '
63 | 'MMSegmentation style.')
64 | parser.add_argument('src', help='src model path or url')
65 | # The dst path must be a full path of the new checkpoint.
66 | parser.add_argument('dst', help='save path')
67 | args = parser.parse_args()
68 |
69 | checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
70 | if 'state_dict' in checkpoint:
71 | state_dict = checkpoint['state_dict']
72 | elif 'model' in checkpoint:
73 | state_dict = checkpoint['model']
74 | else:
75 | state_dict = checkpoint
76 | weight = convert_mit(state_dict)
77 | mmcv.mkdir_or_exist(osp.dirname(args.dst))
78 | torch.save(weight, args.dst)
79 |
80 |
81 | if __name__ == '__main__':
82 | main()
83 |
--------------------------------------------------------------------------------
/segmentation/tools/model_converters/swin2mmseg.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import argparse
3 | import os.path as osp
4 | from collections import OrderedDict
5 |
6 | import mmcv
7 | import torch
8 | from mmcv.runner import CheckpointLoader
9 |
10 |
11 | def convert_swin(ckpt):
12 | new_ckpt = OrderedDict()
13 |
14 | def correct_unfold_reduction_order(x):
15 | out_channel, in_channel = x.shape
16 | x = x.reshape(out_channel, 4, in_channel // 4)
17 | x = x[:, [0, 2, 1, 3], :].transpose(1,
18 | 2).reshape(out_channel, in_channel)
19 | return x
20 |
21 | def correct_unfold_norm_order(x):
22 | in_channel = x.shape[0]
23 | x = x.reshape(4, in_channel // 4)
24 | x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel)
25 | return x
26 |
27 | for k, v in ckpt.items():
28 | if k.startswith('head'):
29 | continue
30 | elif k.startswith('layers'):
31 | new_v = v
32 | if 'attn.' in k:
33 | new_k = k.replace('attn.', 'attn.w_msa.')
34 | elif 'mlp.' in k:
35 | if 'mlp.fc1.' in k:
36 | new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.')
37 | elif 'mlp.fc2.' in k:
38 | new_k = k.replace('mlp.fc2.', 'ffn.layers.1.')
39 | else:
40 | new_k = k.replace('mlp.', 'ffn.')
41 | elif 'downsample' in k:
42 | new_k = k
43 | if 'reduction.' in k:
44 | new_v = correct_unfold_reduction_order(v)
45 | elif 'norm.' in k:
46 | new_v = correct_unfold_norm_order(v)
47 | else:
48 | new_k = k
49 | new_k = new_k.replace('layers', 'stages', 1)
50 | elif k.startswith('patch_embed'):
51 | new_v = v
52 | if 'proj' in k:
53 | new_k = k.replace('proj', 'projection')
54 | else:
55 | new_k = k
56 | else:
57 | new_v = v
58 | new_k = k
59 |
60 | new_ckpt[new_k] = new_v
61 |
62 | return new_ckpt
63 |
64 |
65 | def main():
66 | parser = argparse.ArgumentParser(
67 | description='Convert keys in official pretrained swin models to'
68 | 'MMSegmentation style.')
69 | parser.add_argument('src', help='src model path or url')
70 | # The dst path must be a full path of the new checkpoint.
71 | parser.add_argument('dst', help='save path')
72 | args = parser.parse_args()
73 |
74 | checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
75 | if 'state_dict' in checkpoint:
76 | state_dict = checkpoint['state_dict']
77 | elif 'model' in checkpoint:
78 | state_dict = checkpoint['model']
79 | else:
80 | state_dict = checkpoint
81 | weight = convert_swin(state_dict)
82 | mmcv.mkdir_or_exist(osp.dirname(args.dst))
83 | torch.save(weight, args.dst)
84 |
85 |
86 | if __name__ == '__main__':
87 | main()
88 |
--------------------------------------------------------------------------------
/segmentation/tools/model_converters/vit2mmseg.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import argparse
3 | import os.path as osp
4 | from collections import OrderedDict
5 |
6 | import mmcv
7 | import torch
8 | from mmcv.runner import CheckpointLoader
9 |
10 |
11 | def convert_vit(ckpt):
12 |
13 | new_ckpt = OrderedDict()
14 |
15 | for k, v in ckpt.items():
16 | # if k.startswith('class_token'):
17 | # new_k = k.replace('class_token.', 'cls_token.')
18 | # if k.startswith('pos_embedding'):
19 | # new_k = k.replace('pos_embedding.', 'pos_embed.')
20 | if k.startswith('head'):
21 | continue
22 | if k.startswith('norm'):
23 | new_k = k.replace('norm.', 'ln1.')
24 | elif k.startswith('patch_embed'):
25 | if 'proj' in k:
26 | new_k = k.replace('proj', 'projection')
27 | else:
28 | new_k = k
29 | elif k.startswith('blocks'):
30 | if 'norm' in k:
31 | new_k = k.replace('norm', 'ln')
32 | elif 'mlp.fc1' in k:
33 | new_k = k.replace('mlp.fc1', 'ffn.layers.0.0')
34 | elif 'mlp.fc2' in k:
35 | new_k = k.replace('mlp.fc2', 'ffn.layers.1')
36 | elif 'attn.qkv' in k:
37 | new_k = k.replace('attn.qkv.', 'attn.attn.in_proj_')
38 | elif 'attn.proj' in k:
39 | new_k = k.replace('attn.proj', 'attn.attn.out_proj')
40 | else:
41 | new_k = k
42 | new_k = new_k.replace('blocks.', 'layers.')
43 | else:
44 | new_k = k
45 | new_ckpt[new_k] = v
46 |
47 | return new_ckpt
48 |
49 |
50 | def main():
51 | parser = argparse.ArgumentParser(
52 | description='Convert keys in timm pretrained vit models to '
53 | 'MMSegmentation style.')
54 | parser.add_argument('src', help='src model path or url')
55 | # The dst path must be a full path of the new checkpoint.
56 | parser.add_argument('dst', help='save path')
57 | args = parser.parse_args()
58 |
59 | checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
60 | if 'state_dict' in checkpoint:
61 | # timm checkpoint
62 | state_dict = checkpoint['state_dict']
63 | elif 'model' in checkpoint:
64 | # deit checkpoint
65 | state_dict = checkpoint['model']
66 | elif 'teacher' in checkpoint:
67 | # deit checkpoint
68 | state_dict = checkpoint['teacher']
69 | state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
70 | else:
71 | state_dict = checkpoint
72 | weight = convert_vit(state_dict)
73 | mmcv.mkdir_or_exist(osp.dirname(args.dst))
74 | torch.save(weight, args.dst)
75 |
76 |
77 | if __name__ == '__main__':
78 | main()
79 |
--------------------------------------------------------------------------------
/segmentation/tools/onnx2tensorrt.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import argparse
3 | import os
4 | import os.path as osp
5 | from typing import Iterable, Optional, Union
6 |
7 | import matplotlib.pyplot as plt
8 | import mmcv
9 | import numpy as np
10 | import onnxruntime as ort
11 | import torch
12 | from mmcv.ops import get_onnxruntime_op_path
13 | from mmcv.tensorrt import (TRTWraper, is_tensorrt_plugin_loaded, onnx2trt,
14 | save_trt_engine)
15 |
16 | from mmseg.apis.inference import LoadImage
17 | from mmseg.datasets import DATASETS
18 | from mmseg.datasets.pipelines import Compose
19 |
20 |
21 | def get_GiB(x: int):
22 | """return x GiB."""
23 | return x * (1 << 30)
24 |
25 |
26 | def _prepare_input_img(img_path: str,
27 | test_pipeline: Iterable[dict],
28 | shape: Optional[Iterable] = None,
29 | rescale_shape: Optional[Iterable] = None) -> dict:
30 | # build the data pipeline
31 | if shape is not None:
32 | test_pipeline[1]['img_scale'] = (shape[1], shape[0])
33 | test_pipeline[1]['transforms'][0]['keep_ratio'] = False
34 | test_pipeline = [LoadImage()] + test_pipeline[1:]
35 | test_pipeline = Compose(test_pipeline)
36 | # prepare data
37 | data = dict(img=img_path)
38 | data = test_pipeline(data)
39 | imgs = data['img']
40 | img_metas = [i.data for i in data['img_metas']]
41 |
42 | if rescale_shape is not None:
43 | for img_meta in img_metas:
44 | img_meta['ori_shape'] = tuple(rescale_shape) + (3, )
45 |
46 | mm_inputs = {'imgs': imgs, 'img_metas': img_metas}
47 |
48 | return mm_inputs
49 |
50 |
51 | def _update_input_img(img_list: Iterable, img_meta_list: Iterable):
52 | # update img and its meta list
53 | N = img_list[0].size(0)
54 | img_meta = img_meta_list[0][0]
55 | img_shape = img_meta['img_shape']
56 | ori_shape = img_meta['ori_shape']
57 | pad_shape = img_meta['pad_shape']
58 | new_img_meta_list = [[{
59 | 'img_shape':
60 | img_shape,
61 | 'ori_shape':
62 | ori_shape,
63 | 'pad_shape':
64 | pad_shape,
65 | 'filename':
66 | img_meta['filename'],
67 | 'scale_factor':
68 | (img_shape[1] / ori_shape[1], img_shape[0] / ori_shape[0]) * 2,
69 | 'flip':
70 | False,
71 | } for _ in range(N)]]
72 |
73 | return img_list, new_img_meta_list
74 |
75 |
76 | def show_result_pyplot(img: Union[str, np.ndarray],
77 | result: np.ndarray,
78 | palette: Optional[Iterable] = None,
79 | fig_size: Iterable[int] = (15, 10),
80 | opacity: float = 0.5,
81 | title: str = '',
82 | block: bool = True):
83 | img = mmcv.imread(img)
84 | img = img.copy()
85 | seg = result[0]
86 | seg = mmcv.imresize(seg, img.shape[:2][::-1])
87 | palette = np.array(palette)
88 | assert palette.shape[1] == 3
89 | assert len(palette.shape) == 2
90 | assert 0 < opacity <= 1.0
91 | color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
92 | for label, color in enumerate(palette):
93 | color_seg[seg == label, :] = color
94 | # convert to BGR
95 | color_seg = color_seg[..., ::-1]
96 |
97 | img = img * (1 - opacity) + color_seg * opacity
98 | img = img.astype(np.uint8)
99 |
100 | plt.figure(figsize=fig_size)
101 | plt.imshow(mmcv.bgr2rgb(img))
102 | plt.title(title)
103 | plt.tight_layout()
104 | plt.show(block=block)
105 |
106 |
107 | def onnx2tensorrt(onnx_file: str,
108 | trt_file: str,
109 | config: dict,
110 | input_config: dict,
111 | fp16: bool = False,
112 | verify: bool = False,
113 | show: bool = False,
114 | dataset: str = 'CityscapesDataset',
115 | workspace_size: int = 1,
116 | verbose: bool = False):
117 | import tensorrt as trt
118 | min_shape = input_config['min_shape']
119 | max_shape = input_config['max_shape']
120 | # create trt engine and wraper
121 | opt_shape_dict = {'input': [min_shape, min_shape, max_shape]}
122 | max_workspace_size = get_GiB(workspace_size)
123 | trt_engine = onnx2trt(
124 | onnx_file,
125 | opt_shape_dict,
126 | log_level=trt.Logger.VERBOSE if verbose else trt.Logger.ERROR,
127 | fp16_mode=fp16,
128 | max_workspace_size=max_workspace_size)
129 | save_dir, _ = osp.split(trt_file)
130 | if save_dir:
131 | os.makedirs(save_dir, exist_ok=True)
132 | save_trt_engine(trt_engine, trt_file)
133 | print(f'Successfully created TensorRT engine: {trt_file}')
134 |
135 | if verify:
136 | inputs = _prepare_input_img(
137 | input_config['input_path'],
138 | config.data.test.pipeline,
139 | shape=min_shape[2:])
140 |
141 | imgs = inputs['imgs']
142 | img_metas = inputs['img_metas']
143 | img_list = [img[None, :] for img in imgs]
144 | img_meta_list = [[img_meta] for img_meta in img_metas]
145 | # update img_meta
146 | img_list, img_meta_list = _update_input_img(img_list, img_meta_list)
147 |
148 | if max_shape[0] > 1:
149 | # concate flip image for batch test
150 | flip_img_list = [_.flip(-1) for _ in img_list]
151 | img_list = [
152 | torch.cat((ori_img, flip_img), 0)
153 | for ori_img, flip_img in zip(img_list, flip_img_list)
154 | ]
155 |
156 | # Get results from ONNXRuntime
157 | ort_custom_op_path = get_onnxruntime_op_path()
158 | session_options = ort.SessionOptions()
159 | if osp.exists(ort_custom_op_path):
160 | session_options.register_custom_ops_library(ort_custom_op_path)
161 | sess = ort.InferenceSession(onnx_file, session_options)
162 | sess.set_providers(['CPUExecutionProvider'], [{}]) # use cpu mode
163 | onnx_output = sess.run(['output'],
164 | {'input': img_list[0].detach().numpy()})[0][0]
165 |
166 | # Get results from TensorRT
167 | trt_model = TRTWraper(trt_file, ['input'], ['output'])
168 | with torch.no_grad():
169 | trt_outputs = trt_model({'input': img_list[0].contiguous().cuda()})
170 | trt_output = trt_outputs['output'][0].cpu().detach().numpy()
171 |
172 | if show:
173 | dataset = DATASETS.get(dataset)
174 | assert dataset is not None
175 | palette = dataset.PALETTE
176 |
177 | show_result_pyplot(
178 | input_config['input_path'],
179 | (onnx_output[0].astype(np.uint8), ),
180 | palette=palette,
181 | title='ONNXRuntime',
182 | block=False)
183 | show_result_pyplot(
184 | input_config['input_path'], (trt_output[0].astype(np.uint8), ),
185 | palette=palette,
186 | title='TensorRT')
187 |
188 | np.testing.assert_allclose(
189 | onnx_output, trt_output, rtol=1e-03, atol=1e-05)
190 | print('TensorRT and ONNXRuntime output all close.')
191 |
192 |
193 | def parse_args():
194 | parser = argparse.ArgumentParser(
195 | description='Convert MMSegmentation models from ONNX to TensorRT')
196 | parser.add_argument('config', help='Config file of the model')
197 | parser.add_argument('model', help='Path to the input ONNX model')
198 | parser.add_argument(
199 | '--trt-file', type=str, help='Path to the output TensorRT engine')
200 | parser.add_argument(
201 | '--max-shape',
202 | type=int,
203 | nargs=4,
204 | default=[1, 3, 400, 600],
205 | help='Maximum shape of model input.')
206 | parser.add_argument(
207 | '--min-shape',
208 | type=int,
209 | nargs=4,
210 | default=[1, 3, 400, 600],
211 | help='Minimum shape of model input.')
212 | parser.add_argument('--fp16', action='store_true', help='Enable fp16 mode')
213 | parser.add_argument(
214 | '--workspace-size',
215 | type=int,
216 | default=1,
217 | help='Max workspace size in GiB')
218 | parser.add_argument(
219 | '--input-img', type=str, default='', help='Image for test')
220 | parser.add_argument(
221 | '--show', action='store_true', help='Whether to show output results')
222 | parser.add_argument(
223 | '--dataset',
224 | type=str,
225 | default='CityscapesDataset',
226 | help='Dataset name')
227 | parser.add_argument(
228 | '--verify',
229 | action='store_true',
230 | help='Verify the outputs of ONNXRuntime and TensorRT')
231 | parser.add_argument(
232 | '--verbose',
233 | action='store_true',
234 | help='Whether to verbose logging messages while creating \
235 | TensorRT engine.')
236 | args = parser.parse_args()
237 | return args
238 |
239 |
240 | if __name__ == '__main__':
241 |
242 | assert is_tensorrt_plugin_loaded(), 'TensorRT plugin should be compiled.'
243 | args = parse_args()
244 |
245 | if not args.input_img:
246 | args.input_img = osp.join(osp.dirname(__file__), '../demo/demo.png')
247 |
248 | # check arguments
249 | assert osp.exists(args.config), 'Config {} not found.'.format(args.config)
250 | assert osp.exists(args.model), \
251 | 'ONNX model {} not found.'.format(args.model)
252 | assert args.workspace_size >= 0, 'Workspace size less than 0.'
253 | assert DATASETS.get(args.dataset) is not None, \
254 | 'Dataset {} does not found.'.format(args.dataset)
255 | for max_value, min_value in zip(args.max_shape, args.min_shape):
256 | assert max_value >= min_value, \
257 | 'max_shape sould be larger than min shape'
258 |
259 | input_config = {
260 | 'min_shape': args.min_shape,
261 | 'max_shape': args.max_shape,
262 | 'input_path': args.input_img
263 | }
264 |
265 | cfg = mmcv.Config.fromfile(args.config)
266 | onnx2tensorrt(
267 | args.model,
268 | args.trt_file,
269 | cfg,
270 | input_config,
271 | fp16=args.fp16,
272 | verify=args.verify,
273 | show=args.show,
274 | dataset=args.dataset,
275 | workspace_size=args.workspace_size,
276 | verbose=args.verbose)
277 |
--------------------------------------------------------------------------------
/segmentation/tools/print_config.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import argparse
3 |
4 | from mmcv import Config, DictAction
5 |
6 | from mmseg.apis import init_segmentor
7 |
8 |
9 | def parse_args():
10 | parser = argparse.ArgumentParser(description='Print the whole config')
11 | parser.add_argument('config', help='config file path')
12 | parser.add_argument(
13 | '--graph', action='store_true', help='print the models graph')
14 | parser.add_argument(
15 | '--options', nargs='+', action=DictAction, help='arguments in dict')
16 | args = parser.parse_args()
17 |
18 | return args
19 |
20 |
21 | def main():
22 | args = parse_args()
23 |
24 | cfg = Config.fromfile(args.config)
25 | if args.options is not None:
26 | cfg.merge_from_dict(args.options)
27 | print(f'Config:\n{cfg.pretty_text}')
28 | # dump config
29 | cfg.dump('example.py')
30 | # dump models graph
31 | if args.graph:
32 | model = init_segmentor(args.config, device='cpu')
33 | print(f'Model graph:\n{str(model)}')
34 | with open('example-graph.txt', 'w') as f:
35 | f.writelines(str(model))
36 |
37 |
38 | if __name__ == '__main__':
39 | main()
40 |
--------------------------------------------------------------------------------
/segmentation/tools/publish_model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import argparse
3 | import subprocess
4 |
5 | import torch
6 |
7 |
8 | def parse_args():
9 | parser = argparse.ArgumentParser(
10 | description='Process a checkpoint to be published')
11 | parser.add_argument('in_file', help='input checkpoint filename')
12 | parser.add_argument('out_file', help='output checkpoint filename')
13 | args = parser.parse_args()
14 | return args
15 |
16 |
17 | def process_checkpoint(in_file, out_file):
18 | checkpoint = torch.load(in_file, map_location='cpu')
19 | # remove optimizer for smaller file size
20 | if 'optimizer' in checkpoint:
21 | del checkpoint['optimizer']
22 | # if it is necessary to remove some sensitive data in checkpoint['meta'],
23 | # add the code here.
24 | torch.save(checkpoint, out_file)
25 | sha = subprocess.check_output(['sha256sum', out_file]).decode()
26 | final_file = out_file.rstrip('.pth') + '-{}.pth'.format(sha[:8])
27 | subprocess.Popen(['mv', out_file, final_file])
28 |
29 |
30 | def main():
31 | args = parse_args()
32 | process_checkpoint(args.in_file, args.out_file)
33 |
34 |
35 | if __name__ == '__main__':
36 | main()
37 |
--------------------------------------------------------------------------------
/segmentation/tools/pytorch2torchscript.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import argparse
3 |
4 | import mmcv
5 | import numpy as np
6 | import torch
7 | import torch._C
8 | import torch.serialization
9 | from mmcv.runner import load_checkpoint
10 | from torch import nn
11 |
12 | from mmseg.models import build_segmentor
13 |
14 | torch.manual_seed(3)
15 |
16 |
17 | def digit_version(version_str):
18 | digit_version = []
19 | for x in version_str.split('.'):
20 | if x.isdigit():
21 | digit_version.append(int(x))
22 | elif x.find('rc') != -1:
23 | patch_version = x.split('rc')
24 | digit_version.append(int(patch_version[0]) - 1)
25 | digit_version.append(int(patch_version[1]))
26 | return digit_version
27 |
28 |
29 | def check_torch_version():
30 | torch_minimum_version = '1.8.0'
31 | torch_version = digit_version(torch.__version__)
32 |
33 | assert (torch_version >= digit_version(torch_minimum_version)), \
34 | f'Torch=={torch.__version__} is not support for converting to ' \
35 | f'torchscript. Please install pytorch>={torch_minimum_version}.'
36 |
37 |
38 | def _convert_batchnorm(module):
39 | module_output = module
40 | if isinstance(module, torch.nn.SyncBatchNorm):
41 | module_output = torch.nn.BatchNorm2d(module.num_features, module.eps,
42 | module.momentum, module.affine,
43 | module.track_running_stats)
44 | if module.affine:
45 | module_output.weight.data = module.weight.data.clone().detach()
46 | module_output.bias.data = module.bias.data.clone().detach()
47 | # keep requires_grad unchanged
48 | module_output.weight.requires_grad = module.weight.requires_grad
49 | module_output.bias.requires_grad = module.bias.requires_grad
50 | module_output.running_mean = module.running_mean
51 | module_output.running_var = module.running_var
52 | module_output.num_batches_tracked = module.num_batches_tracked
53 | for name, child in module.named_children():
54 | module_output.add_module(name, _convert_batchnorm(child))
55 | del module
56 | return module_output
57 |
58 |
59 | def _demo_mm_inputs(input_shape, num_classes):
60 | """Create a superset of inputs needed to run test or train batches.
61 |
62 | Args:
63 | input_shape (tuple):
64 | input batch dimensions
65 | num_classes (int):
66 | number of semantic classes
67 | """
68 | (N, C, H, W) = input_shape
69 | rng = np.random.RandomState(0)
70 | imgs = rng.rand(*input_shape)
71 | segs = rng.randint(
72 | low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8)
73 | img_metas = [{
74 | 'img_shape': (H, W, C),
75 | 'ori_shape': (H, W, C),
76 | 'pad_shape': (H, W, C),
77 | 'filename': '.png',
78 | 'scale_factor': 1.0,
79 | 'flip': False,
80 | } for _ in range(N)]
81 | mm_inputs = {
82 | 'imgs': torch.FloatTensor(imgs).requires_grad_(True),
83 | 'img_metas': img_metas,
84 | 'gt_semantic_seg': torch.LongTensor(segs)
85 | }
86 | return mm_inputs
87 |
88 |
89 | def pytorch2libtorch(model,
90 | input_shape,
91 | show=False,
92 | output_file='tmp.pt',
93 | verify=False):
94 | """Export Pytorch model to TorchScript model and verify the outputs are
95 | same between Pytorch and TorchScript.
96 |
97 | Args:
98 | model (nn.Module): Pytorch model we want to export.
99 | input_shape (tuple): Use this input shape to construct
100 | the corresponding dummy input and execute the model.
101 | show (bool): Whether print the computation graph. Default: False.
102 | output_file (string): The path to where we store the
103 | output TorchScript model. Default: `tmp.pt`.
104 | verify (bool): Whether compare the outputs between
105 | Pytorch and TorchScript. Default: False.
106 | """
107 | if isinstance(model.decode_head, nn.ModuleList):
108 | num_classes = model.decode_head[-1].num_classes
109 | else:
110 | num_classes = model.decode_head.num_classes
111 |
112 | mm_inputs = _demo_mm_inputs(input_shape, num_classes)
113 |
114 | imgs = mm_inputs.pop('imgs')
115 |
116 | # replace the orginal forword with forward_dummy
117 | model.forward = model.forward_dummy
118 | model.eval()
119 | traced_model = torch.jit.trace(
120 | model,
121 | example_inputs=imgs,
122 | check_trace=verify,
123 | )
124 |
125 | if show:
126 | print(traced_model.graph)
127 |
128 | traced_model.save(output_file)
129 | print('Successfully exported TorchScript model: {}'.format(output_file))
130 |
131 |
132 | def parse_args():
133 | parser = argparse.ArgumentParser(
134 | description='Convert MMSeg to TorchScript')
135 | parser.add_argument('config', help='test config file path')
136 | parser.add_argument('--checkpoint', help='checkpoint file', default=None)
137 | parser.add_argument(
138 | '--show', action='store_true', help='show TorchScript graph')
139 | parser.add_argument(
140 | '--verify', action='store_true', help='verify the TorchScript model')
141 | parser.add_argument('--output-file', type=str, default='tmp.pt')
142 | parser.add_argument(
143 | '--shape',
144 | type=int,
145 | nargs='+',
146 | default=[512, 512],
147 | help='input image size (height, width)')
148 | args = parser.parse_args()
149 | return args
150 |
151 |
152 | if __name__ == '__main__':
153 | args = parse_args()
154 | check_torch_version()
155 |
156 | if len(args.shape) == 1:
157 | input_shape = (1, 3, args.shape[0], args.shape[0])
158 | elif len(args.shape) == 2:
159 | input_shape = (
160 | 1,
161 | 3,
162 | ) + tuple(args.shape)
163 | else:
164 | raise ValueError('invalid input shape')
165 |
166 | cfg = mmcv.Config.fromfile(args.config)
167 | cfg.model.pretrained = None
168 |
169 | # build the model and load checkpoint
170 | cfg.model.train_cfg = None
171 | segmentor = build_segmentor(
172 | cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg'))
173 | # convert SyncBN to BN
174 | segmentor = _convert_batchnorm(segmentor)
175 |
176 | if args.checkpoint:
177 | load_checkpoint(segmentor, args.checkpoint, map_location='cpu')
178 |
179 | # convert the PyTorch model to LibTorch model
180 | pytorch2libtorch(
181 | segmentor,
182 | input_shape,
183 | show=args.show,
184 | output_file=args.output_file,
185 | verify=args.verify)
186 |
--------------------------------------------------------------------------------
/segmentation/tools/slurm_test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -x
4 |
5 | PARTITION=$1
6 | JOB_NAME=$2
7 | CONFIG=$3
8 | CHECKPOINT=$4
9 | GPUS=${GPUS:-4}
10 | GPUS_PER_NODE=${GPUS_PER_NODE:-4}
11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5}
12 | PY_ARGS=${@:5}
13 | SRUN_ARGS=${SRUN_ARGS:-""}
14 |
15 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
16 | srun -p ${PARTITION} \
17 | --job-name=${JOB_NAME} \
18 | --gres=gpu:${GPUS_PER_NODE} \
19 | --ntasks=${GPUS} \
20 | --ntasks-per-node=${GPUS_PER_NODE} \
21 | --cpus-per-task=${CPUS_PER_TASK} \
22 | --kill-on-bad-exit=1 \
23 | ${SRUN_ARGS} \
24 | python -u tools/test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS}
25 |
--------------------------------------------------------------------------------
/segmentation/tools/slurm_train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -x
4 |
5 | PARTITION=$1
6 | JOB_NAME=$2
7 | CONFIG=$3
8 | GPUS=${GPUS:-4}
9 | GPUS_PER_NODE=${GPUS_PER_NODE:-4}
10 | CPUS_PER_TASK=${CPUS_PER_TASK:-5}
11 | SRUN_ARGS=${SRUN_ARGS:-""}
12 | PY_ARGS=${@:4}
13 |
14 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
15 | srun -p ${PARTITION} \
16 | --job-name=${JOB_NAME} \
17 | --gres=gpu:${GPUS_PER_NODE} \
18 | --ntasks=${GPUS} \
19 | --ntasks-per-node=${GPUS_PER_NODE} \
20 | --cpus-per-task=${CPUS_PER_TASK} \
21 | --kill-on-bad-exit=1 \
22 | ${SRUN_ARGS} \
23 | python -u tools/train.py ${CONFIG} --launcher="slurm" ${PY_ARGS}
24 |
--------------------------------------------------------------------------------
/segmentation/tools/test.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import argparse
3 | import os
4 | import os.path as osp
5 | import shutil
6 | import time
7 | import warnings
8 |
9 | import mmcv
10 | import torch
11 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
12 | from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
13 | wrap_fp16_model)
14 | from mmcv.utils import DictAction
15 |
16 | from mmseg.apis import multi_gpu_test, single_gpu_test
17 | from mmseg.datasets import build_dataloader, build_dataset
18 | from mmseg.models import build_segmentor
19 |
20 |
21 | def parse_args():
22 | parser = argparse.ArgumentParser(
23 | description='mmseg test (and eval) a model')
24 | parser.add_argument('config', help='test config file path')
25 | parser.add_argument('checkpoint', help='checkpoint file')
26 | parser.add_argument(
27 | '--work-dir',
28 | help=('if specified, the evaluation metric results will be dumped'
29 | 'into the directory as json'))
30 | parser.add_argument(
31 | '--aug-test', action='store_true', help='Use Flip and Multi scale aug')
32 | parser.add_argument('--out', help='output result file in pickle format')
33 | parser.add_argument(
34 | '--format-only',
35 | action='store_true',
36 | help='Format the output results without perform evaluation. It is'
37 | 'useful when you want to format the result to a specific format and '
38 | 'submit it to the test server')
39 | parser.add_argument(
40 | '--eval',
41 | type=str,
42 | nargs='+',
43 | help='evaluation metrics, which depends on the dataset, e.g., "mIoU"'
44 | ' for generic datasets, and "cityscapes" for Cityscapes')
45 | parser.add_argument('--show', action='store_true', help='show results')
46 | parser.add_argument(
47 | '--show-dir', help='directory where painted images will be saved')
48 | parser.add_argument(
49 | '--gpu-collect',
50 | action='store_true',
51 | help='whether to use gpu to collect results.')
52 | parser.add_argument(
53 | '--tmpdir',
54 | help='tmp directory used for collecting results from multiple '
55 | 'workers, available when gpu_collect is not specified')
56 | parser.add_argument(
57 | '--options', nargs='+', action=DictAction, help='custom options')
58 | parser.add_argument(
59 | '--eval-options',
60 | nargs='+',
61 | action=DictAction,
62 | help='custom options for evaluation')
63 | parser.add_argument(
64 | '--launcher',
65 | choices=['none', 'pytorch', 'slurm', 'mpi'],
66 | default='none',
67 | help='job launcher')
68 | parser.add_argument(
69 | '--opacity',
70 | type=float,
71 | default=0.5,
72 | help='Opacity of painted segmentation map. In (0, 1] range.')
73 | parser.add_argument('--local_rank', type=int, default=0)
74 | args = parser.parse_args()
75 | if 'LOCAL_RANK' not in os.environ:
76 | os.environ['LOCAL_RANK'] = str(args.local_rank)
77 | return args
78 |
79 |
80 | def main():
81 | args = parse_args()
82 |
83 | assert args.out or args.eval or args.format_only or args.show \
84 | or args.show_dir, \
85 | ('Please specify at least one operation (save/eval/format/show the '
86 | 'results / save the results) with the argument "--out", "--eval"'
87 | ', "--format-only", "--show" or "--show-dir"')
88 |
89 | if args.eval and args.format_only:
90 | raise ValueError('--eval and --format_only cannot be both specified')
91 |
92 | if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
93 | raise ValueError('The output file must be a pkl file.')
94 |
95 | cfg = mmcv.Config.fromfile(args.config)
96 | if args.options is not None:
97 | cfg.merge_from_dict(args.options)
98 | # set cudnn_benchmark
99 | if cfg.get('cudnn_benchmark', False):
100 | torch.backends.cudnn.benchmark = True
101 | if args.aug_test:
102 | # hard code index
103 | cfg.data.test.pipeline[1].img_ratios = [
104 | 0.5, 0.75, 1.0, 1.25, 1.5, 1.75
105 | ]
106 | cfg.data.test.pipeline[1].flip = True
107 | cfg.model.pretrained = None
108 | cfg.data.test.test_mode = True
109 |
110 | # init distributed env first, since logger depends on the dist info.
111 | if args.launcher == 'none':
112 | distributed = False
113 | else:
114 | distributed = True
115 | init_dist(args.launcher, **cfg.dist_params)
116 |
117 | rank, _ = get_dist_info()
118 | # allows not to create
119 | if args.work_dir is not None and rank == 0:
120 | mmcv.mkdir_or_exist(osp.abspath(args.work_dir))
121 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
122 | json_file = osp.join(args.work_dir, f'eval_{timestamp}.json')
123 |
124 | # build the dataloader
125 | # TODO: support multiple images per gpu (only minor changes are needed)
126 | dataset = build_dataset(cfg.data.test)
127 | data_loader = build_dataloader(
128 | dataset,
129 | samples_per_gpu=1,
130 | workers_per_gpu=cfg.data.workers_per_gpu,
131 | dist=distributed,
132 | shuffle=False)
133 |
134 | # build the model and load checkpoint
135 | cfg.model.train_cfg = None
136 | model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg'))
137 | fp16_cfg = cfg.get('fp16', None)
138 | if fp16_cfg is not None:
139 | wrap_fp16_model(model)
140 | checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
141 | if 'CLASSES' in checkpoint.get('meta', {}):
142 | model.CLASSES = checkpoint['meta']['CLASSES']
143 | else:
144 | print('"CLASSES" not found in meta, use dataset.CLASSES instead')
145 | model.CLASSES = dataset.CLASSES
146 | if 'PALETTE' in checkpoint.get('meta', {}):
147 | model.PALETTE = checkpoint['meta']['PALETTE']
148 | else:
149 | print('"PALETTE" not found in meta, use dataset.PALETTE instead')
150 | model.PALETTE = dataset.PALETTE
151 |
152 | # clean gpu memory when starting a new evaluation.
153 | torch.cuda.empty_cache()
154 | eval_kwargs = {} if args.eval_options is None else args.eval_options
155 |
156 | # Deprecated
157 | efficient_test = eval_kwargs.get('efficient_test', False)
158 | if efficient_test:
159 | warnings.warn(
160 | '``efficient_test=True`` does not have effect in tools/test.py, '
161 | 'the evaluation and format results are CPU memory efficient by '
162 | 'default')
163 |
164 | eval_on_format_results = (
165 | args.eval is not None and 'cityscapes' in args.eval)
166 | if eval_on_format_results:
167 | assert len(args.eval) == 1, 'eval on format results is not ' \
168 | 'applicable for metrics other than ' \
169 | 'cityscapes'
170 | if args.format_only or eval_on_format_results:
171 | if 'imgfile_prefix' in eval_kwargs:
172 | tmpdir = eval_kwargs['imgfile_prefix']
173 | else:
174 | tmpdir = '.format_cityscapes'
175 | eval_kwargs.setdefault('imgfile_prefix', tmpdir)
176 | mmcv.mkdir_or_exist(tmpdir)
177 | else:
178 | tmpdir = None
179 |
180 | if not distributed:
181 | model = MMDataParallel(model, device_ids=[0])
182 | results = single_gpu_test(
183 | model,
184 | data_loader,
185 | args.show,
186 | args.show_dir,
187 | False,
188 | args.opacity,
189 | pre_eval=args.eval is not None and not eval_on_format_results,
190 | format_only=args.format_only or eval_on_format_results,
191 | format_args=eval_kwargs)
192 | else:
193 | model = MMDistributedDataParallel(
194 | model.cuda(),
195 | device_ids=[torch.cuda.current_device()],
196 | broadcast_buffers=False)
197 | results = multi_gpu_test(
198 | model,
199 | data_loader,
200 | args.tmpdir,
201 | args.gpu_collect,
202 | False,
203 | pre_eval=args.eval is not None and not eval_on_format_results,
204 | format_only=args.format_only or eval_on_format_results,
205 | format_args=eval_kwargs)
206 |
207 | rank, _ = get_dist_info()
208 | if rank == 0:
209 | if args.out:
210 | warnings.warn(
211 | 'The behavior of ``args.out`` has been changed since MMSeg '
212 | 'v0.16, the pickled outputs could be seg map as type of '
213 | 'np.array, pre-eval results or file paths for '
214 | '``dataset.format_results()``.')
215 | print(f'\nwriting results to {args.out}')
216 | mmcv.dump(results, args.out)
217 | if args.eval:
218 | eval_kwargs.update(metric=args.eval)
219 | metric = dataset.evaluate(results, **eval_kwargs)
220 | metric_dict = dict(config=args.config, metric=metric)
221 | if args.work_dir is not None and rank == 0:
222 | mmcv.dump(metric_dict, json_file, indent=4)
223 | if tmpdir is not None and eval_on_format_results:
224 | # remove tmp dir when cityscapes evaluation
225 | shutil.rmtree(tmpdir)
226 |
227 |
228 | if __name__ == '__main__':
229 | main()
230 |
--------------------------------------------------------------------------------
/segmentation/tools/torchserve/mmseg2torchserve.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from argparse import ArgumentParser, Namespace
3 | from pathlib import Path
4 | from tempfile import TemporaryDirectory
5 |
6 | import mmcv
7 |
8 | try:
9 | from model_archiver.model_packaging import package_model
10 | from model_archiver.model_packaging_utils import ModelExportUtils
11 | except ImportError:
12 | package_model = None
13 |
14 |
15 | def mmseg2torchserve(
16 | config_file: str,
17 | checkpoint_file: str,
18 | output_folder: str,
19 | model_name: str,
20 | model_version: str = '1.0',
21 | force: bool = False,
22 | ):
23 | """Converts mmsegmentation model (config + checkpoint) to TorchServe
24 | `.mar`.
25 |
26 | Args:
27 | config_file:
28 | In MMSegmentation config format.
29 | The contents vary for each task repository.
30 | checkpoint_file:
31 | In MMSegmentation checkpoint format.
32 | The contents vary for each task repository.
33 | output_folder:
34 | Folder where `{model_name}.mar` will be created.
35 | The file created will be in TorchServe archive format.
36 | model_name:
37 | If not None, used for naming the `{model_name}.mar` file
38 | that will be created under `output_folder`.
39 | If None, `{Path(checkpoint_file).stem}` will be used.
40 | model_version:
41 | Model's version.
42 | force:
43 | If True, if there is an existing `{model_name}.mar`
44 | file under `output_folder` it will be overwritten.
45 | """
46 | mmcv.mkdir_or_exist(output_folder)
47 |
48 | config = mmcv.Config.fromfile(config_file)
49 |
50 | with TemporaryDirectory() as tmpdir:
51 | config.dump(f'{tmpdir}/config.py')
52 |
53 | args = Namespace(
54 | **{
55 | 'model_file': f'{tmpdir}/config.py',
56 | 'serialized_file': checkpoint_file,
57 | 'handler': f'{Path(__file__).parent}/mmseg_handler.py',
58 | 'model_name': model_name or Path(checkpoint_file).stem,
59 | 'version': model_version,
60 | 'export_path': output_folder,
61 | 'force': force,
62 | 'requirements_file': None,
63 | 'extra_files': None,
64 | 'runtime': 'python',
65 | 'archive_format': 'default'
66 | })
67 | manifest = ModelExportUtils.generate_manifest_json(args)
68 | package_model(args, manifest)
69 |
70 |
71 | def parse_args():
72 | parser = ArgumentParser(
73 | description='Convert mmseg models to TorchServe `.mar` format.')
74 | parser.add_argument('config', type=str, help='config file path')
75 | parser.add_argument('checkpoint', type=str, help='checkpoint file path')
76 | parser.add_argument(
77 | '--output-folder',
78 | type=str,
79 | required=True,
80 | help='Folder where `{model_name}.mar` will be created.')
81 | parser.add_argument(
82 | '--model-name',
83 | type=str,
84 | default=None,
85 | help='If not None, used for naming the `{model_name}.mar`'
86 | 'file that will be created under `output_folder`.'
87 | 'If None, `{Path(checkpoint_file).stem}` will be used.')
88 | parser.add_argument(
89 | '--model-version',
90 | type=str,
91 | default='1.0',
92 | help='Number used for versioning.')
93 | parser.add_argument(
94 | '-f',
95 | '--force',
96 | action='store_true',
97 | help='overwrite the existing `{model_name}.mar`')
98 | args = parser.parse_args()
99 |
100 | return args
101 |
102 |
103 | if __name__ == '__main__':
104 | args = parse_args()
105 |
106 | if package_model is None:
107 | raise ImportError('`torch-model-archiver` is required.'
108 | 'Try: pip install torch-model-archiver')
109 |
110 | mmseg2torchserve(args.config, args.checkpoint, args.output_folder,
111 | args.model_name, args.model_version, args.force)
112 |
--------------------------------------------------------------------------------
/segmentation/tools/torchserve/mmseg_handler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import base64
3 | import os
4 |
5 | import cv2
6 | import mmcv
7 | import torch
8 | from mmcv.cnn.utils.sync_bn import revert_sync_batchnorm
9 | from ts.torch_handler.base_handler import BaseHandler
10 |
11 | from mmseg.apis import inference_segmentor, init_segmentor
12 |
13 |
14 | class MMsegHandler(BaseHandler):
15 |
16 | def initialize(self, context):
17 | properties = context.system_properties
18 | self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu'
19 | self.device = torch.device(self.map_location + ':' +
20 | str(properties.get('gpu_id')) if torch.cuda.
21 | is_available() else self.map_location)
22 | self.manifest = context.manifest
23 |
24 | model_dir = properties.get('model_dir')
25 | serialized_file = self.manifest['model']['serializedFile']
26 | checkpoint = os.path.join(model_dir, serialized_file)
27 | self.config_file = os.path.join(model_dir, 'config.py')
28 |
29 | self.model = init_segmentor(self.config_file, checkpoint, self.device)
30 | self.model = revert_sync_batchnorm(self.model)
31 | self.initialized = True
32 |
33 | def preprocess(self, data):
34 | images = []
35 |
36 | for row in data:
37 | image = row.get('data') or row.get('body')
38 | if isinstance(image, str):
39 | image = base64.b64decode(image)
40 | image = mmcv.imfrombytes(image)
41 | images.append(image)
42 |
43 | return images
44 |
45 | def inference(self, data, *args, **kwargs):
46 | results = [inference_segmentor(self.model, img) for img in data]
47 | return results
48 |
49 | def postprocess(self, data):
50 | output = []
51 |
52 | for image_result in data:
53 | _, buffer = cv2.imencode('.png', image_result[0].astype('uint8'))
54 | bast64_data = base64.b64encode(buffer.tobytes())
55 | bast64_str = str(bast64_data, 'utf-8')
56 | output.append(bast64_str)
57 | return output
58 |
--------------------------------------------------------------------------------
/segmentation/tools/torchserve/test_torchserve.py:
--------------------------------------------------------------------------------
1 | import base64
2 | from argparse import ArgumentParser
3 | from io import BytesIO
4 |
5 | import matplotlib.pyplot as plt
6 | import mmcv
7 | import requests
8 |
9 | from mmseg.apis import inference_segmentor, init_segmentor
10 |
11 |
12 | def parse_args():
13 | parser = ArgumentParser(
14 | description='Compare result of torchserve and pytorch,'
15 | 'and visualize them.')
16 | parser.add_argument('img', help='Image file')
17 | parser.add_argument('config', help='Config file')
18 | parser.add_argument('checkpoint', help='Checkpoint file')
19 | parser.add_argument('model_name', help='The model name in the server')
20 | parser.add_argument(
21 | '--inference-addr',
22 | default='127.0.0.1:8080',
23 | help='Address and port of the inference server')
24 | parser.add_argument(
25 | '--result-image',
26 | type=str,
27 | default=None,
28 | help='save server output in result-image')
29 | parser.add_argument(
30 | '--device', default='cuda:0', help='Device used for inference')
31 |
32 | args = parser.parse_args()
33 | return args
34 |
35 |
36 | def main(args):
37 | url = 'http://' + args.inference_addr + '/predictions/' + args.model_name
38 | with open(args.img, 'rb') as image:
39 | tmp_res = requests.post(url, image)
40 | base64_str = tmp_res.content
41 | buffer = base64.b64decode(base64_str)
42 | if args.result_image:
43 | with open(args.result_image, 'wb') as out_image:
44 | out_image.write(buffer)
45 | plt.imshow(mmcv.imread(args.result_image, 'grayscale'))
46 | plt.show()
47 | else:
48 | plt.imshow(plt.imread(BytesIO(buffer)))
49 | plt.show()
50 | model = init_segmentor(args.config, args.checkpoint, args.device)
51 | image = mmcv.imread(args.img)
52 | result = inference_segmentor(model, image)
53 | plt.imshow(result[0])
54 | plt.show()
55 |
56 |
57 | if __name__ == '__main__':
58 | args = parse_args()
59 | main(args)
60 |
--------------------------------------------------------------------------------
/segmentation/tools/train.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import argparse
3 | import copy
4 | import os
5 | import os.path as osp
6 | import time
7 | import warnings
8 |
9 | import mmcv
10 | import torch
11 | from mmcv.cnn.utils import revert_sync_batchnorm
12 | from mmcv.runner import get_dist_info, init_dist
13 | from mmcv.utils import Config, DictAction, get_git_hash
14 |
15 | from mmseg import __version__
16 | from mmseg.apis import set_random_seed, train_segmentor
17 | from mmseg.datasets import build_dataset
18 | from mmseg.models import build_segmentor
19 | from mmseg.utils import collect_env, get_root_logger
20 | from backbone import vit_PASS
21 |
22 | def parse_args():
23 | parser = argparse.ArgumentParser(description='Train a segmentor')
24 | parser.add_argument('config', help='train config file path')
25 | parser.add_argument('--work-dir', help='the dir to save logs and models')
26 | parser.add_argument(
27 | '--load-from', help='the checkpoint file to load weights from')
28 | parser.add_argument(
29 | '--resume-from', help='the checkpoint file to resume from')
30 | parser.add_argument(
31 | '--no-validate',
32 | action='store_true',
33 | help='whether not to evaluate the checkpoint during training')
34 | group_gpus = parser.add_mutually_exclusive_group()
35 | group_gpus.add_argument(
36 | '--gpus',
37 | type=int,
38 | help='number of gpus to use '
39 | '(only applicable to non-distributed training)')
40 | group_gpus.add_argument(
41 | '--gpu-ids',
42 | type=int,
43 | nargs='+',
44 | help='ids of gpus to use '
45 | '(only applicable to non-distributed training)')
46 | parser.add_argument('--seed', type=int, default=None, help='random seed')
47 | parser.add_argument(
48 | '--deterministic',
49 | action='store_true',
50 | help='whether to set deterministic options for CUDNN backend.')
51 | parser.add_argument(
52 | '--options', nargs='+', action=DictAction, help='custom options')
53 | parser.add_argument(
54 | '--launcher',
55 | choices=['none', 'pytorch', 'slurm', 'mpi'],
56 | default='none',
57 | help='job launcher')
58 | parser.add_argument('--local_rank', type=int, default=0)
59 | args = parser.parse_args()
60 | if 'LOCAL_RANK' not in os.environ:
61 | os.environ['LOCAL_RANK'] = str(args.local_rank)
62 |
63 | return args
64 |
65 |
66 | def main():
67 | args = parse_args()
68 |
69 | cfg = Config.fromfile(args.config)
70 | if args.options is not None:
71 | cfg.merge_from_dict(args.options)
72 | # set cudnn_benchmark
73 | if cfg.get('cudnn_benchmark', False):
74 | torch.backends.cudnn.benchmark = True
75 |
76 | # work_dir is determined in this priority: CLI > segment in file > filename
77 | if args.work_dir is not None:
78 | # update configs according to CLI args if args.work_dir is not None
79 | cfg.work_dir = args.work_dir
80 | elif cfg.get('work_dir', None) is None:
81 | # use config filename as default work_dir if cfg.work_dir is None
82 | cfg.work_dir = osp.join('./work_dirs',
83 | osp.splitext(osp.basename(args.config))[0])
84 | if args.load_from is not None:
85 | cfg.load_from = args.load_from
86 | if args.resume_from is not None:
87 | cfg.resume_from = args.resume_from
88 | if args.gpu_ids is not None:
89 | cfg.gpu_ids = args.gpu_ids
90 | else:
91 | cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
92 |
93 | # init distributed env first, since logger depends on the dist info.
94 | if args.launcher == 'none':
95 | distributed = False
96 | else:
97 | distributed = True
98 | init_dist(args.launcher, **cfg.dist_params)
99 | # gpu_ids is used to calculate iter when resuming checkpoint,
100 | _, world_size = get_dist_info()
101 | cfg.gpu_ids = range(world_size)
102 |
103 | # create work_dir
104 | mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
105 | # dump config
106 | cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
107 | # init the logger before other steps
108 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
109 | log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
110 | logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
111 |
112 | # init the meta dict to record some important information such as
113 | # environment info and seed, which will be logged
114 | meta = dict()
115 | # log env info
116 | env_info_dict = collect_env()
117 | env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()])
118 | dash_line = '-' * 60 + '\n'
119 | logger.info('Environment info:\n' + dash_line + env_info + '\n' +
120 | dash_line)
121 | meta['env_info'] = env_info
122 |
123 | # log some basic info
124 | logger.info(f'Distributed training: {distributed}')
125 | logger.info(f'Config:\n{cfg.pretty_text}')
126 |
127 | # set random seeds
128 | if args.seed is not None:
129 | logger.info(f'Set random seed to {args.seed}, deterministic: '
130 | f'{args.deterministic}')
131 | set_random_seed(args.seed, deterministic=args.deterministic)
132 | cfg.seed = args.seed
133 | meta['seed'] = args.seed
134 | meta['exp_name'] = osp.basename(args.config)
135 |
136 | model = build_segmentor(
137 | cfg.model,
138 | train_cfg=cfg.get('train_cfg'),
139 | test_cfg=cfg.get('test_cfg'))
140 | model.init_weights()
141 |
142 | # SyncBN is not support for DP
143 | if not distributed:
144 | warnings.warn(
145 | 'SyncBN is only supported with DDP. To be compatible with DP, '
146 | 'we convert SyncBN to BN. Please use dist_train.sh which can '
147 | 'avoid this error.')
148 | model = revert_sync_batchnorm(model)
149 |
150 | logger.info(model)
151 |
152 | datasets = [build_dataset(cfg.data.train)]
153 | if len(cfg.workflow) == 2:
154 | val_dataset = copy.deepcopy(cfg.data.val)
155 | val_dataset.pipeline = cfg.data.train.pipeline
156 | datasets.append(build_dataset(val_dataset))
157 | if cfg.checkpoint_config is not None:
158 | # save mmseg version, config file content and class names in
159 | # checkpoints as meta data
160 | cfg.checkpoint_config.meta = dict(
161 | mmseg_version=f'{__version__}+{get_git_hash()[:7]}',
162 | config=cfg.pretty_text,
163 | CLASSES=datasets[0].CLASSES,
164 | PALETTE=datasets[0].PALETTE)
165 | # add an attribute for visualization convenience
166 | model.CLASSES = datasets[0].CLASSES
167 | # passing checkpoint meta for saving best checkpoint
168 | meta.update(cfg.checkpoint_config.meta)
169 | train_segmentor(
170 | model,
171 | datasets,
172 | cfg,
173 | distributed=distributed,
174 | validate=(not args.no_validate),
175 | timestamp=timestamp,
176 | meta=meta)
177 |
178 |
179 | if __name__ == '__main__':
180 | main()
181 |
--------------------------------------------------------------------------------