├── LICENSE ├── README.md ├── detection ├── README.md ├── backbone │ └── vit_SelfPatch.py ├── configs │ ├── _base_ │ │ ├── datasets │ │ │ ├── coco_detection.py │ │ │ ├── coco_instance.py │ │ │ └── coco_instance_semantic.py │ │ ├── default_runtime.py │ │ ├── models │ │ │ └── mask_rcnn_selfpatch_p16.py │ │ └── schedules │ │ │ ├── schedule_1x.py │ │ │ ├── schedule_20e.py │ │ │ └── schedule_2x.py │ └── selfpatch │ │ └── mask_rcnn_vit_small_12_p16_1x_coco.py └── tools │ ├── dist_test.sh │ ├── dist_train.sh │ ├── slurm_test.sh │ ├── slurm_train.sh │ ├── test.py │ └── train.py ├── eval_video_segmentation.py ├── main_selfpatch.py ├── segmentation ├── README.md ├── backbones │ ├── __init__.py │ ├── bisenetv1.py │ ├── bisenetv2.py │ ├── cgnet.py │ ├── fast_scnn.py │ ├── hrnet.py │ ├── mit.py │ ├── mobilenet_v2.py │ ├── mobilenet_v3.py │ ├── resnest.py │ ├── resnet.py │ ├── resnext.py │ ├── swin.py │ ├── unet.py │ └── vit_SelfPatch.py ├── configs │ ├── _base_ │ │ ├── datasets │ │ │ └── ade20k.py │ │ ├── default_runtime.py │ │ └── schedules │ │ │ └── schedule_40k.py │ └── semfpn_vit-s16_512x512_40k_ade20k.py └── tools │ ├── analyze_logs.py │ ├── benchmark.py │ ├── browse_dataset.py │ ├── convert_datasets │ ├── chase_db1.py │ ├── cityscapes.py │ ├── coco_stuff10k.py │ ├── coco_stuff164k.py │ ├── drive.py │ ├── hrf.py │ ├── pascal_context.py │ ├── stare.py │ └── voc_aug.py │ ├── deploy_test.py │ ├── dist_test.sh │ ├── dist_train.sh │ ├── get_flops.py │ ├── model_converters │ ├── mit2mmseg.py │ ├── swin2mmseg.py │ └── vit2mmseg.py │ ├── onnx2tensorrt.py │ ├── print_config.py │ ├── publish_model.py │ ├── pytorch2onnx.py │ ├── pytorch2torchscript.py │ ├── slurm_test.sh │ ├── slurm_train.sh │ ├── test.py │ ├── torchserve │ ├── mmseg2torchserve.py │ ├── mmseg_handler.py │ └── test_torchserve.py │ └── train.py ├── selfpatch_vision_transformer.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Patch-level Representation Learning for Self-supervised Vision Transformers (SelfPatch) 2 | 3 | PyTorch implementation for "Patch-level Representation Learning for Self-supervised Vision Transformers" (accepted Oral presentation in CVPR 2022) 4 | 5 |

6 | thumbnail 7 |

8 | 9 | ## Requirements 10 | - `torch==1.7.0` 11 | - `torchvision==0.8.1` 12 | 13 | ## Pretraining on ImageNet 14 | ``` 15 | python -m torch.distributed.launch --nproc_per_node=8 main_selfpatch.py --arch vit_small --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir --local_crops_number 8 --patch_size 16 --batch_size_per_gpu 128 --out_dim_selfpatch 4096 --k_num 4 16 | ``` 17 | 18 | ## Pretrained weights on ImageNet 19 | You can download the weights of the pretrained models on ImageNet. All models are trained on `ViT-S/16`. For detection and segmentation downstream tasks, please check SelfPatch/detection, SelfPatch/segmentation. 20 | 21 | | backbone | arch | checkpoint | 22 | | ------------- | ------------- | ------------- | 23 | | DINO | ViT-S/16 | download (pretrained model from VISSL) | 24 | | DINO + SelfPatch | ViT-S/16 | download | 25 | 26 | ## Evaluating video object segmentation on the DAVIS 2017 dataset 27 | Step 1. Prepare DAVIS 2017 data 28 | 29 | ``` 30 | cd $HOME 31 | git clone https://github.com/davisvideochallenge/davis-2017 32 | cd davis-2017 33 | ./data/get_davis.sh 34 | ``` 35 | 36 | Step 2. Run Video object segmentation 37 | 38 | ``` 39 | python eval_video_segmentation.py --data_path /path/to/davis-2017/DAVIS/ --output_dir /path/to/saving_dir --pretrained_weights /path/to/model_dir --arch vit_small --patch_size 16 40 | ``` 41 | 42 | Step 3. Evaluate the obtained segmentation 43 | 44 | ``` 45 | git clone https://github.com/davisvideochallenge/davis2017-evaluation 46 | $HOME/davis2017-evaluation 47 | python /path/to/davis2017-evaluation/evaluation_method.py --task semi-supervised --davis_path /path/to/davis-2017/DAVIS --results_path /path/to/saving_dir 48 | ``` 49 | 50 | ### Video object segmentation examples on the DAVIS 2017 dataset 51 | 52 | Video (left), DINO (middle) and our SelfPatch (right) 53 |

54 | img 55 |

56 | 57 | 58 | ## Acknowledgement 59 | Our code base is built partly upon the packages: 60 | DINO, mmdetection, mmsegmentation and XCiT 61 | 62 | ## Citation 63 | If you use this code for your research, please cite our papers. 64 | ``` 65 | @InProceedings{Yun_2022_CVPR, 66 | author = {Yun, Sukmin and Lee, Hankook and Kim, Jaehyung and Shin, Jinwoo}, 67 | title = {Patch-Level Representation Learning for Self-Supervised Vision Transformers}, 68 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 69 | month = {June}, 70 | year = {2022}, 71 | pages = {8354-8363} 72 | } 73 | ``` 74 | -------------------------------------------------------------------------------- /detection/README.md: -------------------------------------------------------------------------------- 1 | ## Evaluating object detection and instance segmentation on the COCO dataset 2 | Step 1. Prepare COCO dataset 3 | 4 | The dataset can be downloaded at `https://cocodataset.org/#download` 5 | 6 | Step 2. Install mmdetection 7 | 8 | ``` 9 | git clone https://github.com/open-mmlab/mmdetection.git 10 | ``` 11 | 12 | Step 3. Fine-tune on the COCO dataset 13 | 14 | ``` 15 | tools/dist_train.sh configs/selfpatch/mask_rcnn_vit_small_12_p16_1x_coco.py [number of gpu] --work-dir /path/to/saving_dir --seed 0 --deterministic --options model.pretrained=/path/to/model_dir 16 | ``` 17 | 18 | ## Pretrained weights on MS-COCO 19 | You can download the weights of the fine-tuned models on object detection and instance segmentation tasks. All models are fine-tuned with `Mask R-CNN`. 20 | 21 | | backbone | arch | bbox mAP | mask mAP | checkpoint | 22 | | ------------- | ------------- | ------------- | ------------- | ------------- | 23 | | DINO | ViT-S/16 + Mask R-CNN | 40.8 | 37.3 | download | 24 | | DINO + SelfPatch | ViT-S/16 + Mask R-CNN | 42.1 | 38.5 | download | 25 | 26 | ## Acknowledgement 27 | This code is built using the mmdetection libray. 28 | -------------------------------------------------------------------------------- /detection/configs/_base_/datasets/coco_detection.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | dataset_type = 'CocoDataset' 4 | data_root = 'data/coco/' 5 | img_norm_cfg = dict( 6 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 7 | train_pipeline = [ 8 | dict(type='LoadImageFromFile'), 9 | dict(type='LoadAnnotations', with_bbox=True), 10 | dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), 11 | dict(type='RandomFlip', flip_ratio=0.5), 12 | dict(type='Normalize', **img_norm_cfg), 13 | dict(type='Pad', size_divisor=32), 14 | dict(type='DefaultFormatBundle'), 15 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), 16 | ] 17 | test_pipeline = [ 18 | dict(type='LoadImageFromFile'), 19 | dict( 20 | type='MultiScaleFlipAug', 21 | img_scale=(1333, 800), 22 | flip=False, 23 | transforms=[ 24 | dict(type='Resize', keep_ratio=True), 25 | dict(type='RandomFlip'), 26 | dict(type='Normalize', **img_norm_cfg), 27 | dict(type='Pad', size_divisor=32), 28 | dict(type='ImageToTensor', keys=['img']), 29 | dict(type='Collect', keys=['img']), 30 | ]) 31 | ] 32 | data = dict( 33 | samples_per_gpu=2, 34 | workers_per_gpu=2, 35 | train=dict( 36 | type=dataset_type, 37 | ann_file=data_root + 'annotations/instances_train2017.json', 38 | img_prefix=data_root + 'train2017/', 39 | pipeline=train_pipeline), 40 | val=dict( 41 | type=dataset_type, 42 | ann_file=data_root + 'annotations/instances_val2017.json', 43 | img_prefix=data_root + 'val2017/', 44 | pipeline=test_pipeline), 45 | test=dict( 46 | type=dataset_type, 47 | ann_file=data_root + 'annotations/instances_val2017.json', 48 | img_prefix=data_root + 'val2017/', 49 | pipeline=test_pipeline)) 50 | evaluation = dict(interval=1, metric='bbox') 51 | -------------------------------------------------------------------------------- /detection/configs/_base_/datasets/coco_instance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | dataset_type = 'CocoDataset' 4 | data_root = '/data/COCO/' 5 | img_norm_cfg = dict( 6 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 7 | train_pipeline = [ 8 | dict(type='LoadImageFromFile'), 9 | dict(type='LoadAnnotations', with_bbox=True, with_mask=True), 10 | dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), 11 | dict(type='RandomFlip', flip_ratio=0.5), 12 | dict(type='Normalize', **img_norm_cfg), 13 | dict(type='Pad', size_divisor=32), 14 | dict(type='DefaultFormatBundle'), 15 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), 16 | ] 17 | test_pipeline = [ 18 | dict(type='LoadImageFromFile'), 19 | dict( 20 | type='MultiScaleFlipAug', 21 | img_scale=(1333, 800), 22 | flip=False, 23 | transforms=[ 24 | dict(type='Resize', keep_ratio=True), 25 | dict(type='RandomFlip'), 26 | dict(type='Normalize', **img_norm_cfg), 27 | dict(type='Pad', size_divisor=32), 28 | dict(type='ImageToTensor', keys=['img']), 29 | dict(type='Collect', keys=['img']), 30 | ]) 31 | ] 32 | data = dict( 33 | samples_per_gpu=2, 34 | workers_per_gpu=2, 35 | train=dict( 36 | type=dataset_type, 37 | ann_file=data_root + 'annotations/instances_train2017.json', 38 | img_prefix=data_root + 'train2017/', 39 | pipeline=train_pipeline), 40 | val=dict( 41 | type=dataset_type, 42 | ann_file=data_root + 'annotations/instances_val2017.json', 43 | img_prefix=data_root + 'val2017/', 44 | pipeline=test_pipeline), 45 | test=dict( 46 | type=dataset_type, 47 | ann_file=data_root + 'annotations/instances_val2017.json', 48 | img_prefix=data_root + 'val2017/', 49 | pipeline=test_pipeline)) 50 | evaluation = dict(metric=['bbox', 'segm']) 51 | -------------------------------------------------------------------------------- /detection/configs/_base_/datasets/coco_instance_semantic.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | dataset_type = 'CocoDataset' 4 | data_root = 'data/coco/' 5 | img_norm_cfg = dict( 6 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 7 | train_pipeline = [ 8 | dict(type='LoadImageFromFile'), 9 | dict( 10 | type='LoadAnnotations', with_bbox=True, with_mask=True, with_seg=True), 11 | dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), 12 | dict(type='RandomFlip', flip_ratio=0.5), 13 | dict(type='Normalize', **img_norm_cfg), 14 | dict(type='Pad', size_divisor=32), 15 | dict(type='SegRescale', scale_factor=1 / 8), 16 | dict(type='DefaultFormatBundle'), 17 | dict( 18 | type='Collect', 19 | keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg']), 20 | ] 21 | test_pipeline = [ 22 | dict(type='LoadImageFromFile'), 23 | dict( 24 | type='MultiScaleFlipAug', 25 | img_scale=(1333, 800), 26 | flip=False, 27 | transforms=[ 28 | dict(type='Resize', keep_ratio=True), 29 | dict(type='RandomFlip', flip_ratio=0.5), 30 | dict(type='Normalize', **img_norm_cfg), 31 | dict(type='Pad', size_divisor=32), 32 | dict(type='ImageToTensor', keys=['img']), 33 | dict(type='Collect', keys=['img']), 34 | ]) 35 | ] 36 | data = dict( 37 | samples_per_gpu=2, 38 | workers_per_gpu=2, 39 | train=dict( 40 | type=dataset_type, 41 | ann_file=data_root + 'annotations/instances_train2017.json', 42 | img_prefix=data_root + 'train2017/', 43 | seg_prefix=data_root + 'stuffthingmaps/train2017/', 44 | pipeline=train_pipeline), 45 | val=dict( 46 | type=dataset_type, 47 | ann_file=data_root + 'annotations/instances_val2017.json', 48 | img_prefix=data_root + 'val2017/', 49 | pipeline=test_pipeline), 50 | test=dict( 51 | type=dataset_type, 52 | ann_file=data_root + 'annotations/instances_val2017.json', 53 | img_prefix=data_root + 'val2017/', 54 | pipeline=test_pipeline)) 55 | evaluation = dict(metric=['bbox', 'segm']) 56 | -------------------------------------------------------------------------------- /detection/configs/_base_/default_runtime.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | checkpoint_config = dict(interval=1) 4 | # yapf:disable 5 | log_config = dict( 6 | interval=50, 7 | hooks=[ 8 | dict(type='TextLoggerHook'), 9 | # dict(type='TensorboardLoggerHook') 10 | ]) 11 | # yapf:enable 12 | custom_hooks = [dict(type='NumClassCheckHook')] 13 | 14 | dist_params = dict(backend='nccl') 15 | log_level = 'INFO' 16 | load_from = None 17 | resume_from = None 18 | workflow = [('train', 1)] 19 | -------------------------------------------------------------------------------- /detection/configs/_base_/models/mask_rcnn_selfpatch_p16.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # model settings 4 | model = dict( 5 | type='MaskRCNN', 6 | pretrained=None, 7 | backbone=dict( 8 | type='SelfPatch_ViT', 9 | patch_size=16, 10 | embed_dim=384, 11 | depth=12, 12 | num_heads=8, 13 | mlp_ratio=4, 14 | qkv_bias=True, 15 | ), 16 | neck=dict( 17 | type='FPN', 18 | in_channels=[384, 384, 384, 384], 19 | out_channels=256, 20 | num_outs=5), 21 | rpn_head=dict( 22 | type='RPNHead', 23 | in_channels=256, 24 | feat_channels=256, 25 | anchor_generator=dict( 26 | type='AnchorGenerator', 27 | scales=[8], 28 | ratios=[0.5, 1.0, 2.0], 29 | strides=[4, 8, 16, 32, 64]), 30 | bbox_coder=dict( 31 | type='DeltaXYWHBBoxCoder', 32 | target_means=[.0, .0, .0, .0], 33 | target_stds=[1.0, 1.0, 1.0, 1.0]), 34 | loss_cls=dict( 35 | type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), 36 | loss_bbox=dict(type='L1Loss', loss_weight=1.0)), 37 | roi_head=dict( 38 | type='StandardRoIHead', 39 | bbox_roi_extractor=dict( 40 | type='SingleRoIExtractor', 41 | roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), 42 | out_channels=256, 43 | featmap_strides=[16, 16, 16, 16]), 44 | bbox_head=dict( 45 | type='Shared2FCBBoxHead', 46 | in_channels=256, 47 | fc_out_channels=1024, 48 | roi_feat_size=7, 49 | num_classes=80, 50 | bbox_coder=dict( 51 | type='DeltaXYWHBBoxCoder', 52 | target_means=[0., 0., 0., 0.], 53 | target_stds=[0.1, 0.1, 0.2, 0.2]), 54 | reg_class_agnostic=False, 55 | loss_cls=dict( 56 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), 57 | loss_bbox=dict(type='L1Loss', loss_weight=1.0)), 58 | mask_roi_extractor=dict( 59 | type='SingleRoIExtractor', 60 | roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0), 61 | out_channels=256, 62 | featmap_strides=[16, 16, 16, 16]), 63 | mask_head=dict( 64 | type='FCNMaskHead', 65 | num_convs=4, 66 | in_channels=256, 67 | conv_out_channels=256, 68 | num_classes=80, 69 | loss_mask=dict( 70 | type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))), 71 | # model training and testing settings 72 | train_cfg=dict( 73 | rpn=dict( 74 | assigner=dict( 75 | type='MaxIoUAssigner', 76 | pos_iou_thr=0.7, 77 | neg_iou_thr=0.3, 78 | min_pos_iou=0.3, 79 | match_low_quality=True, 80 | ignore_iof_thr=-1), 81 | sampler=dict( 82 | type='RandomSampler', 83 | num=256, 84 | pos_fraction=0.5, 85 | neg_pos_ub=-1, 86 | add_gt_as_proposals=False), 87 | allowed_border=-1, 88 | pos_weight=-1, 89 | debug=False), 90 | rpn_proposal=dict( 91 | nms_pre=2000, 92 | max_per_img=1000, 93 | nms=dict(type='nms', iou_threshold=0.7), 94 | min_bbox_size=0), 95 | rcnn=dict( 96 | assigner=dict( 97 | type='MaxIoUAssigner', 98 | pos_iou_thr=0.5, 99 | neg_iou_thr=0.5, 100 | min_pos_iou=0.5, 101 | match_low_quality=True, 102 | ignore_iof_thr=-1), 103 | sampler=dict( 104 | type='RandomSampler', 105 | num=512, 106 | pos_fraction=0.25, 107 | neg_pos_ub=-1, 108 | add_gt_as_proposals=True), 109 | mask_size=28, 110 | pos_weight=-1, 111 | debug=False)), 112 | test_cfg=dict( 113 | rpn=dict( 114 | nms_pre=1000, 115 | max_per_img=1000, 116 | nms=dict(type='nms', iou_threshold=0.7), 117 | min_bbox_size=0), 118 | rcnn=dict( 119 | score_thr=0.05, 120 | nms=dict(type='nms', iou_threshold=0.5), 121 | max_per_img=100, 122 | mask_thr_binary=0.5))) 123 | -------------------------------------------------------------------------------- /detection/configs/_base_/schedules/schedule_1x.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # optimizer 4 | optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) 5 | optimizer_config = dict(grad_clip=None) 6 | # learning policy 7 | lr_config = dict( 8 | policy='step', 9 | warmup='linear', 10 | warmup_iters=500, 11 | warmup_ratio=0.001, 12 | step=[8, 11]) 13 | runner = dict(type='EpochBasedRunner', max_epochs=12) 14 | -------------------------------------------------------------------------------- /detection/configs/_base_/schedules/schedule_20e.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # optimizer 4 | optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) 5 | optimizer_config = dict(grad_clip=None) 6 | # learning policy 7 | lr_config = dict( 8 | policy='step', 9 | warmup='linear', 10 | warmup_iters=500, 11 | warmup_ratio=0.001, 12 | step=[16, 19]) 13 | runner = dict(type='EpochBasedRunner', max_epochs=20) 14 | -------------------------------------------------------------------------------- /detection/configs/_base_/schedules/schedule_2x.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | # optimizer 4 | optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) 5 | optimizer_config = dict(grad_clip=None) 6 | # learning policy 7 | lr_config = dict( 8 | policy='step', 9 | warmup='linear', 10 | warmup_iters=500, 11 | warmup_ratio=0.001, 12 | step=[16, 22]) 13 | runner = dict(type='EpochBasedRunner', max_epochs=24) 14 | -------------------------------------------------------------------------------- /detection/configs/selfpatch/mask_rcnn_vit_small_12_p16_1x_coco.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Hyperparameters modifed from 5 | https://github.com/SwinTransformer/Swin-Transformer-Object-Detection 6 | """ 7 | 8 | _base_ = [ 9 | '../_base_/models/mask_rcnn_selfpatch_p16.py', 10 | '../_base_/datasets/coco_instance.py', 11 | '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' 12 | ] 13 | 14 | model = dict( 15 | backbone=dict( 16 | type='SelfPatch_ViT', 17 | patch_size=16, 18 | embed_dim=384, 19 | depth=12, 20 | num_heads=8, 21 | mlp_ratio=4, 22 | qkv_bias=True, 23 | eta=1.0, 24 | drop_path_rate=0.05, 25 | out_indices=[3, 5, 7, 11] 26 | ), 27 | neck=dict(in_channels=[384, 384, 384, 384]), 28 | roi_head=dict( 29 | type='StandardRoIHead', 30 | bbox_roi_extractor=dict( 31 | type='SingleRoIExtractor', 32 | roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), 33 | out_channels=256, 34 | featmap_strides=[4, 8, 16, 32]), 35 | mask_roi_extractor=dict( 36 | type='SingleRoIExtractor', 37 | roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0), 38 | out_channels=256, 39 | featmap_strides=[4, 8, 16, 32]) 40 | ), 41 | ) 42 | 43 | img_norm_cfg = dict( 44 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 45 | 46 | # augmentation strategy originates from DETR / Sparse RCNN 47 | train_pipeline = [ 48 | dict(type='LoadImageFromFile'), 49 | dict(type='LoadAnnotations', with_bbox=True, with_mask=True), 50 | dict(type='RandomFlip', flip_ratio=0.5), 51 | dict(type='AutoAugment', 52 | policies=[ 53 | [ 54 | dict(type='Resize', 55 | img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), 56 | (608, 1333), (640, 1333), (672, 1333), (704, 1333), 57 | (736, 1333), (768, 1333), (800, 1333)], 58 | multiscale_mode='value', 59 | keep_ratio=True) 60 | ], 61 | [ 62 | dict(type='Resize', 63 | img_scale=[(400, 1333), (500, 1333), (600, 1333)], 64 | multiscale_mode='value', 65 | keep_ratio=True), 66 | dict(type='RandomCrop', 67 | crop_type='absolute_range', 68 | crop_size=(384, 600), 69 | allow_negative_crop=True), 70 | dict(type='Resize', 71 | img_scale=[(480, 1333), (512, 1333), (544, 1333), 72 | (576, 1333), (608, 1333), (640, 1333), 73 | (672, 1333), (704, 1333), (736, 1333), 74 | (768, 1333), (800, 1333)], 75 | multiscale_mode='value', 76 | override=True, 77 | keep_ratio=True) 78 | ] 79 | ]), 80 | dict(type='Normalize', **img_norm_cfg), 81 | dict(type='Pad', size_divisor=32), 82 | dict(type='DefaultFormatBundle'), 83 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']), 84 | ] 85 | data = dict(samples_per_gpu=1,train=dict(pipeline=train_pipeline)) 86 | 87 | optimizer = dict(_delete_=True, type='AdamW', lr=0.0001/2, betas=(0.9, 0.999), weight_decay=0.05, 88 | paramwise_cfg=dict(custom_keys={'pos_embed': dict(decay_mult=0.), 89 | 'cls_token': dict(decay_mult=0.), 90 | 'norm': dict(decay_mult=0.)})) 91 | lr_config = dict(warmup_iters=500*2,step=[8, 11]) 92 | runner = dict(type='EpochBasedRunnerAmp', max_epochs=12) 93 | 94 | # do not use mmdet version fp16 95 | fp16 = None 96 | optimizer_config = dict( 97 | type="DistOptimizerHook", 98 | update_interval=1, 99 | grad_clip=None, 100 | coalesce=True, 101 | bucket_size_mb=-1, 102 | use_fp16=True, 103 | ) 104 | -------------------------------------------------------------------------------- /detection/tools/dist_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | CHECKPOINT=$2 5 | GPUS=$3 6 | PORT=${PORT:-29500} 7 | 8 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 9 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 10 | $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4} 11 | -------------------------------------------------------------------------------- /detection/tools/dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | GPUS=$2 5 | PORT=${PORT:-29500} 6 | 7 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 8 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 9 | $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3} 10 | -------------------------------------------------------------------------------- /detection/tools/slurm_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | CONFIG=$3 8 | CHECKPOINT=$4 9 | GPUS=${GPUS:-8} 10 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 12 | PY_ARGS=${@:5} 13 | SRUN_ARGS=${SRUN_ARGS:-""} 14 | 15 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 16 | srun -p ${PARTITION} \ 17 | --job-name=${JOB_NAME} \ 18 | --gres=gpu:${GPUS_PER_NODE} \ 19 | --ntasks=${GPUS} \ 20 | --ntasks-per-node=${GPUS_PER_NODE} \ 21 | --cpus-per-task=${CPUS_PER_TASK} \ 22 | --kill-on-bad-exit=1 \ 23 | ${SRUN_ARGS} \ 24 | python -u tools/test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS} 25 | -------------------------------------------------------------------------------- /detection/tools/slurm_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | CONFIG=$3 8 | WORK_DIR=$4 9 | GPUS=${GPUS:-8} 10 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 12 | SRUN_ARGS=${SRUN_ARGS:-""} 13 | PY_ARGS=${@:5} 14 | 15 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 16 | srun -p ${PARTITION} \ 17 | --job-name=${JOB_NAME} \ 18 | --gres=gpu:${GPUS_PER_NODE} \ 19 | --ntasks=${GPUS} \ 20 | --ntasks-per-node=${GPUS_PER_NODE} \ 21 | --cpus-per-task=${CPUS_PER_TASK} \ 22 | --mem 350GB \ 23 | --kill-on-bad-exit=0 \ 24 | ${SRUN_ARGS} \ 25 | python -u tools/train.py ${CONFIG} --work-dir=${WORK_DIR} --launcher="slurm" ${PY_ARGS} -------------------------------------------------------------------------------- /detection/tools/test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Testing script modified from 5 | https://github.com/open-mmlab/mmdetection 6 | """ 7 | import argparse 8 | import os 9 | import warnings 10 | 11 | import mmcv 12 | import torch 13 | from mmcv import Config, DictAction 14 | from mmcv.cnn import fuse_conv_bn 15 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel 16 | from mmcv.runner import (get_dist_info, init_dist, load_checkpoint, 17 | wrap_fp16_model) 18 | 19 | from mmdet.apis import multi_gpu_test, single_gpu_test 20 | from mmdet.datasets import (build_dataloader, build_dataset, 21 | replace_ImageToTensor) 22 | from mmdet.models import build_detector, build_backbone 23 | 24 | from backbone import xcit 25 | 26 | 27 | def parse_args(): 28 | parser = argparse.ArgumentParser( 29 | description='MMDet test (and eval) a model') 30 | parser.add_argument('config', help='test config file path') 31 | parser.add_argument('checkpoint', help='checkpoint file') 32 | parser.add_argument('--out', help='output result file in pickle format') 33 | parser.add_argument( 34 | '--fuse-conv-bn', 35 | action='store_true', 36 | help='Whether to fuse conv and bn, this will slightly increase' 37 | 'the inference speed') 38 | parser.add_argument( 39 | '--format-only', 40 | action='store_true', 41 | help='Format the output results without perform evaluation. It is' 42 | 'useful when you want to format the result to a specific format and ' 43 | 'submit it to the test server') 44 | parser.add_argument( 45 | '--eval', 46 | type=str, 47 | nargs='+', 48 | help='evaluation metrics, which depends on the dataset, e.g., "bbox",' 49 | ' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC') 50 | parser.add_argument('--show', action='store_true', help='show results') 51 | parser.add_argument( 52 | '--show-dir', help='directory where painted images will be saved') 53 | parser.add_argument( 54 | '--show-score-thr', 55 | type=float, 56 | default=0.3, 57 | help='score threshold (default: 0.3)') 58 | parser.add_argument( 59 | '--gpu-collect', 60 | action='store_true', 61 | help='whether to use gpu to collect results.') 62 | parser.add_argument( 63 | '--tmpdir', 64 | help='tmp directory used for collecting results from multiple ' 65 | 'workers, available when gpu-collect is not specified') 66 | parser.add_argument( 67 | '--cfg-options', 68 | nargs='+', 69 | action=DictAction, 70 | help='override some settings in the used config, the key-value pair ' 71 | 'in xxx=yyy format will be merged into config file. If the value to ' 72 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 73 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 74 | 'Note that the quotation marks are necessary and that no white space ' 75 | 'is allowed.') 76 | parser.add_argument( 77 | '--options', 78 | nargs='+', 79 | action=DictAction, 80 | help='custom options for evaluation, the key-value pair in xxx=yyy ' 81 | 'format will be kwargs for dataset.evaluate() function (deprecate), ' 82 | 'change to --eval-options instead.') 83 | parser.add_argument( 84 | '--eval-options', 85 | nargs='+', 86 | action=DictAction, 87 | help='custom options for evaluation, the key-value pair in xxx=yyy ' 88 | 'format will be kwargs for dataset.evaluate() function') 89 | parser.add_argument( 90 | '--launcher', 91 | choices=['none', 'pytorch', 'slurm', 'mpi'], 92 | default='none', 93 | help='job launcher') 94 | parser.add_argument('--local_rank', type=int, default=0) 95 | args = parser.parse_args() 96 | if 'LOCAL_RANK' not in os.environ: 97 | os.environ['LOCAL_RANK'] = str(args.local_rank) 98 | 99 | if args.options and args.eval_options: 100 | raise ValueError( 101 | '--options and --eval-options cannot be both ' 102 | 'specified, --options is deprecated in favor of --eval-options') 103 | if args.options: 104 | warnings.warn('--options is deprecated in favor of --eval-options') 105 | args.eval_options = args.options 106 | return args 107 | 108 | 109 | def main(): 110 | args = parse_args() 111 | 112 | assert args.out or args.eval or args.format_only or args.show \ 113 | or args.show_dir, \ 114 | ('Please specify at least one operation (save/eval/format/show the ' 115 | 'results / save the results) with the argument "--out", "--eval"' 116 | ', "--format-only", "--show" or "--show-dir"') 117 | 118 | if args.eval and args.format_only: 119 | raise ValueError('--eval and --format_only cannot be both specified') 120 | 121 | if args.out is not None and not args.out.endswith(('.pkl', '.pickle')): 122 | raise ValueError('The output file must be a pkl file.') 123 | 124 | cfg = Config.fromfile(args.config) 125 | if args.cfg_options is not None: 126 | cfg.merge_from_dict(args.cfg_options) 127 | # import modules from string list. 128 | if cfg.get('custom_imports', None): 129 | from mmcv.utils import import_modules_from_strings 130 | import_modules_from_strings(**cfg['custom_imports']) 131 | # set cudnn_benchmark 132 | if cfg.get('cudnn_benchmark', False): 133 | torch.backends.cudnn.benchmark = True 134 | cfg.model.pretrained = None 135 | if cfg.model.get('neck'): 136 | if isinstance(cfg.model.neck, list): 137 | for neck_cfg in cfg.model.neck: 138 | if neck_cfg.get('rfp_backbone'): 139 | if neck_cfg.rfp_backbone.get('pretrained'): 140 | neck_cfg.rfp_backbone.pretrained = None 141 | elif cfg.model.neck.get('rfp_backbone'): 142 | if cfg.model.neck.rfp_backbone.get('pretrained'): 143 | cfg.model.neck.rfp_backbone.pretrained = None 144 | 145 | # in case the test dataset is concatenated 146 | samples_per_gpu = 1 147 | if isinstance(cfg.data.test, dict): 148 | cfg.data.test.test_mode = True 149 | samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1) 150 | if samples_per_gpu > 1: 151 | # Replace 'ImageToTensor' to 'DefaultFormatBundle' 152 | cfg.data.test.pipeline = replace_ImageToTensor( 153 | cfg.data.test.pipeline) 154 | elif isinstance(cfg.data.test, list): 155 | for ds_cfg in cfg.data.test: 156 | ds_cfg.test_mode = True 157 | samples_per_gpu = max( 158 | [ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test]) 159 | if samples_per_gpu > 1: 160 | for ds_cfg in cfg.data.test: 161 | ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline) 162 | 163 | # init distributed env first, since logger depends on the dist info. 164 | if args.launcher == 'none': 165 | distributed = False 166 | else: 167 | distributed = True 168 | init_dist(args.launcher, **cfg.dist_params) 169 | 170 | # build the dataloader 171 | dataset = build_dataset(cfg.data.test) 172 | data_loader = build_dataloader( 173 | dataset, 174 | samples_per_gpu=samples_per_gpu, 175 | workers_per_gpu=cfg.data.workers_per_gpu, 176 | dist=distributed, 177 | shuffle=False) 178 | 179 | # build the model and load checkpoint 180 | cfg.model.train_cfg = None 181 | model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg')) 182 | fp16_cfg = cfg.get('fp16', None) 183 | if fp16_cfg is not None: 184 | wrap_fp16_model(model) 185 | checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu') 186 | if args.fuse_conv_bn: 187 | model = fuse_conv_bn(model) 188 | # old versions did not save class info in checkpoints, this walkaround is 189 | # for backward compatibility 190 | if 'CLASSES' in checkpoint.get('meta', {}): 191 | model.CLASSES = checkpoint['meta']['CLASSES'] 192 | else: 193 | model.CLASSES = dataset.CLASSES 194 | 195 | if not distributed: 196 | model = MMDataParallel(model, device_ids=[0]) 197 | outputs = single_gpu_test(model, data_loader, args.show, args.show_dir, 198 | args.show_score_thr) 199 | else: 200 | model = MMDistributedDataParallel( 201 | model.cuda(), 202 | device_ids=[torch.cuda.current_device()], 203 | broadcast_buffers=False) 204 | outputs = multi_gpu_test(model, data_loader, args.tmpdir, 205 | args.gpu_collect) 206 | 207 | rank, _ = get_dist_info() 208 | if rank == 0: 209 | if args.out: 210 | print(f'\nwriting results to {args.out}') 211 | mmcv.dump(outputs, args.out) 212 | kwargs = {} if args.eval_options is None else args.eval_options 213 | if args.format_only: 214 | dataset.format_results(outputs, **kwargs) 215 | if args.eval: 216 | eval_kwargs = cfg.get('evaluation', {}).copy() 217 | # hard-code way to remove EvalHook args 218 | for key in [ 219 | 'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best', 220 | 'rule' 221 | ]: 222 | eval_kwargs.pop(key, None) 223 | eval_kwargs.update(dict(metric=args.eval, **kwargs)) 224 | print(dataset.evaluate(outputs, **eval_kwargs)) 225 | 226 | 227 | if __name__ == '__main__': 228 | main() 229 | -------------------------------------------------------------------------------- /detection/tools/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2015-present, Facebook, Inc. 2 | # All rights reserved. 3 | """ 4 | Training script modified from 5 | https://github.com/open-mmlab/mmdetection 6 | """ 7 | import argparse 8 | import copy 9 | import os 10 | import os.path as osp 11 | import time 12 | import warnings 13 | 14 | import mmcv 15 | import torch 16 | from mmcv import Config, DictAction 17 | from mmcv.runner import get_dist_info, init_dist 18 | from mmcv.utils import get_git_hash 19 | 20 | from mmdet import __version__ 21 | from mmdet.apis import set_random_seed, train_detector 22 | from mmdet.datasets import build_dataset 23 | from mmdet.models import build_detector 24 | from mmdet.utils import collect_env, get_root_logger 25 | 26 | from backbone import vit_PASS 27 | 28 | 29 | def parse_args(): 30 | parser = argparse.ArgumentParser(description='Train a detector') 31 | parser.add_argument('config', help='train config file path') 32 | parser.add_argument('--work-dir', help='the dir to save logs and models') 33 | parser.add_argument( 34 | '--resume-from', help='the checkpoint file to resume from') 35 | parser.add_argument( 36 | '--no-validate', 37 | action='store_true', 38 | help='whether not to evaluate the checkpoint during training') 39 | group_gpus = parser.add_mutually_exclusive_group() 40 | group_gpus.add_argument( 41 | '--gpus', 42 | type=int, 43 | help='number of gpus to use ' 44 | '(only applicable to non-distributed training)') 45 | group_gpus.add_argument( 46 | '--gpu-ids', 47 | type=int, 48 | nargs='+', 49 | help='ids of gpus to use ' 50 | '(only applicable to non-distributed training)') 51 | parser.add_argument('--seed', type=int, default=None, help='random seed') 52 | parser.add_argument( 53 | '--deterministic', 54 | action='store_true', 55 | help='whether to set deterministic options for CUDNN backend.') 56 | parser.add_argument( 57 | '--options', 58 | nargs='+', 59 | action=DictAction, 60 | help='override some settings in the used config, the key-value pair ' 61 | 'in xxx=yyy format will be merged into config file (deprecate), ' 62 | 'change to --cfg-options instead.') 63 | parser.add_argument( 64 | '--cfg-options', 65 | nargs='+', 66 | action=DictAction, 67 | help='override some settings in the used config, the key-value pair ' 68 | 'in xxx=yyy format will be merged into config file. If the value to ' 69 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 70 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 71 | 'Note that the quotation marks are necessary and that no white space ' 72 | 'is allowed.') 73 | parser.add_argument( 74 | '--launcher', 75 | choices=['none', 'pytorch', 'slurm', 'mpi'], 76 | default='none', 77 | help='job launcher') 78 | parser.add_argument('--local_rank', type=int, default=0) 79 | args = parser.parse_args() 80 | if 'LOCAL_RANK' not in os.environ: 81 | os.environ['LOCAL_RANK'] = str(args.local_rank) 82 | 83 | if args.options and args.cfg_options: 84 | raise ValueError( 85 | '--options and --cfg-options cannot be both ' 86 | 'specified, --options is deprecated in favor of --cfg-options') 87 | if args.options: 88 | warnings.warn('--options is deprecated in favor of --cfg-options') 89 | args.cfg_options = args.options 90 | 91 | return args 92 | 93 | 94 | def main(): 95 | args = parse_args() 96 | 97 | cfg = Config.fromfile(args.config) 98 | if args.cfg_options is not None: 99 | cfg.merge_from_dict(args.cfg_options) 100 | # import modules from string list. 101 | if cfg.get('custom_imports', None): 102 | from mmcv.utils import import_modules_from_strings 103 | import_modules_from_strings(**cfg['custom_imports']) 104 | # set cudnn_benchmark 105 | if cfg.get('cudnn_benchmark', False): 106 | torch.backends.cudnn.benchmark = True 107 | 108 | # work_dir is determined in this priority: CLI > segment in file > filename 109 | if args.work_dir is not None: 110 | # update configs according to CLI args if args.work_dir is not None 111 | cfg.work_dir = args.work_dir 112 | elif cfg.get('work_dir', None) is None: 113 | # use config filename as default work_dir if cfg.work_dir is None 114 | cfg.work_dir = osp.join('./work_dirs', 115 | osp.splitext(osp.basename(args.config))[0]) 116 | if args.resume_from is not None: 117 | cfg.resume_from = args.resume_from 118 | if args.gpu_ids is not None: 119 | cfg.gpu_ids = args.gpu_ids 120 | else: 121 | cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus) 122 | 123 | # init distributed env first, since logger depends on the dist info. 124 | if args.launcher == 'none': 125 | distributed = False 126 | else: 127 | distributed = True 128 | init_dist(args.launcher, **cfg.dist_params) 129 | # re-set gpu_ids with distributed training mode 130 | _, world_size = get_dist_info() 131 | cfg.gpu_ids = range(world_size) 132 | 133 | # create work_dir 134 | mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) 135 | # dump config 136 | cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) 137 | # init the logger before other steps 138 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 139 | log_file = osp.join(cfg.work_dir, f'{timestamp}.log') 140 | logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) 141 | 142 | # init the meta dict to record some important information such as 143 | # environment info and seed, which will be logged 144 | meta = dict() 145 | # log env info 146 | env_info_dict = collect_env() 147 | env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()]) 148 | dash_line = '-' * 60 + '\n' 149 | logger.info('Environment info:\n' + dash_line + env_info + '\n' + 150 | dash_line) 151 | meta['env_info'] = env_info 152 | meta['config'] = cfg.pretty_text 153 | # log some basic info 154 | logger.info(f'Distributed training: {distributed}') 155 | logger.info(f'Config:\n{cfg.pretty_text}') 156 | 157 | # set random seeds 158 | if args.seed is not None: 159 | logger.info(f'Set random seed to {args.seed}, ' 160 | f'deterministic: {args.deterministic}') 161 | set_random_seed(args.seed, deterministic=args.deterministic) 162 | cfg.seed = args.seed 163 | meta['seed'] = args.seed 164 | meta['exp_name'] = osp.basename(args.config) 165 | 166 | model = build_detector( 167 | cfg.model, 168 | train_cfg=cfg.get('train_cfg'), 169 | test_cfg=cfg.get('test_cfg')) 170 | 171 | datasets = [build_dataset(cfg.data.train)] 172 | if len(cfg.workflow) == 2: 173 | val_dataset = copy.deepcopy(cfg.data.val) 174 | val_dataset.pipeline = cfg.data.train.pipeline 175 | datasets.append(build_dataset(val_dataset)) 176 | if cfg.checkpoint_config is not None: 177 | # save mmdet version, config file content and class names in 178 | # checkpoints as meta data 179 | cfg.checkpoint_config.meta = dict( 180 | mmdet_version=__version__ + get_git_hash()[:7], 181 | CLASSES=datasets[0].CLASSES) 182 | # add an attribute for visualization convenience 183 | model.CLASSES = datasets[0].CLASSES 184 | train_detector( 185 | model, 186 | datasets, 187 | cfg, 188 | distributed=distributed, 189 | validate=(not args.no_validate), 190 | timestamp=timestamp, 191 | meta=meta) 192 | 193 | 194 | if __name__ == '__main__': 195 | main() 196 | -------------------------------------------------------------------------------- /segmentation/README.md: -------------------------------------------------------------------------------- 1 | ## Evaluating semantic segmentation on the ADE20K dataset 2 | 3 | Step 1. Prepare ADE20K dataset 4 | 5 | The dataset can be downloaded at 6 | `http://groups.csail.mit.edu/vision/datasets/ADE20K/toolkit/index_ade20k.pkl` 7 | 8 | or following instruction of `https://github.com/CSAILVision/ADE20K` 9 | 10 | Step 2. Install mmsegmentation 11 | 12 | ``` 13 | git clone https://github.com/open-mmlab/mmsegmentation.git 14 | ``` 15 | 16 | Step 3. Convert your model 17 | 18 | ``` 19 | python tools/model_converters/vit2mmseg.py /path/to/model_dir /path/to/saving_dir 20 | ``` 21 | 22 | Step 4. Fine-tune on the ADE20K dataset 23 | 24 | ``` 25 | tools/dist_train.sh configs/selfpatch/semfpn_vit-s16_512x512_40k_ade20k.py [number of gpu] --work-dir /path/to/saving_dir --seed 0 --deterministic --options model.pretrained=/path/to/model_dir 26 | ``` 27 | 28 | ## Pretrained weights on ADE20K 29 | You can download the weights of the fine-tuned models on semantic segmentation task. We provide fine-tuned models with `Semantic FPN` (40k iterations) and `UperNet` (160k iterations). 30 | 31 | | backbone | arch | iterations | mIoU | checkpoint | 32 | | ------------- | ------------- | ------------- | ------------- | ------------- | 33 | | DINO | ViT-S/16 + Semantic FPN | 40k | 38.3 | download | 34 | | DINO + SelfPatch | ViT-S/16 + Semantic FPN | 40k | 41.2 | download | 35 | | DINO | ViT-S/16 + UperNet | 160k | 42.3 | download | 36 | | DINO + SelfPatch | ViT-S/16 + UperNet | 160k | 43.2 | download | 37 | 38 | ## Acknowledgement 39 | This code is built using the mmsegmentation libray. The optimization hyperarameters are adopted from XCiT. 40 | -------------------------------------------------------------------------------- /segmentation/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .bisenetv1 import BiSeNetV1 3 | from .bisenetv2 import BiSeNetV2 4 | from .cgnet import CGNet 5 | from .fast_scnn import FastSCNN 6 | from .hrnet import HRNet 7 | from .mit import MixVisionTransformer 8 | from .mobilenet_v2 import MobileNetV2 9 | from .mobilenet_v3 import MobileNetV3 10 | from .resnest import ResNeSt 11 | from .resnet import ResNet, ResNetV1c, ResNetV1d 12 | from .resnext import ResNeXt 13 | from .swin import SwinTransformer 14 | from .unet import UNet 15 | from .vit_SelfPatch import SelfPatch_ViT 16 | 17 | __all__ = [ 18 | 'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN', 19 | 'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3', 20 | 'SelfPatch_ViT', 'SwinTransformer', 'MixVisionTransformer', 21 | 'BiSeNetV1', 'BiSeNetV2' 22 | ] 23 | -------------------------------------------------------------------------------- /segmentation/backbones/mobilenet_v2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import warnings 3 | 4 | import torch.nn as nn 5 | from mmcv.cnn import ConvModule 6 | from mmcv.runner import BaseModule 7 | from torch.nn.modules.batchnorm import _BatchNorm 8 | 9 | from ..builder import BACKBONES 10 | from ..utils import InvertedResidual, make_divisible 11 | 12 | 13 | @BACKBONES.register_module() 14 | class MobileNetV2(BaseModule): 15 | """MobileNetV2 backbone. 16 | 17 | This backbone is the implementation of 18 | `MobileNetV2: Inverted Residuals and Linear Bottlenecks 19 | `_. 20 | 21 | Args: 22 | widen_factor (float): Width multiplier, multiply number of 23 | channels in each layer by this amount. Default: 1.0. 24 | strides (Sequence[int], optional): Strides of the first block of each 25 | layer. If not specified, default config in ``arch_setting`` will 26 | be used. 27 | dilations (Sequence[int]): Dilation of each layer. 28 | out_indices (None or Sequence[int]): Output from which stages. 29 | Default: (7, ). 30 | frozen_stages (int): Stages to be frozen (all param fixed). 31 | Default: -1, which means not freezing any parameters. 32 | conv_cfg (dict): Config dict for convolution layer. 33 | Default: None, which means using conv2d. 34 | norm_cfg (dict): Config dict for normalization layer. 35 | Default: dict(type='BN'). 36 | act_cfg (dict): Config dict for activation layer. 37 | Default: dict(type='ReLU6'). 38 | norm_eval (bool): Whether to set norm layers to eval mode, namely, 39 | freeze running stats (mean and var). Note: Effect on Batch Norm 40 | and its variants only. Default: False. 41 | with_cp (bool): Use checkpoint or not. Using checkpoint will save some 42 | memory while slowing down the training speed. Default: False. 43 | pretrained (str, optional): model pretrained path. Default: None 44 | init_cfg (dict or list[dict], optional): Initialization config dict. 45 | Default: None 46 | """ 47 | 48 | # Parameters to build layers. 3 parameters are needed to construct a 49 | # layer, from left to right: expand_ratio, channel, num_blocks. 50 | arch_settings = [[1, 16, 1], [6, 24, 2], [6, 32, 3], [6, 64, 4], 51 | [6, 96, 3], [6, 160, 3], [6, 320, 1]] 52 | 53 | def __init__(self, 54 | widen_factor=1., 55 | strides=(1, 2, 2, 2, 1, 2, 1), 56 | dilations=(1, 1, 1, 1, 1, 1, 1), 57 | out_indices=(1, 2, 4, 6), 58 | frozen_stages=-1, 59 | conv_cfg=None, 60 | norm_cfg=dict(type='BN'), 61 | act_cfg=dict(type='ReLU6'), 62 | norm_eval=False, 63 | with_cp=False, 64 | pretrained=None, 65 | init_cfg=None): 66 | super(MobileNetV2, self).__init__(init_cfg) 67 | 68 | self.pretrained = pretrained 69 | assert not (init_cfg and pretrained), \ 70 | 'init_cfg and pretrained cannot be setting at the same time' 71 | if isinstance(pretrained, str): 72 | warnings.warn('DeprecationWarning: pretrained is a deprecated, ' 73 | 'please use "init_cfg" instead') 74 | self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) 75 | elif pretrained is None: 76 | if init_cfg is None: 77 | self.init_cfg = [ 78 | dict(type='Kaiming', layer='Conv2d'), 79 | dict( 80 | type='Constant', 81 | val=1, 82 | layer=['_BatchNorm', 'GroupNorm']) 83 | ] 84 | else: 85 | raise TypeError('pretrained must be a str or None') 86 | 87 | self.widen_factor = widen_factor 88 | self.strides = strides 89 | self.dilations = dilations 90 | assert len(strides) == len(dilations) == len(self.arch_settings) 91 | self.out_indices = out_indices 92 | for index in out_indices: 93 | if index not in range(0, 7): 94 | raise ValueError('the item in out_indices must in ' 95 | f'range(0, 7). But received {index}') 96 | 97 | if frozen_stages not in range(-1, 7): 98 | raise ValueError('frozen_stages must be in range(-1, 7). ' 99 | f'But received {frozen_stages}') 100 | self.out_indices = out_indices 101 | self.frozen_stages = frozen_stages 102 | self.conv_cfg = conv_cfg 103 | self.norm_cfg = norm_cfg 104 | self.act_cfg = act_cfg 105 | self.norm_eval = norm_eval 106 | self.with_cp = with_cp 107 | 108 | self.in_channels = make_divisible(32 * widen_factor, 8) 109 | 110 | self.conv1 = ConvModule( 111 | in_channels=3, 112 | out_channels=self.in_channels, 113 | kernel_size=3, 114 | stride=2, 115 | padding=1, 116 | conv_cfg=self.conv_cfg, 117 | norm_cfg=self.norm_cfg, 118 | act_cfg=self.act_cfg) 119 | 120 | self.layers = [] 121 | 122 | for i, layer_cfg in enumerate(self.arch_settings): 123 | expand_ratio, channel, num_blocks = layer_cfg 124 | stride = self.strides[i] 125 | dilation = self.dilations[i] 126 | out_channels = make_divisible(channel * widen_factor, 8) 127 | inverted_res_layer = self.make_layer( 128 | out_channels=out_channels, 129 | num_blocks=num_blocks, 130 | stride=stride, 131 | dilation=dilation, 132 | expand_ratio=expand_ratio) 133 | layer_name = f'layer{i + 1}' 134 | self.add_module(layer_name, inverted_res_layer) 135 | self.layers.append(layer_name) 136 | 137 | def make_layer(self, out_channels, num_blocks, stride, dilation, 138 | expand_ratio): 139 | """Stack InvertedResidual blocks to build a layer for MobileNetV2. 140 | 141 | Args: 142 | out_channels (int): out_channels of block. 143 | num_blocks (int): Number of blocks. 144 | stride (int): Stride of the first block. 145 | dilation (int): Dilation of the first block. 146 | expand_ratio (int): Expand the number of channels of the 147 | hidden layer in InvertedResidual by this ratio. 148 | """ 149 | layers = [] 150 | for i in range(num_blocks): 151 | layers.append( 152 | InvertedResidual( 153 | self.in_channels, 154 | out_channels, 155 | stride if i == 0 else 1, 156 | expand_ratio=expand_ratio, 157 | dilation=dilation if i == 0 else 1, 158 | conv_cfg=self.conv_cfg, 159 | norm_cfg=self.norm_cfg, 160 | act_cfg=self.act_cfg, 161 | with_cp=self.with_cp)) 162 | self.in_channels = out_channels 163 | 164 | return nn.Sequential(*layers) 165 | 166 | def forward(self, x): 167 | x = self.conv1(x) 168 | 169 | outs = [] 170 | for i, layer_name in enumerate(self.layers): 171 | layer = getattr(self, layer_name) 172 | x = layer(x) 173 | if i in self.out_indices: 174 | outs.append(x) 175 | 176 | if len(outs) == 1: 177 | return outs[0] 178 | else: 179 | return tuple(outs) 180 | 181 | def _freeze_stages(self): 182 | if self.frozen_stages >= 0: 183 | for param in self.conv1.parameters(): 184 | param.requires_grad = False 185 | for i in range(1, self.frozen_stages + 1): 186 | layer = getattr(self, f'layer{i}') 187 | layer.eval() 188 | for param in layer.parameters(): 189 | param.requires_grad = False 190 | 191 | def train(self, mode=True): 192 | super(MobileNetV2, self).train(mode) 193 | self._freeze_stages() 194 | if mode and self.norm_eval: 195 | for m in self.modules(): 196 | if isinstance(m, _BatchNorm): 197 | m.eval() 198 | -------------------------------------------------------------------------------- /segmentation/backbones/mobilenet_v3.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import warnings 3 | 4 | import mmcv 5 | from mmcv.cnn import ConvModule 6 | from mmcv.cnn.bricks import Conv2dAdaptivePadding 7 | from mmcv.runner import BaseModule 8 | from torch.nn.modules.batchnorm import _BatchNorm 9 | 10 | from ..builder import BACKBONES 11 | from ..utils import InvertedResidualV3 as InvertedResidual 12 | 13 | 14 | @BACKBONES.register_module() 15 | class MobileNetV3(BaseModule): 16 | """MobileNetV3 backbone. 17 | 18 | This backbone is the improved implementation of `Searching for MobileNetV3 19 | `_. 20 | 21 | Args: 22 | arch (str): Architecture of mobilnetv3, from {'small', 'large'}. 23 | Default: 'small'. 24 | conv_cfg (dict): Config dict for convolution layer. 25 | Default: None, which means using conv2d. 26 | norm_cfg (dict): Config dict for normalization layer. 27 | Default: dict(type='BN'). 28 | out_indices (tuple[int]): Output from which layer. 29 | Default: (0, 1, 12). 30 | frozen_stages (int): Stages to be frozen (all param fixed). 31 | Default: -1, which means not freezing any parameters. 32 | norm_eval (bool): Whether to set norm layers to eval mode, namely, 33 | freeze running stats (mean and var). Note: Effect on Batch Norm 34 | and its variants only. Default: False. 35 | with_cp (bool): Use checkpoint or not. Using checkpoint will save 36 | some memory while slowing down the training speed. 37 | Default: False. 38 | pretrained (str, optional): model pretrained path. Default: None 39 | init_cfg (dict or list[dict], optional): Initialization config dict. 40 | Default: None 41 | """ 42 | # Parameters to build each block: 43 | # [kernel size, mid channels, out channels, with_se, act type, stride] 44 | arch_settings = { 45 | 'small': [[3, 16, 16, True, 'ReLU', 2], # block0 layer1 os=4 46 | [3, 72, 24, False, 'ReLU', 2], # block1 layer2 os=8 47 | [3, 88, 24, False, 'ReLU', 1], 48 | [5, 96, 40, True, 'HSwish', 2], # block2 layer4 os=16 49 | [5, 240, 40, True, 'HSwish', 1], 50 | [5, 240, 40, True, 'HSwish', 1], 51 | [5, 120, 48, True, 'HSwish', 1], # block3 layer7 os=16 52 | [5, 144, 48, True, 'HSwish', 1], 53 | [5, 288, 96, True, 'HSwish', 2], # block4 layer9 os=32 54 | [5, 576, 96, True, 'HSwish', 1], 55 | [5, 576, 96, True, 'HSwish', 1]], 56 | 'large': [[3, 16, 16, False, 'ReLU', 1], # block0 layer1 os=2 57 | [3, 64, 24, False, 'ReLU', 2], # block1 layer2 os=4 58 | [3, 72, 24, False, 'ReLU', 1], 59 | [5, 72, 40, True, 'ReLU', 2], # block2 layer4 os=8 60 | [5, 120, 40, True, 'ReLU', 1], 61 | [5, 120, 40, True, 'ReLU', 1], 62 | [3, 240, 80, False, 'HSwish', 2], # block3 layer7 os=16 63 | [3, 200, 80, False, 'HSwish', 1], 64 | [3, 184, 80, False, 'HSwish', 1], 65 | [3, 184, 80, False, 'HSwish', 1], 66 | [3, 480, 112, True, 'HSwish', 1], # block4 layer11 os=16 67 | [3, 672, 112, True, 'HSwish', 1], 68 | [5, 672, 160, True, 'HSwish', 2], # block5 layer13 os=32 69 | [5, 960, 160, True, 'HSwish', 1], 70 | [5, 960, 160, True, 'HSwish', 1]] 71 | } # yapf: disable 72 | 73 | def __init__(self, 74 | arch='small', 75 | conv_cfg=None, 76 | norm_cfg=dict(type='BN'), 77 | out_indices=(0, 1, 12), 78 | frozen_stages=-1, 79 | reduction_factor=1, 80 | norm_eval=False, 81 | with_cp=False, 82 | pretrained=None, 83 | init_cfg=None): 84 | super(MobileNetV3, self).__init__(init_cfg) 85 | 86 | self.pretrained = pretrained 87 | assert not (init_cfg and pretrained), \ 88 | 'init_cfg and pretrained cannot be setting at the same time' 89 | if isinstance(pretrained, str): 90 | warnings.warn('DeprecationWarning: pretrained is a deprecated, ' 91 | 'please use "init_cfg" instead') 92 | self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) 93 | elif pretrained is None: 94 | if init_cfg is None: 95 | self.init_cfg = [ 96 | dict(type='Kaiming', layer='Conv2d'), 97 | dict( 98 | type='Constant', 99 | val=1, 100 | layer=['_BatchNorm', 'GroupNorm']) 101 | ] 102 | else: 103 | raise TypeError('pretrained must be a str or None') 104 | 105 | assert arch in self.arch_settings 106 | assert isinstance(reduction_factor, int) and reduction_factor > 0 107 | assert mmcv.is_tuple_of(out_indices, int) 108 | for index in out_indices: 109 | if index not in range(0, len(self.arch_settings[arch]) + 2): 110 | raise ValueError( 111 | 'the item in out_indices must in ' 112 | f'range(0, {len(self.arch_settings[arch])+2}). ' 113 | f'But received {index}') 114 | 115 | if frozen_stages not in range(-1, len(self.arch_settings[arch]) + 2): 116 | raise ValueError('frozen_stages must be in range(-1, ' 117 | f'{len(self.arch_settings[arch])+2}). ' 118 | f'But received {frozen_stages}') 119 | self.arch = arch 120 | self.conv_cfg = conv_cfg 121 | self.norm_cfg = norm_cfg 122 | self.out_indices = out_indices 123 | self.frozen_stages = frozen_stages 124 | self.reduction_factor = reduction_factor 125 | self.norm_eval = norm_eval 126 | self.with_cp = with_cp 127 | self.layers = self._make_layer() 128 | 129 | def _make_layer(self): 130 | layers = [] 131 | 132 | # build the first layer (layer0) 133 | in_channels = 16 134 | layer = ConvModule( 135 | in_channels=3, 136 | out_channels=in_channels, 137 | kernel_size=3, 138 | stride=2, 139 | padding=1, 140 | conv_cfg=dict(type='Conv2dAdaptivePadding'), 141 | norm_cfg=self.norm_cfg, 142 | act_cfg=dict(type='HSwish')) 143 | self.add_module('layer0', layer) 144 | layers.append('layer0') 145 | 146 | layer_setting = self.arch_settings[self.arch] 147 | for i, params in enumerate(layer_setting): 148 | (kernel_size, mid_channels, out_channels, with_se, act, 149 | stride) = params 150 | 151 | if self.arch == 'large' and i >= 12 or self.arch == 'small' and \ 152 | i >= 8: 153 | mid_channels = mid_channels // self.reduction_factor 154 | out_channels = out_channels // self.reduction_factor 155 | 156 | if with_se: 157 | se_cfg = dict( 158 | channels=mid_channels, 159 | ratio=4, 160 | act_cfg=(dict(type='ReLU'), 161 | dict(type='HSigmoid', bias=3.0, divisor=6.0))) 162 | else: 163 | se_cfg = None 164 | 165 | layer = InvertedResidual( 166 | in_channels=in_channels, 167 | out_channels=out_channels, 168 | mid_channels=mid_channels, 169 | kernel_size=kernel_size, 170 | stride=stride, 171 | se_cfg=se_cfg, 172 | with_expand_conv=(in_channels != mid_channels), 173 | conv_cfg=self.conv_cfg, 174 | norm_cfg=self.norm_cfg, 175 | act_cfg=dict(type=act), 176 | with_cp=self.with_cp) 177 | in_channels = out_channels 178 | layer_name = 'layer{}'.format(i + 1) 179 | self.add_module(layer_name, layer) 180 | layers.append(layer_name) 181 | 182 | # build the last layer 183 | # block5 layer12 os=32 for small model 184 | # block6 layer16 os=32 for large model 185 | layer = ConvModule( 186 | in_channels=in_channels, 187 | out_channels=576 if self.arch == 'small' else 960, 188 | kernel_size=1, 189 | stride=1, 190 | dilation=4, 191 | padding=0, 192 | conv_cfg=self.conv_cfg, 193 | norm_cfg=self.norm_cfg, 194 | act_cfg=dict(type='HSwish')) 195 | layer_name = 'layer{}'.format(len(layer_setting) + 1) 196 | self.add_module(layer_name, layer) 197 | layers.append(layer_name) 198 | 199 | # next, convert backbone MobileNetV3 to a semantic segmentation version 200 | if self.arch == 'small': 201 | self.layer4.depthwise_conv.conv.stride = (1, 1) 202 | self.layer9.depthwise_conv.conv.stride = (1, 1) 203 | for i in range(4, len(layers)): 204 | layer = getattr(self, layers[i]) 205 | if isinstance(layer, InvertedResidual): 206 | modified_module = layer.depthwise_conv.conv 207 | else: 208 | modified_module = layer.conv 209 | 210 | if i < 9: 211 | modified_module.dilation = (2, 2) 212 | pad = 2 213 | else: 214 | modified_module.dilation = (4, 4) 215 | pad = 4 216 | 217 | if not isinstance(modified_module, Conv2dAdaptivePadding): 218 | # Adjust padding 219 | pad *= (modified_module.kernel_size[0] - 1) // 2 220 | modified_module.padding = (pad, pad) 221 | else: 222 | self.layer7.depthwise_conv.conv.stride = (1, 1) 223 | self.layer13.depthwise_conv.conv.stride = (1, 1) 224 | for i in range(7, len(layers)): 225 | layer = getattr(self, layers[i]) 226 | if isinstance(layer, InvertedResidual): 227 | modified_module = layer.depthwise_conv.conv 228 | else: 229 | modified_module = layer.conv 230 | 231 | if i < 13: 232 | modified_module.dilation = (2, 2) 233 | pad = 2 234 | else: 235 | modified_module.dilation = (4, 4) 236 | pad = 4 237 | 238 | if not isinstance(modified_module, Conv2dAdaptivePadding): 239 | # Adjust padding 240 | pad *= (modified_module.kernel_size[0] - 1) // 2 241 | modified_module.padding = (pad, pad) 242 | 243 | return layers 244 | 245 | def forward(self, x): 246 | outs = [] 247 | for i, layer_name in enumerate(self.layers): 248 | layer = getattr(self, layer_name) 249 | x = layer(x) 250 | if i in self.out_indices: 251 | outs.append(x) 252 | return outs 253 | 254 | def _freeze_stages(self): 255 | for i in range(self.frozen_stages + 1): 256 | layer = getattr(self, f'layer{i}') 257 | layer.eval() 258 | for param in layer.parameters(): 259 | param.requires_grad = False 260 | 261 | def train(self, mode=True): 262 | super(MobileNetV3, self).train(mode) 263 | self._freeze_stages() 264 | if mode and self.norm_eval: 265 | for m in self.modules(): 266 | if isinstance(m, _BatchNorm): 267 | m.eval() 268 | -------------------------------------------------------------------------------- /segmentation/backbones/resnest.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.utils.checkpoint as cp 8 | from mmcv.cnn import build_conv_layer, build_norm_layer 9 | 10 | from ..builder import BACKBONES 11 | from ..utils import ResLayer 12 | from .resnet import Bottleneck as _Bottleneck 13 | from .resnet import ResNetV1d 14 | 15 | 16 | class RSoftmax(nn.Module): 17 | """Radix Softmax module in ``SplitAttentionConv2d``. 18 | 19 | Args: 20 | radix (int): Radix of input. 21 | groups (int): Groups of input. 22 | """ 23 | 24 | def __init__(self, radix, groups): 25 | super().__init__() 26 | self.radix = radix 27 | self.groups = groups 28 | 29 | def forward(self, x): 30 | batch = x.size(0) 31 | if self.radix > 1: 32 | x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2) 33 | x = F.softmax(x, dim=1) 34 | x = x.reshape(batch, -1) 35 | else: 36 | x = torch.sigmoid(x) 37 | return x 38 | 39 | 40 | class SplitAttentionConv2d(nn.Module): 41 | """Split-Attention Conv2d in ResNeSt. 42 | 43 | Args: 44 | in_channels (int): Same as nn.Conv2d. 45 | out_channels (int): Same as nn.Conv2d. 46 | kernel_size (int | tuple[int]): Same as nn.Conv2d. 47 | stride (int | tuple[int]): Same as nn.Conv2d. 48 | padding (int | tuple[int]): Same as nn.Conv2d. 49 | dilation (int | tuple[int]): Same as nn.Conv2d. 50 | groups (int): Same as nn.Conv2d. 51 | radix (int): Radix of SpltAtConv2d. Default: 2 52 | reduction_factor (int): Reduction factor of inter_channels. Default: 4. 53 | conv_cfg (dict): Config dict for convolution layer. Default: None, 54 | which means using conv2d. 55 | norm_cfg (dict): Config dict for normalization layer. Default: None. 56 | dcn (dict): Config dict for DCN. Default: None. 57 | """ 58 | 59 | def __init__(self, 60 | in_channels, 61 | channels, 62 | kernel_size, 63 | stride=1, 64 | padding=0, 65 | dilation=1, 66 | groups=1, 67 | radix=2, 68 | reduction_factor=4, 69 | conv_cfg=None, 70 | norm_cfg=dict(type='BN'), 71 | dcn=None): 72 | super(SplitAttentionConv2d, self).__init__() 73 | inter_channels = max(in_channels * radix // reduction_factor, 32) 74 | self.radix = radix 75 | self.groups = groups 76 | self.channels = channels 77 | self.with_dcn = dcn is not None 78 | self.dcn = dcn 79 | fallback_on_stride = False 80 | if self.with_dcn: 81 | fallback_on_stride = self.dcn.pop('fallback_on_stride', False) 82 | if self.with_dcn and not fallback_on_stride: 83 | assert conv_cfg is None, 'conv_cfg must be None for DCN' 84 | conv_cfg = dcn 85 | self.conv = build_conv_layer( 86 | conv_cfg, 87 | in_channels, 88 | channels * radix, 89 | kernel_size, 90 | stride=stride, 91 | padding=padding, 92 | dilation=dilation, 93 | groups=groups * radix, 94 | bias=False) 95 | self.norm0_name, norm0 = build_norm_layer( 96 | norm_cfg, channels * radix, postfix=0) 97 | self.add_module(self.norm0_name, norm0) 98 | self.relu = nn.ReLU(inplace=True) 99 | self.fc1 = build_conv_layer( 100 | None, channels, inter_channels, 1, groups=self.groups) 101 | self.norm1_name, norm1 = build_norm_layer( 102 | norm_cfg, inter_channels, postfix=1) 103 | self.add_module(self.norm1_name, norm1) 104 | self.fc2 = build_conv_layer( 105 | None, inter_channels, channels * radix, 1, groups=self.groups) 106 | self.rsoftmax = RSoftmax(radix, groups) 107 | 108 | @property 109 | def norm0(self): 110 | """nn.Module: the normalization layer named "norm0" """ 111 | return getattr(self, self.norm0_name) 112 | 113 | @property 114 | def norm1(self): 115 | """nn.Module: the normalization layer named "norm1" """ 116 | return getattr(self, self.norm1_name) 117 | 118 | def forward(self, x): 119 | x = self.conv(x) 120 | x = self.norm0(x) 121 | x = self.relu(x) 122 | 123 | batch, rchannel = x.shape[:2] 124 | batch = x.size(0) 125 | if self.radix > 1: 126 | splits = x.view(batch, self.radix, -1, *x.shape[2:]) 127 | gap = splits.sum(dim=1) 128 | else: 129 | gap = x 130 | gap = F.adaptive_avg_pool2d(gap, 1) 131 | gap = self.fc1(gap) 132 | 133 | gap = self.norm1(gap) 134 | gap = self.relu(gap) 135 | 136 | atten = self.fc2(gap) 137 | atten = self.rsoftmax(atten).view(batch, -1, 1, 1) 138 | 139 | if self.radix > 1: 140 | attens = atten.view(batch, self.radix, -1, *atten.shape[2:]) 141 | out = torch.sum(attens * splits, dim=1) 142 | else: 143 | out = atten * x 144 | return out.contiguous() 145 | 146 | 147 | class Bottleneck(_Bottleneck): 148 | """Bottleneck block for ResNeSt. 149 | 150 | Args: 151 | inplane (int): Input planes of this block. 152 | planes (int): Middle planes of this block. 153 | groups (int): Groups of conv2. 154 | width_per_group (int): Width per group of conv2. 64x4d indicates 155 | ``groups=64, width_per_group=4`` and 32x8d indicates 156 | ``groups=32, width_per_group=8``. 157 | radix (int): Radix of SpltAtConv2d. Default: 2 158 | reduction_factor (int): Reduction factor of inter_channels in 159 | SplitAttentionConv2d. Default: 4. 160 | avg_down_stride (bool): Whether to use average pool for stride in 161 | Bottleneck. Default: True. 162 | kwargs (dict): Key word arguments for base class. 163 | """ 164 | expansion = 4 165 | 166 | def __init__(self, 167 | inplanes, 168 | planes, 169 | groups=1, 170 | base_width=4, 171 | base_channels=64, 172 | radix=2, 173 | reduction_factor=4, 174 | avg_down_stride=True, 175 | **kwargs): 176 | """Bottleneck block for ResNeSt.""" 177 | super(Bottleneck, self).__init__(inplanes, planes, **kwargs) 178 | 179 | if groups == 1: 180 | width = self.planes 181 | else: 182 | width = math.floor(self.planes * 183 | (base_width / base_channels)) * groups 184 | 185 | self.avg_down_stride = avg_down_stride and self.conv2_stride > 1 186 | 187 | self.norm1_name, norm1 = build_norm_layer( 188 | self.norm_cfg, width, postfix=1) 189 | self.norm3_name, norm3 = build_norm_layer( 190 | self.norm_cfg, self.planes * self.expansion, postfix=3) 191 | 192 | self.conv1 = build_conv_layer( 193 | self.conv_cfg, 194 | self.inplanes, 195 | width, 196 | kernel_size=1, 197 | stride=self.conv1_stride, 198 | bias=False) 199 | self.add_module(self.norm1_name, norm1) 200 | self.with_modulated_dcn = False 201 | self.conv2 = SplitAttentionConv2d( 202 | width, 203 | width, 204 | kernel_size=3, 205 | stride=1 if self.avg_down_stride else self.conv2_stride, 206 | padding=self.dilation, 207 | dilation=self.dilation, 208 | groups=groups, 209 | radix=radix, 210 | reduction_factor=reduction_factor, 211 | conv_cfg=self.conv_cfg, 212 | norm_cfg=self.norm_cfg, 213 | dcn=self.dcn) 214 | delattr(self, self.norm2_name) 215 | 216 | if self.avg_down_stride: 217 | self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1) 218 | 219 | self.conv3 = build_conv_layer( 220 | self.conv_cfg, 221 | width, 222 | self.planes * self.expansion, 223 | kernel_size=1, 224 | bias=False) 225 | self.add_module(self.norm3_name, norm3) 226 | 227 | def forward(self, x): 228 | 229 | def _inner_forward(x): 230 | identity = x 231 | 232 | out = self.conv1(x) 233 | out = self.norm1(out) 234 | out = self.relu(out) 235 | 236 | if self.with_plugins: 237 | out = self.forward_plugin(out, self.after_conv1_plugin_names) 238 | 239 | out = self.conv2(out) 240 | 241 | if self.avg_down_stride: 242 | out = self.avd_layer(out) 243 | 244 | if self.with_plugins: 245 | out = self.forward_plugin(out, self.after_conv2_plugin_names) 246 | 247 | out = self.conv3(out) 248 | out = self.norm3(out) 249 | 250 | if self.with_plugins: 251 | out = self.forward_plugin(out, self.after_conv3_plugin_names) 252 | 253 | if self.downsample is not None: 254 | identity = self.downsample(x) 255 | 256 | out += identity 257 | 258 | return out 259 | 260 | if self.with_cp and x.requires_grad: 261 | out = cp.checkpoint(_inner_forward, x) 262 | else: 263 | out = _inner_forward(x) 264 | 265 | out = self.relu(out) 266 | 267 | return out 268 | 269 | 270 | @BACKBONES.register_module() 271 | class ResNeSt(ResNetV1d): 272 | """ResNeSt backbone. 273 | 274 | This backbone is the implementation of `ResNeSt: 275 | Split-Attention Networks `_. 276 | 277 | Args: 278 | groups (int): Number of groups of Bottleneck. Default: 1 279 | base_width (int): Base width of Bottleneck. Default: 4 280 | radix (int): Radix of SpltAtConv2d. Default: 2 281 | reduction_factor (int): Reduction factor of inter_channels in 282 | SplitAttentionConv2d. Default: 4. 283 | avg_down_stride (bool): Whether to use average pool for stride in 284 | Bottleneck. Default: True. 285 | kwargs (dict): Keyword arguments for ResNet. 286 | """ 287 | 288 | arch_settings = { 289 | 50: (Bottleneck, (3, 4, 6, 3)), 290 | 101: (Bottleneck, (3, 4, 23, 3)), 291 | 152: (Bottleneck, (3, 8, 36, 3)), 292 | 200: (Bottleneck, (3, 24, 36, 3)) 293 | } 294 | 295 | def __init__(self, 296 | groups=1, 297 | base_width=4, 298 | radix=2, 299 | reduction_factor=4, 300 | avg_down_stride=True, 301 | **kwargs): 302 | self.groups = groups 303 | self.base_width = base_width 304 | self.radix = radix 305 | self.reduction_factor = reduction_factor 306 | self.avg_down_stride = avg_down_stride 307 | super(ResNeSt, self).__init__(**kwargs) 308 | 309 | def make_res_layer(self, **kwargs): 310 | """Pack all blocks in a stage into a ``ResLayer``.""" 311 | return ResLayer( 312 | groups=self.groups, 313 | base_width=self.base_width, 314 | base_channels=self.base_channels, 315 | radix=self.radix, 316 | reduction_factor=self.reduction_factor, 317 | avg_down_stride=self.avg_down_stride, 318 | **kwargs) 319 | -------------------------------------------------------------------------------- /segmentation/backbones/resnext.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import math 3 | 4 | from mmcv.cnn import build_conv_layer, build_norm_layer 5 | 6 | from ..builder import BACKBONES 7 | from ..utils import ResLayer 8 | from .resnet import Bottleneck as _Bottleneck 9 | from .resnet import ResNet 10 | 11 | 12 | class Bottleneck(_Bottleneck): 13 | """Bottleneck block for ResNeXt. 14 | 15 | If style is "pytorch", the stride-two layer is the 3x3 conv layer, if it is 16 | "caffe", the stride-two layer is the first 1x1 conv layer. 17 | """ 18 | 19 | def __init__(self, 20 | inplanes, 21 | planes, 22 | groups=1, 23 | base_width=4, 24 | base_channels=64, 25 | **kwargs): 26 | super(Bottleneck, self).__init__(inplanes, planes, **kwargs) 27 | 28 | if groups == 1: 29 | width = self.planes 30 | else: 31 | width = math.floor(self.planes * 32 | (base_width / base_channels)) * groups 33 | 34 | self.norm1_name, norm1 = build_norm_layer( 35 | self.norm_cfg, width, postfix=1) 36 | self.norm2_name, norm2 = build_norm_layer( 37 | self.norm_cfg, width, postfix=2) 38 | self.norm3_name, norm3 = build_norm_layer( 39 | self.norm_cfg, self.planes * self.expansion, postfix=3) 40 | 41 | self.conv1 = build_conv_layer( 42 | self.conv_cfg, 43 | self.inplanes, 44 | width, 45 | kernel_size=1, 46 | stride=self.conv1_stride, 47 | bias=False) 48 | self.add_module(self.norm1_name, norm1) 49 | fallback_on_stride = False 50 | self.with_modulated_dcn = False 51 | if self.with_dcn: 52 | fallback_on_stride = self.dcn.pop('fallback_on_stride', False) 53 | if not self.with_dcn or fallback_on_stride: 54 | self.conv2 = build_conv_layer( 55 | self.conv_cfg, 56 | width, 57 | width, 58 | kernel_size=3, 59 | stride=self.conv2_stride, 60 | padding=self.dilation, 61 | dilation=self.dilation, 62 | groups=groups, 63 | bias=False) 64 | else: 65 | assert self.conv_cfg is None, 'conv_cfg must be None for DCN' 66 | self.conv2 = build_conv_layer( 67 | self.dcn, 68 | width, 69 | width, 70 | kernel_size=3, 71 | stride=self.conv2_stride, 72 | padding=self.dilation, 73 | dilation=self.dilation, 74 | groups=groups, 75 | bias=False) 76 | 77 | self.add_module(self.norm2_name, norm2) 78 | self.conv3 = build_conv_layer( 79 | self.conv_cfg, 80 | width, 81 | self.planes * self.expansion, 82 | kernel_size=1, 83 | bias=False) 84 | self.add_module(self.norm3_name, norm3) 85 | 86 | 87 | @BACKBONES.register_module() 88 | class ResNeXt(ResNet): 89 | """ResNeXt backbone. 90 | 91 | This backbone is the implementation of `Aggregated 92 | Residual Transformations for Deep Neural 93 | Networks `_. 94 | 95 | Args: 96 | depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. 97 | in_channels (int): Number of input image channels. Normally 3. 98 | num_stages (int): Resnet stages, normally 4. 99 | groups (int): Group of resnext. 100 | base_width (int): Base width of resnext. 101 | strides (Sequence[int]): Strides of the first block of each stage. 102 | dilations (Sequence[int]): Dilation of each stage. 103 | out_indices (Sequence[int]): Output from which stages. 104 | style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two 105 | layer is the 3x3 conv layer, otherwise the stride-two layer is 106 | the first 1x1 conv layer. 107 | frozen_stages (int): Stages to be frozen (all param fixed). -1 means 108 | not freezing any parameters. 109 | norm_cfg (dict): dictionary to construct and config norm layer. 110 | norm_eval (bool): Whether to set norm layers to eval mode, namely, 111 | freeze running stats (mean and var). Note: Effect on Batch Norm 112 | and its variants only. 113 | with_cp (bool): Use checkpoint or not. Using checkpoint will save some 114 | memory while slowing down the training speed. 115 | zero_init_residual (bool): whether to use zero init for last norm layer 116 | in resblocks to let them behave as identity. 117 | 118 | Example: 119 | >>> from mmseg.models import ResNeXt 120 | >>> import torch 121 | >>> self = ResNeXt(depth=50) 122 | >>> self.eval() 123 | >>> inputs = torch.rand(1, 3, 32, 32) 124 | >>> level_outputs = self.forward(inputs) 125 | >>> for level_out in level_outputs: 126 | ... print(tuple(level_out.shape)) 127 | (1, 256, 8, 8) 128 | (1, 512, 4, 4) 129 | (1, 1024, 2, 2) 130 | (1, 2048, 1, 1) 131 | """ 132 | 133 | arch_settings = { 134 | 50: (Bottleneck, (3, 4, 6, 3)), 135 | 101: (Bottleneck, (3, 4, 23, 3)), 136 | 152: (Bottleneck, (3, 8, 36, 3)) 137 | } 138 | 139 | def __init__(self, groups=1, base_width=4, **kwargs): 140 | self.groups = groups 141 | self.base_width = base_width 142 | super(ResNeXt, self).__init__(**kwargs) 143 | 144 | def make_res_layer(self, **kwargs): 145 | """Pack all blocks in a stage into a ``ResLayer``""" 146 | return ResLayer( 147 | groups=self.groups, 148 | base_width=self.base_width, 149 | base_channels=self.base_channels, 150 | **kwargs) 151 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/datasets/ade20k.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'ADE20KDataset' 3 | data_root = '/data/ADEChallengeData2016' 4 | img_norm_cfg = dict( 5 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 6 | crop_size = (512, 512) 7 | train_pipeline = [ 8 | dict(type='LoadImageFromFile'), 9 | dict(type='LoadAnnotations', reduce_zero_label=True), 10 | dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), 11 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 12 | dict(type='RandomFlip', prob=0.5), 13 | dict(type='PhotoMetricDistortion'), 14 | dict(type='Normalize', **img_norm_cfg), 15 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), 16 | dict(type='DefaultFormatBundle'), 17 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 18 | ] 19 | test_pipeline = [ 20 | dict(type='LoadImageFromFile'), 21 | dict( 22 | type='MultiScaleFlipAug', 23 | img_scale=(2048, 512), 24 | # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 25 | flip=False, 26 | transforms=[ 27 | dict(type='Resize', keep_ratio=True), 28 | dict(type='RandomFlip'), 29 | dict(type='Normalize', **img_norm_cfg), 30 | dict(type='ImageToTensor', keys=['img']), 31 | dict(type='Collect', keys=['img']), 32 | ]) 33 | ] 34 | data = dict( 35 | samples_per_gpu=4, 36 | workers_per_gpu=4, 37 | train=dict( 38 | type=dataset_type, 39 | data_root=data_root, 40 | img_dir='images/training', 41 | ann_dir='annotations/training', 42 | pipeline=train_pipeline), 43 | val=dict( 44 | type=dataset_type, 45 | data_root=data_root, 46 | img_dir='images/validation', 47 | ann_dir='annotations/validation', 48 | pipeline=test_pipeline), 49 | test=dict( 50 | type=dataset_type, 51 | data_root=data_root, 52 | img_dir='images/validation', 53 | ann_dir='annotations/validation', 54 | pipeline=test_pipeline)) 55 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/default_runtime.py: -------------------------------------------------------------------------------- 1 | # yapf:disable 2 | log_config = dict( 3 | interval=50, 4 | hooks=[ 5 | dict(type='TextLoggerHook', by_epoch=False), 6 | # dict(type='TensorboardLoggerHook') 7 | ]) 8 | # yapf:enable 9 | dist_params = dict(backend='nccl') 10 | log_level = 'INFO' 11 | load_from = None 12 | resume_from = None 13 | workflow = [('train', 1)] 14 | cudnn_benchmark = True 15 | -------------------------------------------------------------------------------- /segmentation/configs/_base_/schedules/schedule_40k.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) 3 | optimizer_config = dict() 4 | # learning policy 5 | lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) 6 | # runtime settings 7 | runner = dict(type='IterBasedRunner', max_iters=40000) 8 | checkpoint_config = dict(by_epoch=False, interval=4000) 9 | evaluation = dict(interval=4000, metric='mIoU', pre_eval=True) 10 | -------------------------------------------------------------------------------- /segmentation/configs/semfpn_vit-s16_512x512_40k_ade20k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/datasets/ade20k.py', 3 | '../_base_/default_runtime.py', '../_base_/schedules/schedule_40k.py' 4 | ] 5 | 6 | # model settings 7 | norm_cfg = dict(type='SyncBN', requires_grad=True) 8 | model = dict( 9 | type='EncoderDecoder', 10 | pretrained='pretrain/jx_vit_base_p16_224-80ecf9dd.pth', 11 | backbone=dict( 12 | type='SelfPatch_ViT', 13 | img_size=(512, 512), 14 | patch_size=16, 15 | in_channels=3, 16 | embed_dims=384, 17 | num_layers=12, 18 | num_heads=6, 19 | mlp_ratio=4, 20 | out_indices=(2, 5, 8, 11), 21 | qkv_bias=True, 22 | drop_rate=0.0, 23 | attn_drop_rate=0.0, 24 | drop_path_rate=0.1, 25 | with_cls_token=True, 26 | norm_cfg=dict(type='LN', eps=1e-6), 27 | act_cfg=dict(type='GELU'), 28 | norm_eval=False, 29 | interpolate_mode='bicubic'), 30 | neck=dict( 31 | type='FPN', 32 | in_channels=[384, 384, 384, 384], 33 | out_channels=256, 34 | num_outs=4), 35 | decode_head=dict( 36 | type='FPNHead', 37 | in_channels=[256, 256, 256, 256], 38 | in_index=[0, 1, 2, 3], 39 | feature_strides=[4, 8, 16, 32], 40 | channels=128, 41 | dropout_ratio=0.1, 42 | num_classes=150, 43 | norm_cfg=norm_cfg, 44 | align_corners=False, 45 | loss_decode=dict( 46 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 47 | # model training and testing settings 48 | train_cfg=dict(), 49 | test_cfg=dict(mode='whole')) 50 | 51 | 52 | # AdamW optimizer, no weight decay for position embedding & layer norm 53 | # in backbone 54 | optimizer = dict( 55 | _delete_=True, 56 | type='AdamW', 57 | lr=0.00006, 58 | betas=(0.9, 0.999), 59 | weight_decay=0.01, 60 | paramwise_cfg=dict( 61 | custom_keys={ 62 | 'pos_embed': dict(decay_mult=0.), 63 | 'cls_token': dict(decay_mult=0.), 64 | 'norm': dict(decay_mult=0.) 65 | })) 66 | 67 | lr_config = dict( 68 | _delete_=True, 69 | policy='poly', 70 | warmup='linear', 71 | warmup_iters=1500, 72 | warmup_ratio=1e-6, 73 | power=1.0, 74 | min_lr=0.0, 75 | by_epoch=False) 76 | 77 | # By default, models are trained on 8 GPUs with 2 images per GPU 78 | data = dict(samples_per_gpu=4) 79 | -------------------------------------------------------------------------------- /segmentation/tools/analyze_logs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | """Modified from https://github.com/open- 3 | mmlab/mmdetection/blob/master/tools/analysis_tools/analyze_logs.py.""" 4 | import argparse 5 | import json 6 | from collections import defaultdict 7 | 8 | import matplotlib.pyplot as plt 9 | import seaborn as sns 10 | 11 | 12 | def plot_curve(log_dicts, args): 13 | if args.backend is not None: 14 | plt.switch_backend(args.backend) 15 | sns.set_style(args.style) 16 | # if legend is None, use {filename}_{key} as legend 17 | legend = args.legend 18 | if legend is None: 19 | legend = [] 20 | for json_log in args.json_logs: 21 | for metric in args.keys: 22 | legend.append(f'{json_log}_{metric}') 23 | assert len(legend) == (len(args.json_logs) * len(args.keys)) 24 | metrics = args.keys 25 | 26 | num_metrics = len(metrics) 27 | for i, log_dict in enumerate(log_dicts): 28 | epochs = list(log_dict.keys()) 29 | for j, metric in enumerate(metrics): 30 | print(f'plot curve of {args.json_logs[i]}, metric is {metric}') 31 | plot_epochs = [] 32 | plot_iters = [] 33 | plot_values = [] 34 | # In some log files, iters number is not correct, `pre_iter` is 35 | # used to prevent generate wrong lines. 36 | pre_iter = -1 37 | for epoch in epochs: 38 | epoch_logs = log_dict[epoch] 39 | if metric not in epoch_logs.keys(): 40 | continue 41 | if metric in ['mIoU', 'mAcc', 'aAcc']: 42 | plot_epochs.append(epoch) 43 | plot_values.append(epoch_logs[metric][0]) 44 | else: 45 | for idx in range(len(epoch_logs[metric])): 46 | if pre_iter > epoch_logs['iter'][idx]: 47 | continue 48 | pre_iter = epoch_logs['iter'][idx] 49 | plot_iters.append(epoch_logs['iter'][idx]) 50 | plot_values.append(epoch_logs[metric][idx]) 51 | ax = plt.gca() 52 | label = legend[i * num_metrics + j] 53 | if metric in ['mIoU', 'mAcc', 'aAcc']: 54 | ax.set_xticks(plot_epochs) 55 | plt.xlabel('epoch') 56 | plt.plot(plot_epochs, plot_values, label=label, marker='o') 57 | else: 58 | plt.xlabel('iter') 59 | plt.plot(plot_iters, plot_values, label=label, linewidth=0.5) 60 | plt.legend() 61 | if args.title is not None: 62 | plt.title(args.title) 63 | if args.out is None: 64 | plt.show() 65 | else: 66 | print(f'save curve to: {args.out}') 67 | plt.savefig(args.out) 68 | plt.cla() 69 | 70 | 71 | def parse_args(): 72 | parser = argparse.ArgumentParser(description='Analyze Json Log') 73 | parser.add_argument( 74 | 'json_logs', 75 | type=str, 76 | nargs='+', 77 | help='path of train log in json format') 78 | parser.add_argument( 79 | '--keys', 80 | type=str, 81 | nargs='+', 82 | default=['mIoU'], 83 | help='the metric that you want to plot') 84 | parser.add_argument('--title', type=str, help='title of figure') 85 | parser.add_argument( 86 | '--legend', 87 | type=str, 88 | nargs='+', 89 | default=None, 90 | help='legend of each plot') 91 | parser.add_argument( 92 | '--backend', type=str, default=None, help='backend of plt') 93 | parser.add_argument( 94 | '--style', type=str, default='dark', help='style of plt') 95 | parser.add_argument('--out', type=str, default=None) 96 | args = parser.parse_args() 97 | return args 98 | 99 | 100 | def load_json_logs(json_logs): 101 | # load and convert json_logs to log_dict, key is epoch, value is a sub dict 102 | # keys of sub dict is different metrics 103 | # value of sub dict is a list of corresponding values of all iterations 104 | log_dicts = [dict() for _ in json_logs] 105 | for json_log, log_dict in zip(json_logs, log_dicts): 106 | with open(json_log, 'r') as log_file: 107 | for line in log_file: 108 | log = json.loads(line.strip()) 109 | # skip lines without `epoch` field 110 | if 'epoch' not in log: 111 | continue 112 | epoch = log.pop('epoch') 113 | if epoch not in log_dict: 114 | log_dict[epoch] = defaultdict(list) 115 | for k, v in log.items(): 116 | log_dict[epoch][k].append(v) 117 | return log_dicts 118 | 119 | 120 | def main(): 121 | args = parse_args() 122 | json_logs = args.json_logs 123 | for json_log in json_logs: 124 | assert json_log.endswith('.json') 125 | log_dicts = load_json_logs(json_logs) 126 | plot_curve(log_dicts, args) 127 | 128 | 129 | if __name__ == '__main__': 130 | main() 131 | -------------------------------------------------------------------------------- /segmentation/tools/benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import time 4 | 5 | import torch 6 | from mmcv import Config 7 | from mmcv.parallel import MMDataParallel 8 | from mmcv.runner import load_checkpoint, wrap_fp16_model 9 | 10 | from mmseg.datasets import build_dataloader, build_dataset 11 | from mmseg.models import build_segmentor 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser(description='MMSeg benchmark a model') 16 | parser.add_argument('config', help='test config file path') 17 | parser.add_argument('checkpoint', help='checkpoint file') 18 | parser.add_argument( 19 | '--log-interval', type=int, default=50, help='interval of logging') 20 | args = parser.parse_args() 21 | return args 22 | 23 | 24 | def main(): 25 | args = parse_args() 26 | 27 | cfg = Config.fromfile(args.config) 28 | # set cudnn_benchmark 29 | torch.backends.cudnn.benchmark = False 30 | cfg.model.pretrained = None 31 | cfg.data.test.test_mode = True 32 | 33 | # build the dataloader 34 | # TODO: support multiple images per gpu (only minor changes are needed) 35 | dataset = build_dataset(cfg.data.test) 36 | data_loader = build_dataloader( 37 | dataset, 38 | samples_per_gpu=1, 39 | workers_per_gpu=cfg.data.workers_per_gpu, 40 | dist=False, 41 | shuffle=False) 42 | 43 | # build the model and load checkpoint 44 | cfg.model.train_cfg = None 45 | model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg')) 46 | fp16_cfg = cfg.get('fp16', None) 47 | if fp16_cfg is not None: 48 | wrap_fp16_model(model) 49 | load_checkpoint(model, args.checkpoint, map_location='cpu') 50 | 51 | model = MMDataParallel(model, device_ids=[0]) 52 | 53 | model.eval() 54 | 55 | # the first several iterations may be very slow so skip them 56 | num_warmup = 5 57 | pure_inf_time = 0 58 | total_iters = 200 59 | 60 | # benchmark with 200 image and take the average 61 | for i, data in enumerate(data_loader): 62 | 63 | torch.cuda.synchronize() 64 | start_time = time.perf_counter() 65 | 66 | with torch.no_grad(): 67 | model(return_loss=False, rescale=True, **data) 68 | 69 | torch.cuda.synchronize() 70 | elapsed = time.perf_counter() - start_time 71 | 72 | if i >= num_warmup: 73 | pure_inf_time += elapsed 74 | if (i + 1) % args.log_interval == 0: 75 | fps = (i + 1 - num_warmup) / pure_inf_time 76 | print(f'Done image [{i + 1:<3}/ {total_iters}], ' 77 | f'fps: {fps:.2f} img / s') 78 | 79 | if (i + 1) == total_iters: 80 | fps = (i + 1 - num_warmup) / pure_inf_time 81 | print(f'Overall fps: {fps:.2f} img / s') 82 | break 83 | 84 | 85 | if __name__ == '__main__': 86 | main() 87 | -------------------------------------------------------------------------------- /segmentation/tools/browse_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import warnings 4 | from pathlib import Path 5 | 6 | import mmcv 7 | import numpy as np 8 | from mmcv import Config 9 | 10 | from mmseg.datasets.builder import build_dataset 11 | 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser(description='Browse a dataset') 15 | parser.add_argument('config', help='train config file path') 16 | parser.add_argument( 17 | '--show-origin', 18 | default=False, 19 | action='store_true', 20 | help='if True, omit all augmentation in pipeline,' 21 | ' show origin image and seg map') 22 | parser.add_argument( 23 | '--skip-type', 24 | type=str, 25 | nargs='+', 26 | default=['DefaultFormatBundle', 'Normalize', 'Collect'], 27 | help='skip some useless pipeline,if `show-origin` is true, ' 28 | 'all pipeline except `Load` will be skipped') 29 | parser.add_argument( 30 | '--output-dir', 31 | default='./output', 32 | type=str, 33 | help='If there is no display interface, you can save it') 34 | parser.add_argument('--show', default=False, action='store_true') 35 | parser.add_argument( 36 | '--show-interval', 37 | type=int, 38 | default=999, 39 | help='the interval of show (ms)') 40 | parser.add_argument( 41 | '--opacity', 42 | type=float, 43 | default=0.5, 44 | help='the opacity of semantic map') 45 | args = parser.parse_args() 46 | return args 47 | 48 | 49 | def imshow_semantic(img, 50 | seg, 51 | class_names, 52 | palette=None, 53 | win_name='', 54 | show=False, 55 | wait_time=0, 56 | out_file=None, 57 | opacity=0.5): 58 | """Draw `result` over `img`. 59 | 60 | Args: 61 | img (str or Tensor): The image to be displayed. 62 | seg (Tensor): The semantic segmentation results to draw over 63 | `img`. 64 | class_names (list[str]): Names of each classes. 65 | palette (list[list[int]]] | np.ndarray | None): The palette of 66 | segmentation map. If None is given, random palette will be 67 | generated. Default: None 68 | win_name (str): The window name. 69 | wait_time (int): Value of waitKey param. 70 | Default: 0. 71 | show (bool): Whether to show the image. 72 | Default: False. 73 | out_file (str or None): The filename to write the image. 74 | Default: None. 75 | opacity(float): Opacity of painted segmentation map. 76 | Default 0.5. 77 | Must be in (0, 1] range. 78 | Returns: 79 | img (Tensor): Only if not `show` or `out_file` 80 | """ 81 | img = mmcv.imread(img) 82 | img = img.copy() 83 | if palette is None: 84 | palette = np.random.randint(0, 255, size=(len(class_names), 3)) 85 | palette = np.array(palette) 86 | assert palette.shape[0] == len(class_names) 87 | assert palette.shape[1] == 3 88 | assert len(palette.shape) == 2 89 | assert 0 < opacity <= 1.0 90 | color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) 91 | for label, color in enumerate(palette): 92 | color_seg[seg == label, :] = color 93 | # convert to BGR 94 | color_seg = color_seg[..., ::-1] 95 | 96 | img = img * (1 - opacity) + color_seg * opacity 97 | img = img.astype(np.uint8) 98 | # if out_file specified, do not show image in window 99 | if out_file is not None: 100 | show = False 101 | 102 | if show: 103 | mmcv.imshow(img, win_name, wait_time) 104 | if out_file is not None: 105 | mmcv.imwrite(img, out_file) 106 | 107 | if not (show or out_file): 108 | warnings.warn('show==False and out_file is not specified, only ' 109 | 'result image will be returned') 110 | return img 111 | 112 | 113 | def _retrieve_data_cfg(_data_cfg, skip_type, show_origin): 114 | if show_origin is True: 115 | # only keep pipeline of Loading data and ann 116 | _data_cfg['pipeline'] = [ 117 | x for x in _data_cfg.pipeline if 'Load' in x['type'] 118 | ] 119 | else: 120 | _data_cfg['pipeline'] = [ 121 | x for x in _data_cfg.pipeline if x['type'] not in skip_type 122 | ] 123 | 124 | 125 | def retrieve_data_cfg(config_path, skip_type, show_origin=False): 126 | cfg = Config.fromfile(config_path) 127 | train_data_cfg = cfg.data.train 128 | if isinstance(train_data_cfg, list): 129 | for _data_cfg in train_data_cfg: 130 | if 'pipeline' in _data_cfg: 131 | _retrieve_data_cfg(_data_cfg, skip_type, show_origin) 132 | elif 'dataset' in _data_cfg: 133 | _retrieve_data_cfg(_data_cfg['dataset'], skip_type, 134 | show_origin) 135 | else: 136 | raise ValueError 137 | elif 'dataset' in train_data_cfg: 138 | _retrieve_data_cfg(train_data_cfg['dataset'], skip_type, show_origin) 139 | else: 140 | _retrieve_data_cfg(train_data_cfg, skip_type, show_origin) 141 | return cfg 142 | 143 | 144 | def main(): 145 | args = parse_args() 146 | cfg = retrieve_data_cfg(args.config, args.skip_type, args.show_origin) 147 | dataset = build_dataset(cfg.data.train) 148 | progress_bar = mmcv.ProgressBar(len(dataset)) 149 | for item in dataset: 150 | filename = os.path.join(args.output_dir, 151 | Path(item['filename']).name 152 | ) if args.output_dir is not None else None 153 | imshow_semantic( 154 | item['img'], 155 | item['gt_semantic_seg'], 156 | dataset.CLASSES, 157 | dataset.PALETTE, 158 | show=args.show, 159 | wait_time=args.show_interval, 160 | out_file=filename, 161 | opacity=args.opacity, 162 | ) 163 | progress_bar.update() 164 | 165 | 166 | if __name__ == '__main__': 167 | main() 168 | -------------------------------------------------------------------------------- /segmentation/tools/convert_datasets/chase_db1.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os 4 | import os.path as osp 5 | import tempfile 6 | import zipfile 7 | 8 | import mmcv 9 | 10 | CHASE_DB1_LEN = 28 * 3 11 | TRAINING_LEN = 60 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser( 16 | description='Convert CHASE_DB1 dataset to mmsegmentation format') 17 | parser.add_argument('dataset_path', help='path of CHASEDB1.zip') 18 | parser.add_argument('--tmp_dir', help='path of the temporary directory') 19 | parser.add_argument('-o', '--out_dir', help='output path') 20 | args = parser.parse_args() 21 | return args 22 | 23 | 24 | def main(): 25 | args = parse_args() 26 | dataset_path = args.dataset_path 27 | if args.out_dir is None: 28 | out_dir = osp.join('data', 'CHASE_DB1') 29 | else: 30 | out_dir = args.out_dir 31 | 32 | print('Making directories...') 33 | mmcv.mkdir_or_exist(out_dir) 34 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images')) 35 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'training')) 36 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'validation')) 37 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations')) 38 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'training')) 39 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation')) 40 | 41 | with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: 42 | print('Extracting CHASEDB1.zip...') 43 | zip_file = zipfile.ZipFile(dataset_path) 44 | zip_file.extractall(tmp_dir) 45 | 46 | print('Generating training dataset...') 47 | 48 | assert len(os.listdir(tmp_dir)) == CHASE_DB1_LEN, \ 49 | 'len(os.listdir(tmp_dir)) != {}'.format(CHASE_DB1_LEN) 50 | 51 | for img_name in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]: 52 | img = mmcv.imread(osp.join(tmp_dir, img_name)) 53 | if osp.splitext(img_name)[1] == '.jpg': 54 | mmcv.imwrite( 55 | img, 56 | osp.join(out_dir, 'images', 'training', 57 | osp.splitext(img_name)[0] + '.png')) 58 | else: 59 | # The annotation img should be divided by 128, because some of 60 | # the annotation imgs are not standard. We should set a 61 | # threshold to convert the nonstandard annotation imgs. The 62 | # value divided by 128 is equivalent to '1 if value >= 128 63 | # else 0' 64 | mmcv.imwrite( 65 | img[:, :, 0] // 128, 66 | osp.join(out_dir, 'annotations', 'training', 67 | osp.splitext(img_name)[0] + '.png')) 68 | 69 | for img_name in sorted(os.listdir(tmp_dir))[TRAINING_LEN:]: 70 | img = mmcv.imread(osp.join(tmp_dir, img_name)) 71 | if osp.splitext(img_name)[1] == '.jpg': 72 | mmcv.imwrite( 73 | img, 74 | osp.join(out_dir, 'images', 'validation', 75 | osp.splitext(img_name)[0] + '.png')) 76 | else: 77 | mmcv.imwrite( 78 | img[:, :, 0] // 128, 79 | osp.join(out_dir, 'annotations', 'validation', 80 | osp.splitext(img_name)[0] + '.png')) 81 | 82 | print('Removing the temporary files...') 83 | 84 | print('Done!') 85 | 86 | 87 | if __name__ == '__main__': 88 | main() 89 | -------------------------------------------------------------------------------- /segmentation/tools/convert_datasets/cityscapes.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os.path as osp 4 | 5 | import mmcv 6 | from cityscapesscripts.preparation.json2labelImg import json2labelImg 7 | 8 | 9 | def convert_json_to_label(json_file): 10 | label_file = json_file.replace('_polygons.json', '_labelTrainIds.png') 11 | json2labelImg(json_file, label_file, 'trainIds') 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser( 16 | description='Convert Cityscapes annotations to TrainIds') 17 | parser.add_argument('cityscapes_path', help='cityscapes data path') 18 | parser.add_argument('--gt-dir', default='gtFine', type=str) 19 | parser.add_argument('-o', '--out-dir', help='output path') 20 | parser.add_argument( 21 | '--nproc', default=1, type=int, help='number of process') 22 | args = parser.parse_args() 23 | return args 24 | 25 | 26 | def main(): 27 | args = parse_args() 28 | cityscapes_path = args.cityscapes_path 29 | out_dir = args.out_dir if args.out_dir else cityscapes_path 30 | mmcv.mkdir_or_exist(out_dir) 31 | 32 | gt_dir = osp.join(cityscapes_path, args.gt_dir) 33 | 34 | poly_files = [] 35 | for poly in mmcv.scandir(gt_dir, '_polygons.json', recursive=True): 36 | poly_file = osp.join(gt_dir, poly) 37 | poly_files.append(poly_file) 38 | if args.nproc > 1: 39 | mmcv.track_parallel_progress(convert_json_to_label, poly_files, 40 | args.nproc) 41 | else: 42 | mmcv.track_progress(convert_json_to_label, poly_files) 43 | 44 | split_names = ['train', 'val', 'test'] 45 | 46 | for split in split_names: 47 | filenames = [] 48 | for poly in mmcv.scandir( 49 | osp.join(gt_dir, split), '_polygons.json', recursive=True): 50 | filenames.append(poly.replace('_gtFine_polygons.json', '')) 51 | with open(osp.join(out_dir, f'{split}.txt'), 'w') as f: 52 | f.writelines(f + '\n' for f in filenames) 53 | 54 | 55 | if __name__ == '__main__': 56 | main() 57 | -------------------------------------------------------------------------------- /segmentation/tools/convert_datasets/coco_stuff10k.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path as osp 3 | import shutil 4 | from functools import partial 5 | 6 | import mmcv 7 | import numpy as np 8 | from PIL import Image 9 | from scipy.io import loadmat 10 | 11 | COCO_LEN = 10000 12 | 13 | clsID_to_trID = { 14 | 0: 0, 15 | 1: 1, 16 | 2: 2, 17 | 3: 3, 18 | 4: 4, 19 | 5: 5, 20 | 6: 6, 21 | 7: 7, 22 | 8: 8, 23 | 9: 9, 24 | 10: 10, 25 | 11: 11, 26 | 13: 12, 27 | 14: 13, 28 | 15: 14, 29 | 16: 15, 30 | 17: 16, 31 | 18: 17, 32 | 19: 18, 33 | 20: 19, 34 | 21: 20, 35 | 22: 21, 36 | 23: 22, 37 | 24: 23, 38 | 25: 24, 39 | 27: 25, 40 | 28: 26, 41 | 31: 27, 42 | 32: 28, 43 | 33: 29, 44 | 34: 30, 45 | 35: 31, 46 | 36: 32, 47 | 37: 33, 48 | 38: 34, 49 | 39: 35, 50 | 40: 36, 51 | 41: 37, 52 | 42: 38, 53 | 43: 39, 54 | 44: 40, 55 | 46: 41, 56 | 47: 42, 57 | 48: 43, 58 | 49: 44, 59 | 50: 45, 60 | 51: 46, 61 | 52: 47, 62 | 53: 48, 63 | 54: 49, 64 | 55: 50, 65 | 56: 51, 66 | 57: 52, 67 | 58: 53, 68 | 59: 54, 69 | 60: 55, 70 | 61: 56, 71 | 62: 57, 72 | 63: 58, 73 | 64: 59, 74 | 65: 60, 75 | 67: 61, 76 | 70: 62, 77 | 72: 63, 78 | 73: 64, 79 | 74: 65, 80 | 75: 66, 81 | 76: 67, 82 | 77: 68, 83 | 78: 69, 84 | 79: 70, 85 | 80: 71, 86 | 81: 72, 87 | 82: 73, 88 | 84: 74, 89 | 85: 75, 90 | 86: 76, 91 | 87: 77, 92 | 88: 78, 93 | 89: 79, 94 | 90: 80, 95 | 92: 81, 96 | 93: 82, 97 | 94: 83, 98 | 95: 84, 99 | 96: 85, 100 | 97: 86, 101 | 98: 87, 102 | 99: 88, 103 | 100: 89, 104 | 101: 90, 105 | 102: 91, 106 | 103: 92, 107 | 104: 93, 108 | 105: 94, 109 | 106: 95, 110 | 107: 96, 111 | 108: 97, 112 | 109: 98, 113 | 110: 99, 114 | 111: 100, 115 | 112: 101, 116 | 113: 102, 117 | 114: 103, 118 | 115: 104, 119 | 116: 105, 120 | 117: 106, 121 | 118: 107, 122 | 119: 108, 123 | 120: 109, 124 | 121: 110, 125 | 122: 111, 126 | 123: 112, 127 | 124: 113, 128 | 125: 114, 129 | 126: 115, 130 | 127: 116, 131 | 128: 117, 132 | 129: 118, 133 | 130: 119, 134 | 131: 120, 135 | 132: 121, 136 | 133: 122, 137 | 134: 123, 138 | 135: 124, 139 | 136: 125, 140 | 137: 126, 141 | 138: 127, 142 | 139: 128, 143 | 140: 129, 144 | 141: 130, 145 | 142: 131, 146 | 143: 132, 147 | 144: 133, 148 | 145: 134, 149 | 146: 135, 150 | 147: 136, 151 | 148: 137, 152 | 149: 138, 153 | 150: 139, 154 | 151: 140, 155 | 152: 141, 156 | 153: 142, 157 | 154: 143, 158 | 155: 144, 159 | 156: 145, 160 | 157: 146, 161 | 158: 147, 162 | 159: 148, 163 | 160: 149, 164 | 161: 150, 165 | 162: 151, 166 | 163: 152, 167 | 164: 153, 168 | 165: 154, 169 | 166: 155, 170 | 167: 156, 171 | 168: 157, 172 | 169: 158, 173 | 170: 159, 174 | 171: 160, 175 | 172: 161, 176 | 173: 162, 177 | 174: 163, 178 | 175: 164, 179 | 176: 165, 180 | 177: 166, 181 | 178: 167, 182 | 179: 168, 183 | 180: 169, 184 | 181: 170, 185 | 182: 171 186 | } 187 | 188 | 189 | def convert_to_trainID(tuple_path, in_img_dir, in_ann_dir, out_img_dir, 190 | out_mask_dir, is_train): 191 | imgpath, maskpath = tuple_path 192 | shutil.copyfile( 193 | osp.join(in_img_dir, imgpath), 194 | osp.join(out_img_dir, 'train2014', imgpath) if is_train else osp.join( 195 | out_img_dir, 'test2014', imgpath)) 196 | annotate = loadmat(osp.join(in_ann_dir, maskpath)) 197 | mask = annotate['S'].astype(np.uint8) 198 | mask_copy = mask.copy() 199 | for clsID, trID in clsID_to_trID.items(): 200 | mask_copy[mask == clsID] = trID 201 | seg_filename = osp.join(out_mask_dir, 'train2014', 202 | maskpath.split('.')[0] + 203 | '_labelTrainIds.png') if is_train else osp.join( 204 | out_mask_dir, 'test2014', 205 | maskpath.split('.')[0] + '_labelTrainIds.png') 206 | Image.fromarray(mask_copy).save(seg_filename, 'PNG') 207 | 208 | 209 | def generate_coco_list(folder): 210 | train_list = osp.join(folder, 'imageLists', 'train.txt') 211 | test_list = osp.join(folder, 'imageLists', 'test.txt') 212 | train_paths = [] 213 | test_paths = [] 214 | 215 | with open(train_list) as f: 216 | for filename in f: 217 | basename = filename.strip() 218 | imgpath = basename + '.jpg' 219 | maskpath = basename + '.mat' 220 | train_paths.append((imgpath, maskpath)) 221 | 222 | with open(test_list) as f: 223 | for filename in f: 224 | basename = filename.strip() 225 | imgpath = basename + '.jpg' 226 | maskpath = basename + '.mat' 227 | test_paths.append((imgpath, maskpath)) 228 | 229 | return train_paths, test_paths 230 | 231 | 232 | def parse_args(): 233 | parser = argparse.ArgumentParser( 234 | description=\ 235 | 'Convert COCO Stuff 10k annotations to mmsegmentation format') # noqa 236 | parser.add_argument('coco_path', help='coco stuff path') 237 | parser.add_argument('-o', '--out_dir', help='output path') 238 | parser.add_argument( 239 | '--nproc', default=16, type=int, help='number of process') 240 | args = parser.parse_args() 241 | return args 242 | 243 | 244 | def main(): 245 | args = parse_args() 246 | coco_path = args.coco_path 247 | nproc = args.nproc 248 | 249 | out_dir = args.out_dir or coco_path 250 | out_img_dir = osp.join(out_dir, 'images') 251 | out_mask_dir = osp.join(out_dir, 'annotations') 252 | 253 | mmcv.mkdir_or_exist(osp.join(out_img_dir, 'train2014')) 254 | mmcv.mkdir_or_exist(osp.join(out_img_dir, 'test2014')) 255 | mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'train2014')) 256 | mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'test2014')) 257 | 258 | train_list, test_list = generate_coco_list(coco_path) 259 | assert (len(train_list) + 260 | len(test_list)) == COCO_LEN, 'Wrong length of list {} & {}'.format( 261 | len(train_list), len(test_list)) 262 | 263 | if args.nproc > 1: 264 | mmcv.track_parallel_progress( 265 | partial( 266 | convert_to_trainID, 267 | in_img_dir=osp.join(coco_path, 'images'), 268 | in_ann_dir=osp.join(coco_path, 'annotations'), 269 | out_img_dir=out_img_dir, 270 | out_mask_dir=out_mask_dir, 271 | is_train=True), 272 | train_list, 273 | nproc=nproc) 274 | mmcv.track_parallel_progress( 275 | partial( 276 | convert_to_trainID, 277 | in_img_dir=osp.join(coco_path, 'images'), 278 | in_ann_dir=osp.join(coco_path, 'annotations'), 279 | out_img_dir=out_img_dir, 280 | out_mask_dir=out_mask_dir, 281 | is_train=False), 282 | test_list, 283 | nproc=nproc) 284 | else: 285 | mmcv.track_progress( 286 | partial( 287 | convert_to_trainID, 288 | in_img_dir=osp.join(coco_path, 'images'), 289 | in_ann_dir=osp.join(coco_path, 'annotations'), 290 | out_img_dir=out_img_dir, 291 | out_mask_dir=out_mask_dir, 292 | is_train=True), train_list) 293 | mmcv.track_progress( 294 | partial( 295 | convert_to_trainID, 296 | in_img_dir=osp.join(coco_path, 'images'), 297 | in_ann_dir=osp.join(coco_path, 'annotations'), 298 | out_img_dir=out_img_dir, 299 | out_mask_dir=out_mask_dir, 300 | is_train=False), test_list) 301 | 302 | print('Done!') 303 | 304 | 305 | if __name__ == '__main__': 306 | main() 307 | -------------------------------------------------------------------------------- /segmentation/tools/convert_datasets/coco_stuff164k.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os.path as osp 3 | import shutil 4 | from functools import partial 5 | from glob import glob 6 | 7 | import mmcv 8 | import numpy as np 9 | from PIL import Image 10 | 11 | COCO_LEN = 123287 12 | 13 | clsID_to_trID = { 14 | 0: 0, 15 | 1: 1, 16 | 2: 2, 17 | 3: 3, 18 | 4: 4, 19 | 5: 5, 20 | 6: 6, 21 | 7: 7, 22 | 8: 8, 23 | 9: 9, 24 | 10: 10, 25 | 12: 11, 26 | 13: 12, 27 | 14: 13, 28 | 15: 14, 29 | 16: 15, 30 | 17: 16, 31 | 18: 17, 32 | 19: 18, 33 | 20: 19, 34 | 21: 20, 35 | 22: 21, 36 | 23: 22, 37 | 24: 23, 38 | 26: 24, 39 | 27: 25, 40 | 30: 26, 41 | 31: 27, 42 | 32: 28, 43 | 33: 29, 44 | 34: 30, 45 | 35: 31, 46 | 36: 32, 47 | 37: 33, 48 | 38: 34, 49 | 39: 35, 50 | 40: 36, 51 | 41: 37, 52 | 42: 38, 53 | 43: 39, 54 | 45: 40, 55 | 46: 41, 56 | 47: 42, 57 | 48: 43, 58 | 49: 44, 59 | 50: 45, 60 | 51: 46, 61 | 52: 47, 62 | 53: 48, 63 | 54: 49, 64 | 55: 50, 65 | 56: 51, 66 | 57: 52, 67 | 58: 53, 68 | 59: 54, 69 | 60: 55, 70 | 61: 56, 71 | 62: 57, 72 | 63: 58, 73 | 64: 59, 74 | 66: 60, 75 | 69: 61, 76 | 71: 62, 77 | 72: 63, 78 | 73: 64, 79 | 74: 65, 80 | 75: 66, 81 | 76: 67, 82 | 77: 68, 83 | 78: 69, 84 | 79: 70, 85 | 80: 71, 86 | 81: 72, 87 | 83: 73, 88 | 84: 74, 89 | 85: 75, 90 | 86: 76, 91 | 87: 77, 92 | 88: 78, 93 | 89: 79, 94 | 91: 80, 95 | 92: 81, 96 | 93: 82, 97 | 94: 83, 98 | 95: 84, 99 | 96: 85, 100 | 97: 86, 101 | 98: 87, 102 | 99: 88, 103 | 100: 89, 104 | 101: 90, 105 | 102: 91, 106 | 103: 92, 107 | 104: 93, 108 | 105: 94, 109 | 106: 95, 110 | 107: 96, 111 | 108: 97, 112 | 109: 98, 113 | 110: 99, 114 | 111: 100, 115 | 112: 101, 116 | 113: 102, 117 | 114: 103, 118 | 115: 104, 119 | 116: 105, 120 | 117: 106, 121 | 118: 107, 122 | 119: 108, 123 | 120: 109, 124 | 121: 110, 125 | 122: 111, 126 | 123: 112, 127 | 124: 113, 128 | 125: 114, 129 | 126: 115, 130 | 127: 116, 131 | 128: 117, 132 | 129: 118, 133 | 130: 119, 134 | 131: 120, 135 | 132: 121, 136 | 133: 122, 137 | 134: 123, 138 | 135: 124, 139 | 136: 125, 140 | 137: 126, 141 | 138: 127, 142 | 139: 128, 143 | 140: 129, 144 | 141: 130, 145 | 142: 131, 146 | 143: 132, 147 | 144: 133, 148 | 145: 134, 149 | 146: 135, 150 | 147: 136, 151 | 148: 137, 152 | 149: 138, 153 | 150: 139, 154 | 151: 140, 155 | 152: 141, 156 | 153: 142, 157 | 154: 143, 158 | 155: 144, 159 | 156: 145, 160 | 157: 146, 161 | 158: 147, 162 | 159: 148, 163 | 160: 149, 164 | 161: 150, 165 | 162: 151, 166 | 163: 152, 167 | 164: 153, 168 | 165: 154, 169 | 166: 155, 170 | 167: 156, 171 | 168: 157, 172 | 169: 158, 173 | 170: 159, 174 | 171: 160, 175 | 172: 161, 176 | 173: 162, 177 | 174: 163, 178 | 175: 164, 179 | 176: 165, 180 | 177: 166, 181 | 178: 167, 182 | 179: 168, 183 | 180: 169, 184 | 181: 170, 185 | 255: 255 186 | } 187 | 188 | 189 | def convert_to_trainID(maskpath, out_mask_dir, is_train): 190 | mask = np.array(Image.open(maskpath)) 191 | mask_copy = mask.copy() 192 | for clsID, trID in clsID_to_trID.items(): 193 | mask_copy[mask == clsID] = trID 194 | seg_filename = osp.join( 195 | out_mask_dir, 'train2017', 196 | osp.basename(maskpath).split('.')[0] + 197 | '_labelTrainIds.png') if is_train else osp.join( 198 | out_mask_dir, 'val2017', 199 | osp.basename(maskpath).split('.')[0] + '_labelTrainIds.png') 200 | Image.fromarray(mask_copy).save(seg_filename, 'PNG') 201 | 202 | 203 | def parse_args(): 204 | parser = argparse.ArgumentParser( 205 | description=\ 206 | 'Convert COCO Stuff 164k annotations to mmsegmentation format') # noqa 207 | parser.add_argument('coco_path', help='coco stuff path') 208 | parser.add_argument('-o', '--out_dir', help='output path') 209 | parser.add_argument( 210 | '--nproc', default=16, type=int, help='number of process') 211 | args = parser.parse_args() 212 | return args 213 | 214 | 215 | def main(): 216 | args = parse_args() 217 | coco_path = args.coco_path 218 | nproc = args.nproc 219 | 220 | out_dir = args.out_dir or coco_path 221 | out_img_dir = osp.join(out_dir, 'images') 222 | out_mask_dir = osp.join(out_dir, 'annotations') 223 | 224 | mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'train2017')) 225 | mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'val2017')) 226 | 227 | if out_dir != coco_path: 228 | shutil.copytree(osp.join(coco_path, 'images'), out_img_dir) 229 | 230 | train_list = glob(osp.join(coco_path, 'annotations', 'train2017', '*.png')) 231 | train_list = [file for file in train_list if '_labelTrainIds' not in file] 232 | test_list = glob(osp.join(coco_path, 'annotations', 'val2017', '*.png')) 233 | test_list = [file for file in test_list if '_labelTrainIds' not in file] 234 | assert (len(train_list) + 235 | len(test_list)) == COCO_LEN, 'Wrong length of list {} & {}'.format( 236 | len(train_list), len(test_list)) 237 | 238 | if args.nproc > 1: 239 | mmcv.track_parallel_progress( 240 | partial( 241 | convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True), 242 | train_list, 243 | nproc=nproc) 244 | mmcv.track_parallel_progress( 245 | partial( 246 | convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False), 247 | test_list, 248 | nproc=nproc) 249 | else: 250 | mmcv.track_progress( 251 | partial( 252 | convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True), 253 | train_list) 254 | mmcv.track_progress( 255 | partial( 256 | convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False), 257 | test_list) 258 | 259 | print('Done!') 260 | 261 | 262 | if __name__ == '__main__': 263 | main() 264 | -------------------------------------------------------------------------------- /segmentation/tools/convert_datasets/drive.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os 4 | import os.path as osp 5 | import tempfile 6 | import zipfile 7 | 8 | import cv2 9 | import mmcv 10 | 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser( 14 | description='Convert DRIVE dataset to mmsegmentation format') 15 | parser.add_argument( 16 | 'training_path', help='the training part of DRIVE dataset') 17 | parser.add_argument( 18 | 'testing_path', help='the testing part of DRIVE dataset') 19 | parser.add_argument('--tmp_dir', help='path of the temporary directory') 20 | parser.add_argument('-o', '--out_dir', help='output path') 21 | args = parser.parse_args() 22 | return args 23 | 24 | 25 | def main(): 26 | args = parse_args() 27 | training_path = args.training_path 28 | testing_path = args.testing_path 29 | if args.out_dir is None: 30 | out_dir = osp.join('data', 'DRIVE') 31 | else: 32 | out_dir = args.out_dir 33 | 34 | print('Making directories...') 35 | mmcv.mkdir_or_exist(out_dir) 36 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images')) 37 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'training')) 38 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'validation')) 39 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations')) 40 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'training')) 41 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation')) 42 | 43 | with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: 44 | print('Extracting training.zip...') 45 | zip_file = zipfile.ZipFile(training_path) 46 | zip_file.extractall(tmp_dir) 47 | 48 | print('Generating training dataset...') 49 | now_dir = osp.join(tmp_dir, 'training', 'images') 50 | for img_name in os.listdir(now_dir): 51 | img = mmcv.imread(osp.join(now_dir, img_name)) 52 | mmcv.imwrite( 53 | img, 54 | osp.join( 55 | out_dir, 'images', 'training', 56 | osp.splitext(img_name)[0].replace('_training', '') + 57 | '.png')) 58 | 59 | now_dir = osp.join(tmp_dir, 'training', '1st_manual') 60 | for img_name in os.listdir(now_dir): 61 | cap = cv2.VideoCapture(osp.join(now_dir, img_name)) 62 | ret, img = cap.read() 63 | mmcv.imwrite( 64 | img[:, :, 0] // 128, 65 | osp.join(out_dir, 'annotations', 'training', 66 | osp.splitext(img_name)[0] + '.png')) 67 | 68 | print('Extracting test.zip...') 69 | zip_file = zipfile.ZipFile(testing_path) 70 | zip_file.extractall(tmp_dir) 71 | 72 | print('Generating validation dataset...') 73 | now_dir = osp.join(tmp_dir, 'test', 'images') 74 | for img_name in os.listdir(now_dir): 75 | img = mmcv.imread(osp.join(now_dir, img_name)) 76 | mmcv.imwrite( 77 | img, 78 | osp.join( 79 | out_dir, 'images', 'validation', 80 | osp.splitext(img_name)[0].replace('_test', '') + '.png')) 81 | 82 | now_dir = osp.join(tmp_dir, 'test', '1st_manual') 83 | if osp.exists(now_dir): 84 | for img_name in os.listdir(now_dir): 85 | cap = cv2.VideoCapture(osp.join(now_dir, img_name)) 86 | ret, img = cap.read() 87 | # The annotation img should be divided by 128, because some of 88 | # the annotation imgs are not standard. We should set a 89 | # threshold to convert the nonstandard annotation imgs. The 90 | # value divided by 128 is equivalent to '1 if value >= 128 91 | # else 0' 92 | mmcv.imwrite( 93 | img[:, :, 0] // 128, 94 | osp.join(out_dir, 'annotations', 'validation', 95 | osp.splitext(img_name)[0] + '.png')) 96 | 97 | now_dir = osp.join(tmp_dir, 'test', '2nd_manual') 98 | if osp.exists(now_dir): 99 | for img_name in os.listdir(now_dir): 100 | cap = cv2.VideoCapture(osp.join(now_dir, img_name)) 101 | ret, img = cap.read() 102 | mmcv.imwrite( 103 | img[:, :, 0] // 128, 104 | osp.join(out_dir, 'annotations', 'validation', 105 | osp.splitext(img_name)[0] + '.png')) 106 | 107 | print('Removing the temporary files...') 108 | 109 | print('Done!') 110 | 111 | 112 | if __name__ == '__main__': 113 | main() 114 | -------------------------------------------------------------------------------- /segmentation/tools/convert_datasets/hrf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os 4 | import os.path as osp 5 | import tempfile 6 | import zipfile 7 | 8 | import mmcv 9 | 10 | HRF_LEN = 15 11 | TRAINING_LEN = 5 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser( 16 | description='Convert HRF dataset to mmsegmentation format') 17 | parser.add_argument('healthy_path', help='the path of healthy.zip') 18 | parser.add_argument( 19 | 'healthy_manualsegm_path', help='the path of healthy_manualsegm.zip') 20 | parser.add_argument('glaucoma_path', help='the path of glaucoma.zip') 21 | parser.add_argument( 22 | 'glaucoma_manualsegm_path', help='the path of glaucoma_manualsegm.zip') 23 | parser.add_argument( 24 | 'diabetic_retinopathy_path', 25 | help='the path of diabetic_retinopathy.zip') 26 | parser.add_argument( 27 | 'diabetic_retinopathy_manualsegm_path', 28 | help='the path of diabetic_retinopathy_manualsegm.zip') 29 | parser.add_argument('--tmp_dir', help='path of the temporary directory') 30 | parser.add_argument('-o', '--out_dir', help='output path') 31 | args = parser.parse_args() 32 | return args 33 | 34 | 35 | def main(): 36 | args = parse_args() 37 | images_path = [ 38 | args.healthy_path, args.glaucoma_path, args.diabetic_retinopathy_path 39 | ] 40 | annotations_path = [ 41 | args.healthy_manualsegm_path, args.glaucoma_manualsegm_path, 42 | args.diabetic_retinopathy_manualsegm_path 43 | ] 44 | if args.out_dir is None: 45 | out_dir = osp.join('data', 'HRF') 46 | else: 47 | out_dir = args.out_dir 48 | 49 | print('Making directories...') 50 | mmcv.mkdir_or_exist(out_dir) 51 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images')) 52 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'training')) 53 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'validation')) 54 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations')) 55 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'training')) 56 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation')) 57 | 58 | print('Generating images...') 59 | for now_path in images_path: 60 | with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: 61 | zip_file = zipfile.ZipFile(now_path) 62 | zip_file.extractall(tmp_dir) 63 | 64 | assert len(os.listdir(tmp_dir)) == HRF_LEN, \ 65 | 'len(os.listdir(tmp_dir)) != {}'.format(HRF_LEN) 66 | 67 | for filename in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]: 68 | img = mmcv.imread(osp.join(tmp_dir, filename)) 69 | mmcv.imwrite( 70 | img, 71 | osp.join(out_dir, 'images', 'training', 72 | osp.splitext(filename)[0] + '.png')) 73 | for filename in sorted(os.listdir(tmp_dir))[TRAINING_LEN:]: 74 | img = mmcv.imread(osp.join(tmp_dir, filename)) 75 | mmcv.imwrite( 76 | img, 77 | osp.join(out_dir, 'images', 'validation', 78 | osp.splitext(filename)[0] + '.png')) 79 | 80 | print('Generating annotations...') 81 | for now_path in annotations_path: 82 | with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: 83 | zip_file = zipfile.ZipFile(now_path) 84 | zip_file.extractall(tmp_dir) 85 | 86 | assert len(os.listdir(tmp_dir)) == HRF_LEN, \ 87 | 'len(os.listdir(tmp_dir)) != {}'.format(HRF_LEN) 88 | 89 | for filename in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]: 90 | img = mmcv.imread(osp.join(tmp_dir, filename)) 91 | # The annotation img should be divided by 128, because some of 92 | # the annotation imgs are not standard. We should set a 93 | # threshold to convert the nonstandard annotation imgs. The 94 | # value divided by 128 is equivalent to '1 if value >= 128 95 | # else 0' 96 | mmcv.imwrite( 97 | img[:, :, 0] // 128, 98 | osp.join(out_dir, 'annotations', 'training', 99 | osp.splitext(filename)[0] + '.png')) 100 | for filename in sorted(os.listdir(tmp_dir))[TRAINING_LEN:]: 101 | img = mmcv.imread(osp.join(tmp_dir, filename)) 102 | mmcv.imwrite( 103 | img[:, :, 0] // 128, 104 | osp.join(out_dir, 'annotations', 'validation', 105 | osp.splitext(filename)[0] + '.png')) 106 | 107 | print('Done!') 108 | 109 | 110 | if __name__ == '__main__': 111 | main() 112 | -------------------------------------------------------------------------------- /segmentation/tools/convert_datasets/pascal_context.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os.path as osp 4 | from functools import partial 5 | 6 | import mmcv 7 | import numpy as np 8 | from detail import Detail 9 | from PIL import Image 10 | 11 | _mapping = np.sort( 12 | np.array([ 13 | 0, 2, 259, 260, 415, 324, 9, 258, 144, 18, 19, 22, 23, 397, 25, 284, 14 | 158, 159, 416, 33, 162, 420, 454, 295, 296, 427, 44, 45, 46, 308, 59, 15 | 440, 445, 31, 232, 65, 354, 424, 68, 326, 72, 458, 34, 207, 80, 355, 16 | 85, 347, 220, 349, 360, 98, 187, 104, 105, 366, 189, 368, 113, 115 17 | ])) 18 | _key = np.array(range(len(_mapping))).astype('uint8') 19 | 20 | 21 | def generate_labels(img_id, detail, out_dir): 22 | 23 | def _class_to_index(mask, _mapping, _key): 24 | # assert the values 25 | values = np.unique(mask) 26 | for i in range(len(values)): 27 | assert (values[i] in _mapping) 28 | index = np.digitize(mask.ravel(), _mapping, right=True) 29 | return _key[index].reshape(mask.shape) 30 | 31 | mask = Image.fromarray( 32 | _class_to_index(detail.getMask(img_id), _mapping=_mapping, _key=_key)) 33 | filename = img_id['file_name'] 34 | mask.save(osp.join(out_dir, filename.replace('jpg', 'png'))) 35 | return osp.splitext(osp.basename(filename))[0] 36 | 37 | 38 | def parse_args(): 39 | parser = argparse.ArgumentParser( 40 | description='Convert PASCAL VOC annotations to mmsegmentation format') 41 | parser.add_argument('devkit_path', help='pascal voc devkit path') 42 | parser.add_argument('json_path', help='annoation json filepath') 43 | parser.add_argument('-o', '--out_dir', help='output path') 44 | args = parser.parse_args() 45 | return args 46 | 47 | 48 | def main(): 49 | args = parse_args() 50 | devkit_path = args.devkit_path 51 | if args.out_dir is None: 52 | out_dir = osp.join(devkit_path, 'VOC2010', 'SegmentationClassContext') 53 | else: 54 | out_dir = args.out_dir 55 | json_path = args.json_path 56 | mmcv.mkdir_or_exist(out_dir) 57 | img_dir = osp.join(devkit_path, 'VOC2010', 'JPEGImages') 58 | 59 | train_detail = Detail(json_path, img_dir, 'train') 60 | train_ids = train_detail.getImgs() 61 | 62 | val_detail = Detail(json_path, img_dir, 'val') 63 | val_ids = val_detail.getImgs() 64 | 65 | mmcv.mkdir_or_exist( 66 | osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext')) 67 | 68 | train_list = mmcv.track_progress( 69 | partial(generate_labels, detail=train_detail, out_dir=out_dir), 70 | train_ids) 71 | with open( 72 | osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext', 73 | 'train.txt'), 'w') as f: 74 | f.writelines(line + '\n' for line in sorted(train_list)) 75 | 76 | val_list = mmcv.track_progress( 77 | partial(generate_labels, detail=val_detail, out_dir=out_dir), val_ids) 78 | with open( 79 | osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext', 80 | 'val.txt'), 'w') as f: 81 | f.writelines(line + '\n' for line in sorted(val_list)) 82 | 83 | print('Done!') 84 | 85 | 86 | if __name__ == '__main__': 87 | main() 88 | -------------------------------------------------------------------------------- /segmentation/tools/convert_datasets/stare.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import gzip 4 | import os 5 | import os.path as osp 6 | import tarfile 7 | import tempfile 8 | 9 | import mmcv 10 | 11 | STARE_LEN = 20 12 | TRAINING_LEN = 10 13 | 14 | 15 | def un_gz(src, dst): 16 | g_file = gzip.GzipFile(src) 17 | with open(dst, 'wb+') as f: 18 | f.write(g_file.read()) 19 | g_file.close() 20 | 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser( 24 | description='Convert STARE dataset to mmsegmentation format') 25 | parser.add_argument('image_path', help='the path of stare-images.tar') 26 | parser.add_argument('labels_ah', help='the path of labels-ah.tar') 27 | parser.add_argument('labels_vk', help='the path of labels-vk.tar') 28 | parser.add_argument('--tmp_dir', help='path of the temporary directory') 29 | parser.add_argument('-o', '--out_dir', help='output path') 30 | args = parser.parse_args() 31 | return args 32 | 33 | 34 | def main(): 35 | args = parse_args() 36 | image_path = args.image_path 37 | labels_ah = args.labels_ah 38 | labels_vk = args.labels_vk 39 | if args.out_dir is None: 40 | out_dir = osp.join('data', 'STARE') 41 | else: 42 | out_dir = args.out_dir 43 | 44 | print('Making directories...') 45 | mmcv.mkdir_or_exist(out_dir) 46 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images')) 47 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'training')) 48 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'validation')) 49 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations')) 50 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'training')) 51 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation')) 52 | 53 | with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: 54 | mmcv.mkdir_or_exist(osp.join(tmp_dir, 'gz')) 55 | mmcv.mkdir_or_exist(osp.join(tmp_dir, 'files')) 56 | 57 | print('Extracting stare-images.tar...') 58 | with tarfile.open(image_path) as f: 59 | f.extractall(osp.join(tmp_dir, 'gz')) 60 | 61 | for filename in os.listdir(osp.join(tmp_dir, 'gz')): 62 | un_gz( 63 | osp.join(tmp_dir, 'gz', filename), 64 | osp.join(tmp_dir, 'files', 65 | osp.splitext(filename)[0])) 66 | 67 | now_dir = osp.join(tmp_dir, 'files') 68 | 69 | assert len(os.listdir(now_dir)) == STARE_LEN, \ 70 | 'len(os.listdir(now_dir)) != {}'.format(STARE_LEN) 71 | 72 | for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]: 73 | img = mmcv.imread(osp.join(now_dir, filename)) 74 | mmcv.imwrite( 75 | img, 76 | osp.join(out_dir, 'images', 'training', 77 | osp.splitext(filename)[0] + '.png')) 78 | 79 | for filename in sorted(os.listdir(now_dir))[TRAINING_LEN:]: 80 | img = mmcv.imread(osp.join(now_dir, filename)) 81 | mmcv.imwrite( 82 | img, 83 | osp.join(out_dir, 'images', 'validation', 84 | osp.splitext(filename)[0] + '.png')) 85 | 86 | print('Removing the temporary files...') 87 | 88 | with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: 89 | mmcv.mkdir_or_exist(osp.join(tmp_dir, 'gz')) 90 | mmcv.mkdir_or_exist(osp.join(tmp_dir, 'files')) 91 | 92 | print('Extracting labels-ah.tar...') 93 | with tarfile.open(labels_ah) as f: 94 | f.extractall(osp.join(tmp_dir, 'gz')) 95 | 96 | for filename in os.listdir(osp.join(tmp_dir, 'gz')): 97 | un_gz( 98 | osp.join(tmp_dir, 'gz', filename), 99 | osp.join(tmp_dir, 'files', 100 | osp.splitext(filename)[0])) 101 | 102 | now_dir = osp.join(tmp_dir, 'files') 103 | 104 | assert len(os.listdir(now_dir)) == STARE_LEN, \ 105 | 'len(os.listdir(now_dir)) != {}'.format(STARE_LEN) 106 | 107 | for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]: 108 | img = mmcv.imread(osp.join(now_dir, filename)) 109 | # The annotation img should be divided by 128, because some of 110 | # the annotation imgs are not standard. We should set a threshold 111 | # to convert the nonstandard annotation imgs. The value divided by 112 | # 128 equivalent to '1 if value >= 128 else 0' 113 | mmcv.imwrite( 114 | img[:, :, 0] // 128, 115 | osp.join(out_dir, 'annotations', 'training', 116 | osp.splitext(filename)[0] + '.png')) 117 | 118 | for filename in sorted(os.listdir(now_dir))[TRAINING_LEN:]: 119 | img = mmcv.imread(osp.join(now_dir, filename)) 120 | mmcv.imwrite( 121 | img[:, :, 0] // 128, 122 | osp.join(out_dir, 'annotations', 'validation', 123 | osp.splitext(filename)[0] + '.png')) 124 | 125 | print('Removing the temporary files...') 126 | 127 | with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: 128 | mmcv.mkdir_or_exist(osp.join(tmp_dir, 'gz')) 129 | mmcv.mkdir_or_exist(osp.join(tmp_dir, 'files')) 130 | 131 | print('Extracting labels-vk.tar...') 132 | with tarfile.open(labels_vk) as f: 133 | f.extractall(osp.join(tmp_dir, 'gz')) 134 | 135 | for filename in os.listdir(osp.join(tmp_dir, 'gz')): 136 | un_gz( 137 | osp.join(tmp_dir, 'gz', filename), 138 | osp.join(tmp_dir, 'files', 139 | osp.splitext(filename)[0])) 140 | 141 | now_dir = osp.join(tmp_dir, 'files') 142 | 143 | assert len(os.listdir(now_dir)) == STARE_LEN, \ 144 | 'len(os.listdir(now_dir)) != {}'.format(STARE_LEN) 145 | 146 | for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]: 147 | img = mmcv.imread(osp.join(now_dir, filename)) 148 | mmcv.imwrite( 149 | img[:, :, 0] // 128, 150 | osp.join(out_dir, 'annotations', 'training', 151 | osp.splitext(filename)[0] + '.png')) 152 | 153 | for filename in sorted(os.listdir(now_dir))[TRAINING_LEN:]: 154 | img = mmcv.imread(osp.join(now_dir, filename)) 155 | mmcv.imwrite( 156 | img[:, :, 0] // 128, 157 | osp.join(out_dir, 'annotations', 'validation', 158 | osp.splitext(filename)[0] + '.png')) 159 | 160 | print('Removing the temporary files...') 161 | 162 | print('Done!') 163 | 164 | 165 | if __name__ == '__main__': 166 | main() 167 | -------------------------------------------------------------------------------- /segmentation/tools/convert_datasets/voc_aug.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os.path as osp 4 | from functools import partial 5 | 6 | import mmcv 7 | import numpy as np 8 | from PIL import Image 9 | from scipy.io import loadmat 10 | 11 | AUG_LEN = 10582 12 | 13 | 14 | def convert_mat(mat_file, in_dir, out_dir): 15 | data = loadmat(osp.join(in_dir, mat_file)) 16 | mask = data['GTcls'][0]['Segmentation'][0].astype(np.uint8) 17 | seg_filename = osp.join(out_dir, mat_file.replace('.mat', '.png')) 18 | Image.fromarray(mask).save(seg_filename, 'PNG') 19 | 20 | 21 | def generate_aug_list(merged_list, excluded_list): 22 | return list(set(merged_list) - set(excluded_list)) 23 | 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser( 27 | description='Convert PASCAL VOC annotations to mmsegmentation format') 28 | parser.add_argument('devkit_path', help='pascal voc devkit path') 29 | parser.add_argument('aug_path', help='pascal voc aug path') 30 | parser.add_argument('-o', '--out_dir', help='output path') 31 | parser.add_argument( 32 | '--nproc', default=1, type=int, help='number of process') 33 | args = parser.parse_args() 34 | return args 35 | 36 | 37 | def main(): 38 | args = parse_args() 39 | devkit_path = args.devkit_path 40 | aug_path = args.aug_path 41 | nproc = args.nproc 42 | if args.out_dir is None: 43 | out_dir = osp.join(devkit_path, 'VOC2012', 'SegmentationClassAug') 44 | else: 45 | out_dir = args.out_dir 46 | mmcv.mkdir_or_exist(out_dir) 47 | in_dir = osp.join(aug_path, 'dataset', 'cls') 48 | 49 | mmcv.track_parallel_progress( 50 | partial(convert_mat, in_dir=in_dir, out_dir=out_dir), 51 | list(mmcv.scandir(in_dir, suffix='.mat')), 52 | nproc=nproc) 53 | 54 | full_aug_list = [] 55 | with open(osp.join(aug_path, 'dataset', 'train.txt')) as f: 56 | full_aug_list += [line.strip() for line in f] 57 | with open(osp.join(aug_path, 'dataset', 'val.txt')) as f: 58 | full_aug_list += [line.strip() for line in f] 59 | 60 | with open( 61 | osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', 62 | 'train.txt')) as f: 63 | ori_train_list = [line.strip() for line in f] 64 | with open( 65 | osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', 66 | 'val.txt')) as f: 67 | val_list = [line.strip() for line in f] 68 | 69 | aug_train_list = generate_aug_list(ori_train_list + full_aug_list, 70 | val_list) 71 | assert len(aug_train_list) == AUG_LEN, 'len(aug_train_list) != {}'.format( 72 | AUG_LEN) 73 | 74 | with open( 75 | osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', 76 | 'trainaug.txt'), 'w') as f: 77 | f.writelines(line + '\n' for line in aug_train_list) 78 | 79 | aug_list = generate_aug_list(full_aug_list, ori_train_list + val_list) 80 | assert len(aug_list) == AUG_LEN - len( 81 | ori_train_list), 'len(aug_list) != {}'.format(AUG_LEN - 82 | len(ori_train_list)) 83 | with open( 84 | osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', 'aug.txt'), 85 | 'w') as f: 86 | f.writelines(line + '\n' for line in aug_list) 87 | 88 | print('Done!') 89 | 90 | 91 | if __name__ == '__main__': 92 | main() 93 | -------------------------------------------------------------------------------- /segmentation/tools/deploy_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os 4 | import os.path as osp 5 | import shutil 6 | import warnings 7 | from typing import Any, Iterable 8 | 9 | import mmcv 10 | import numpy as np 11 | import torch 12 | from mmcv.parallel import MMDataParallel 13 | from mmcv.runner import get_dist_info 14 | from mmcv.utils import DictAction 15 | 16 | from mmseg.apis import single_gpu_test 17 | from mmseg.datasets import build_dataloader, build_dataset 18 | from mmseg.models.segmentors.base import BaseSegmentor 19 | from mmseg.ops import resize 20 | 21 | 22 | class ONNXRuntimeSegmentor(BaseSegmentor): 23 | 24 | def __init__(self, onnx_file: str, cfg: Any, device_id: int): 25 | super(ONNXRuntimeSegmentor, self).__init__() 26 | import onnxruntime as ort 27 | 28 | # get the custom op path 29 | ort_custom_op_path = '' 30 | try: 31 | from mmcv.ops import get_onnxruntime_op_path 32 | ort_custom_op_path = get_onnxruntime_op_path() 33 | except (ImportError, ModuleNotFoundError): 34 | warnings.warn('If input model has custom op from mmcv, \ 35 | you may have to build mmcv with ONNXRuntime from source.') 36 | session_options = ort.SessionOptions() 37 | # register custom op for onnxruntime 38 | if osp.exists(ort_custom_op_path): 39 | session_options.register_custom_ops_library(ort_custom_op_path) 40 | sess = ort.InferenceSession(onnx_file, session_options) 41 | providers = ['CPUExecutionProvider'] 42 | options = [{}] 43 | is_cuda_available = ort.get_device() == 'GPU' 44 | if is_cuda_available: 45 | providers.insert(0, 'CUDAExecutionProvider') 46 | options.insert(0, {'device_id': device_id}) 47 | 48 | sess.set_providers(providers, options) 49 | 50 | self.sess = sess 51 | self.device_id = device_id 52 | self.io_binding = sess.io_binding() 53 | self.output_names = [_.name for _ in sess.get_outputs()] 54 | for name in self.output_names: 55 | self.io_binding.bind_output(name) 56 | self.cfg = cfg 57 | self.test_mode = cfg.model.test_cfg.mode 58 | self.is_cuda_available = is_cuda_available 59 | 60 | def extract_feat(self, imgs): 61 | raise NotImplementedError('This method is not implemented.') 62 | 63 | def encode_decode(self, img, img_metas): 64 | raise NotImplementedError('This method is not implemented.') 65 | 66 | def forward_train(self, imgs, img_metas, **kwargs): 67 | raise NotImplementedError('This method is not implemented.') 68 | 69 | def simple_test(self, img: torch.Tensor, img_meta: Iterable, 70 | **kwargs) -> list: 71 | if not self.is_cuda_available: 72 | img = img.detach().cpu() 73 | elif self.device_id >= 0: 74 | img = img.cuda(self.device_id) 75 | device_type = img.device.type 76 | self.io_binding.bind_input( 77 | name='input', 78 | device_type=device_type, 79 | device_id=self.device_id, 80 | element_type=np.float32, 81 | shape=img.shape, 82 | buffer_ptr=img.data_ptr()) 83 | self.sess.run_with_iobinding(self.io_binding) 84 | seg_pred = self.io_binding.copy_outputs_to_cpu()[0] 85 | # whole might support dynamic reshape 86 | ori_shape = img_meta[0]['ori_shape'] 87 | if not (ori_shape[0] == seg_pred.shape[-2] 88 | and ori_shape[1] == seg_pred.shape[-1]): 89 | seg_pred = torch.from_numpy(seg_pred).float() 90 | seg_pred = resize( 91 | seg_pred, size=tuple(ori_shape[:2]), mode='nearest') 92 | seg_pred = seg_pred.long().detach().cpu().numpy() 93 | seg_pred = seg_pred[0] 94 | seg_pred = list(seg_pred) 95 | return seg_pred 96 | 97 | def aug_test(self, imgs, img_metas, **kwargs): 98 | raise NotImplementedError('This method is not implemented.') 99 | 100 | 101 | class TensorRTSegmentor(BaseSegmentor): 102 | 103 | def __init__(self, trt_file: str, cfg: Any, device_id: int): 104 | super(TensorRTSegmentor, self).__init__() 105 | from mmcv.tensorrt import TRTWraper, load_tensorrt_plugin 106 | try: 107 | load_tensorrt_plugin() 108 | except (ImportError, ModuleNotFoundError): 109 | warnings.warn('If input model has custom op from mmcv, \ 110 | you may have to build mmcv with TensorRT from source.') 111 | model = TRTWraper( 112 | trt_file, input_names=['input'], output_names=['output']) 113 | 114 | self.model = model 115 | self.device_id = device_id 116 | self.cfg = cfg 117 | self.test_mode = cfg.model.test_cfg.mode 118 | 119 | def extract_feat(self, imgs): 120 | raise NotImplementedError('This method is not implemented.') 121 | 122 | def encode_decode(self, img, img_metas): 123 | raise NotImplementedError('This method is not implemented.') 124 | 125 | def forward_train(self, imgs, img_metas, **kwargs): 126 | raise NotImplementedError('This method is not implemented.') 127 | 128 | def simple_test(self, img: torch.Tensor, img_meta: Iterable, 129 | **kwargs) -> list: 130 | with torch.cuda.device(self.device_id), torch.no_grad(): 131 | seg_pred = self.model({'input': img})['output'] 132 | seg_pred = seg_pred.detach().cpu().numpy() 133 | # whole might support dynamic reshape 134 | ori_shape = img_meta[0]['ori_shape'] 135 | if not (ori_shape[0] == seg_pred.shape[-2] 136 | and ori_shape[1] == seg_pred.shape[-1]): 137 | seg_pred = torch.from_numpy(seg_pred).float() 138 | seg_pred = resize( 139 | seg_pred, size=tuple(ori_shape[:2]), mode='nearest') 140 | seg_pred = seg_pred.long().detach().cpu().numpy() 141 | seg_pred = seg_pred[0] 142 | seg_pred = list(seg_pred) 143 | return seg_pred 144 | 145 | def aug_test(self, imgs, img_metas, **kwargs): 146 | raise NotImplementedError('This method is not implemented.') 147 | 148 | 149 | def parse_args() -> argparse.Namespace: 150 | parser = argparse.ArgumentParser( 151 | description='mmseg backend test (and eval)') 152 | parser.add_argument('config', help='test config file path') 153 | parser.add_argument('model', help='Input model file') 154 | parser.add_argument( 155 | '--backend', 156 | help='Backend of the model.', 157 | choices=['onnxruntime', 'tensorrt']) 158 | parser.add_argument('--out', help='output result file in pickle format') 159 | parser.add_argument( 160 | '--format-only', 161 | action='store_true', 162 | help='Format the output results without perform evaluation. It is' 163 | 'useful when you want to format the result to a specific format and ' 164 | 'submit it to the test server') 165 | parser.add_argument( 166 | '--eval', 167 | type=str, 168 | nargs='+', 169 | help='evaluation metrics, which depends on the dataset, e.g., "mIoU"' 170 | ' for generic datasets, and "cityscapes" for Cityscapes') 171 | parser.add_argument('--show', action='store_true', help='show results') 172 | parser.add_argument( 173 | '--show-dir', help='directory where painted images will be saved') 174 | parser.add_argument( 175 | '--options', nargs='+', action=DictAction, help='custom options') 176 | parser.add_argument( 177 | '--eval-options', 178 | nargs='+', 179 | action=DictAction, 180 | help='custom options for evaluation') 181 | parser.add_argument( 182 | '--opacity', 183 | type=float, 184 | default=0.5, 185 | help='Opacity of painted segmentation map. In (0, 1] range.') 186 | parser.add_argument('--local_rank', type=int, default=0) 187 | args = parser.parse_args() 188 | if 'LOCAL_RANK' not in os.environ: 189 | os.environ['LOCAL_RANK'] = str(args.local_rank) 190 | return args 191 | 192 | 193 | def main(): 194 | args = parse_args() 195 | 196 | assert args.out or args.eval or args.format_only or args.show \ 197 | or args.show_dir, \ 198 | ('Please specify at least one operation (save/eval/format/show the ' 199 | 'results / save the results) with the argument "--out", "--eval"' 200 | ', "--format-only", "--show" or "--show-dir"') 201 | 202 | if args.eval and args.format_only: 203 | raise ValueError('--eval and --format_only cannot be both specified') 204 | 205 | if args.out is not None and not args.out.endswith(('.pkl', '.pickle')): 206 | raise ValueError('The output file must be a pkl file.') 207 | 208 | cfg = mmcv.Config.fromfile(args.config) 209 | if args.options is not None: 210 | cfg.merge_from_dict(args.options) 211 | cfg.model.pretrained = None 212 | cfg.data.test.test_mode = True 213 | 214 | # init distributed env first, since logger depends on the dist info. 215 | distributed = False 216 | 217 | # build the dataloader 218 | # TODO: support multiple images per gpu (only minor changes are needed) 219 | dataset = build_dataset(cfg.data.test) 220 | data_loader = build_dataloader( 221 | dataset, 222 | samples_per_gpu=1, 223 | workers_per_gpu=cfg.data.workers_per_gpu, 224 | dist=distributed, 225 | shuffle=False) 226 | 227 | # load onnx config and meta 228 | cfg.model.train_cfg = None 229 | 230 | if args.backend == 'onnxruntime': 231 | model = ONNXRuntimeSegmentor(args.model, cfg=cfg, device_id=0) 232 | elif args.backend == 'tensorrt': 233 | model = TensorRTSegmentor(args.model, cfg=cfg, device_id=0) 234 | 235 | model.CLASSES = dataset.CLASSES 236 | model.PALETTE = dataset.PALETTE 237 | 238 | # clean gpu memory when starting a new evaluation. 239 | torch.cuda.empty_cache() 240 | eval_kwargs = {} if args.eval_options is None else args.eval_options 241 | 242 | # Deprecated 243 | efficient_test = eval_kwargs.get('efficient_test', False) 244 | if efficient_test: 245 | warnings.warn( 246 | '``efficient_test=True`` does not have effect in tools/test.py, ' 247 | 'the evaluation and format results are CPU memory efficient by ' 248 | 'default') 249 | 250 | eval_on_format_results = ( 251 | args.eval is not None and 'cityscapes' in args.eval) 252 | if eval_on_format_results: 253 | assert len(args.eval) == 1, 'eval on format results is not ' \ 254 | 'applicable for metrics other than ' \ 255 | 'cityscapes' 256 | if args.format_only or eval_on_format_results: 257 | if 'imgfile_prefix' in eval_kwargs: 258 | tmpdir = eval_kwargs['imgfile_prefix'] 259 | else: 260 | tmpdir = '.format_cityscapes' 261 | eval_kwargs.setdefault('imgfile_prefix', tmpdir) 262 | mmcv.mkdir_or_exist(tmpdir) 263 | else: 264 | tmpdir = None 265 | 266 | model = MMDataParallel(model, device_ids=[0]) 267 | results = single_gpu_test( 268 | model, 269 | data_loader, 270 | args.show, 271 | args.show_dir, 272 | False, 273 | args.opacity, 274 | pre_eval=args.eval is not None and not eval_on_format_results, 275 | format_only=args.format_only or eval_on_format_results, 276 | format_args=eval_kwargs) 277 | 278 | rank, _ = get_dist_info() 279 | if rank == 0: 280 | if args.out: 281 | warnings.warn( 282 | 'The behavior of ``args.out`` has been changed since MMSeg ' 283 | 'v0.16, the pickled outputs could be seg map as type of ' 284 | 'np.array, pre-eval results or file paths for ' 285 | '``dataset.format_results()``.') 286 | print(f'\nwriting results to {args.out}') 287 | mmcv.dump(results, args.out) 288 | if args.eval: 289 | dataset.evaluate(results, args.eval, **eval_kwargs) 290 | if tmpdir is not None and eval_on_format_results: 291 | # remove tmp dir when cityscapes evaluation 292 | shutil.rmtree(tmpdir) 293 | 294 | 295 | if __name__ == '__main__': 296 | main() 297 | -------------------------------------------------------------------------------- /segmentation/tools/dist_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | CHECKPOINT=$2 5 | GPUS=$3 6 | PORT=${PORT:-29500} 7 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 8 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 9 | $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4} 10 | -------------------------------------------------------------------------------- /segmentation/tools/dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | GPUS=$2 5 | PORT=${PORT:-29500} 6 | 7 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 8 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 9 | $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3} 10 | -------------------------------------------------------------------------------- /segmentation/tools/get_flops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | 4 | from mmcv import Config 5 | from mmcv.cnn import get_model_complexity_info 6 | 7 | from mmseg.models import build_segmentor 8 | 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser(description='Train a segmentor') 12 | parser.add_argument('config', help='train config file path') 13 | parser.add_argument( 14 | '--shape', 15 | type=int, 16 | nargs='+', 17 | default=[2048, 1024], 18 | help='input image size') 19 | args = parser.parse_args() 20 | return args 21 | 22 | 23 | def main(): 24 | 25 | args = parse_args() 26 | 27 | if len(args.shape) == 1: 28 | input_shape = (3, args.shape[0], args.shape[0]) 29 | elif len(args.shape) == 2: 30 | input_shape = (3, ) + tuple(args.shape) 31 | else: 32 | raise ValueError('invalid input shape') 33 | 34 | cfg = Config.fromfile(args.config) 35 | cfg.model.pretrained = None 36 | model = build_segmentor( 37 | cfg.model, 38 | train_cfg=cfg.get('train_cfg'), 39 | test_cfg=cfg.get('test_cfg')).cuda() 40 | model.eval() 41 | 42 | if hasattr(model, 'forward_dummy'): 43 | model.forward = model.forward_dummy 44 | else: 45 | raise NotImplementedError( 46 | 'FLOPs counter is currently not currently supported with {}'. 47 | format(model.__class__.__name__)) 48 | 49 | flops, params = get_model_complexity_info(model, input_shape) 50 | split_line = '=' * 30 51 | print('{0}\nInput shape: {1}\nFlops: {2}\nParams: {3}\n{0}'.format( 52 | split_line, input_shape, flops, params)) 53 | print('!!!Please be cautious if you use the results in papers. ' 54 | 'You may need to check if all ops are supported and verify that the ' 55 | 'flops computation is correct.') 56 | 57 | 58 | if __name__ == '__main__': 59 | main() 60 | -------------------------------------------------------------------------------- /segmentation/tools/model_converters/mit2mmseg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os.path as osp 4 | from collections import OrderedDict 5 | 6 | import mmcv 7 | import torch 8 | from mmcv.runner import CheckpointLoader 9 | 10 | 11 | def convert_mit(ckpt): 12 | new_ckpt = OrderedDict() 13 | # Process the concat between q linear weights and kv linear weights 14 | for k, v in ckpt.items(): 15 | if k.startswith('head'): 16 | continue 17 | # patch embedding convertion 18 | elif k.startswith('patch_embed'): 19 | stage_i = int(k.split('.')[0].replace('patch_embed', '')) 20 | new_k = k.replace(f'patch_embed{stage_i}', f'layers.{stage_i-1}.0') 21 | new_v = v 22 | if 'proj.' in new_k: 23 | new_k = new_k.replace('proj.', 'projection.') 24 | # transformer encoder layer convertion 25 | elif k.startswith('block'): 26 | stage_i = int(k.split('.')[0].replace('block', '')) 27 | new_k = k.replace(f'block{stage_i}', f'layers.{stage_i-1}.1') 28 | new_v = v 29 | if 'attn.q.' in new_k: 30 | sub_item_k = k.replace('q.', 'kv.') 31 | new_k = new_k.replace('q.', 'attn.in_proj_') 32 | new_v = torch.cat([v, ckpt[sub_item_k]], dim=0) 33 | elif 'attn.kv.' in new_k: 34 | continue 35 | elif 'attn.proj.' in new_k: 36 | new_k = new_k.replace('proj.', 'attn.out_proj.') 37 | elif 'attn.sr.' in new_k: 38 | new_k = new_k.replace('sr.', 'sr.') 39 | elif 'mlp.' in new_k: 40 | string = f'{new_k}-' 41 | new_k = new_k.replace('mlp.', 'ffn.layers.') 42 | if 'fc1.weight' in new_k or 'fc2.weight' in new_k: 43 | new_v = v.reshape((*v.shape, 1, 1)) 44 | new_k = new_k.replace('fc1.', '0.') 45 | new_k = new_k.replace('dwconv.dwconv.', '1.') 46 | new_k = new_k.replace('fc2.', '4.') 47 | string += f'{new_k} {v.shape}-{new_v.shape}' 48 | # norm layer convertion 49 | elif k.startswith('norm'): 50 | stage_i = int(k.split('.')[0].replace('norm', '')) 51 | new_k = k.replace(f'norm{stage_i}', f'layers.{stage_i-1}.2') 52 | new_v = v 53 | else: 54 | new_k = k 55 | new_v = v 56 | new_ckpt[new_k] = new_v 57 | return new_ckpt 58 | 59 | 60 | def main(): 61 | parser = argparse.ArgumentParser( 62 | description='Convert keys in official pretrained segformer to ' 63 | 'MMSegmentation style.') 64 | parser.add_argument('src', help='src model path or url') 65 | # The dst path must be a full path of the new checkpoint. 66 | parser.add_argument('dst', help='save path') 67 | args = parser.parse_args() 68 | 69 | checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') 70 | if 'state_dict' in checkpoint: 71 | state_dict = checkpoint['state_dict'] 72 | elif 'model' in checkpoint: 73 | state_dict = checkpoint['model'] 74 | else: 75 | state_dict = checkpoint 76 | weight = convert_mit(state_dict) 77 | mmcv.mkdir_or_exist(osp.dirname(args.dst)) 78 | torch.save(weight, args.dst) 79 | 80 | 81 | if __name__ == '__main__': 82 | main() 83 | -------------------------------------------------------------------------------- /segmentation/tools/model_converters/swin2mmseg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os.path as osp 4 | from collections import OrderedDict 5 | 6 | import mmcv 7 | import torch 8 | from mmcv.runner import CheckpointLoader 9 | 10 | 11 | def convert_swin(ckpt): 12 | new_ckpt = OrderedDict() 13 | 14 | def correct_unfold_reduction_order(x): 15 | out_channel, in_channel = x.shape 16 | x = x.reshape(out_channel, 4, in_channel // 4) 17 | x = x[:, [0, 2, 1, 3], :].transpose(1, 18 | 2).reshape(out_channel, in_channel) 19 | return x 20 | 21 | def correct_unfold_norm_order(x): 22 | in_channel = x.shape[0] 23 | x = x.reshape(4, in_channel // 4) 24 | x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel) 25 | return x 26 | 27 | for k, v in ckpt.items(): 28 | if k.startswith('head'): 29 | continue 30 | elif k.startswith('layers'): 31 | new_v = v 32 | if 'attn.' in k: 33 | new_k = k.replace('attn.', 'attn.w_msa.') 34 | elif 'mlp.' in k: 35 | if 'mlp.fc1.' in k: 36 | new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.') 37 | elif 'mlp.fc2.' in k: 38 | new_k = k.replace('mlp.fc2.', 'ffn.layers.1.') 39 | else: 40 | new_k = k.replace('mlp.', 'ffn.') 41 | elif 'downsample' in k: 42 | new_k = k 43 | if 'reduction.' in k: 44 | new_v = correct_unfold_reduction_order(v) 45 | elif 'norm.' in k: 46 | new_v = correct_unfold_norm_order(v) 47 | else: 48 | new_k = k 49 | new_k = new_k.replace('layers', 'stages', 1) 50 | elif k.startswith('patch_embed'): 51 | new_v = v 52 | if 'proj' in k: 53 | new_k = k.replace('proj', 'projection') 54 | else: 55 | new_k = k 56 | else: 57 | new_v = v 58 | new_k = k 59 | 60 | new_ckpt[new_k] = new_v 61 | 62 | return new_ckpt 63 | 64 | 65 | def main(): 66 | parser = argparse.ArgumentParser( 67 | description='Convert keys in official pretrained swin models to' 68 | 'MMSegmentation style.') 69 | parser.add_argument('src', help='src model path or url') 70 | # The dst path must be a full path of the new checkpoint. 71 | parser.add_argument('dst', help='save path') 72 | args = parser.parse_args() 73 | 74 | checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') 75 | if 'state_dict' in checkpoint: 76 | state_dict = checkpoint['state_dict'] 77 | elif 'model' in checkpoint: 78 | state_dict = checkpoint['model'] 79 | else: 80 | state_dict = checkpoint 81 | weight = convert_swin(state_dict) 82 | mmcv.mkdir_or_exist(osp.dirname(args.dst)) 83 | torch.save(weight, args.dst) 84 | 85 | 86 | if __name__ == '__main__': 87 | main() 88 | -------------------------------------------------------------------------------- /segmentation/tools/model_converters/vit2mmseg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os.path as osp 4 | from collections import OrderedDict 5 | 6 | import mmcv 7 | import torch 8 | from mmcv.runner import CheckpointLoader 9 | 10 | 11 | def convert_vit(ckpt): 12 | 13 | new_ckpt = OrderedDict() 14 | 15 | for k, v in ckpt.items(): 16 | # if k.startswith('class_token'): 17 | # new_k = k.replace('class_token.', 'cls_token.') 18 | # if k.startswith('pos_embedding'): 19 | # new_k = k.replace('pos_embedding.', 'pos_embed.') 20 | if k.startswith('head'): 21 | continue 22 | if k.startswith('norm'): 23 | new_k = k.replace('norm.', 'ln1.') 24 | elif k.startswith('patch_embed'): 25 | if 'proj' in k: 26 | new_k = k.replace('proj', 'projection') 27 | else: 28 | new_k = k 29 | elif k.startswith('blocks'): 30 | if 'norm' in k: 31 | new_k = k.replace('norm', 'ln') 32 | elif 'mlp.fc1' in k: 33 | new_k = k.replace('mlp.fc1', 'ffn.layers.0.0') 34 | elif 'mlp.fc2' in k: 35 | new_k = k.replace('mlp.fc2', 'ffn.layers.1') 36 | elif 'attn.qkv' in k: 37 | new_k = k.replace('attn.qkv.', 'attn.attn.in_proj_') 38 | elif 'attn.proj' in k: 39 | new_k = k.replace('attn.proj', 'attn.attn.out_proj') 40 | else: 41 | new_k = k 42 | new_k = new_k.replace('blocks.', 'layers.') 43 | else: 44 | new_k = k 45 | new_ckpt[new_k] = v 46 | 47 | return new_ckpt 48 | 49 | 50 | def main(): 51 | parser = argparse.ArgumentParser( 52 | description='Convert keys in timm pretrained vit models to ' 53 | 'MMSegmentation style.') 54 | parser.add_argument('src', help='src model path or url') 55 | # The dst path must be a full path of the new checkpoint. 56 | parser.add_argument('dst', help='save path') 57 | args = parser.parse_args() 58 | 59 | checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') 60 | if 'state_dict' in checkpoint: 61 | # timm checkpoint 62 | state_dict = checkpoint['state_dict'] 63 | elif 'model' in checkpoint: 64 | # deit checkpoint 65 | state_dict = checkpoint['model'] 66 | elif 'teacher' in checkpoint: 67 | # deit checkpoint 68 | state_dict = checkpoint['teacher'] 69 | state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} 70 | else: 71 | state_dict = checkpoint 72 | weight = convert_vit(state_dict) 73 | mmcv.mkdir_or_exist(osp.dirname(args.dst)) 74 | torch.save(weight, args.dst) 75 | 76 | 77 | if __name__ == '__main__': 78 | main() 79 | -------------------------------------------------------------------------------- /segmentation/tools/onnx2tensorrt.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os 4 | import os.path as osp 5 | from typing import Iterable, Optional, Union 6 | 7 | import matplotlib.pyplot as plt 8 | import mmcv 9 | import numpy as np 10 | import onnxruntime as ort 11 | import torch 12 | from mmcv.ops import get_onnxruntime_op_path 13 | from mmcv.tensorrt import (TRTWraper, is_tensorrt_plugin_loaded, onnx2trt, 14 | save_trt_engine) 15 | 16 | from mmseg.apis.inference import LoadImage 17 | from mmseg.datasets import DATASETS 18 | from mmseg.datasets.pipelines import Compose 19 | 20 | 21 | def get_GiB(x: int): 22 | """return x GiB.""" 23 | return x * (1 << 30) 24 | 25 | 26 | def _prepare_input_img(img_path: str, 27 | test_pipeline: Iterable[dict], 28 | shape: Optional[Iterable] = None, 29 | rescale_shape: Optional[Iterable] = None) -> dict: 30 | # build the data pipeline 31 | if shape is not None: 32 | test_pipeline[1]['img_scale'] = (shape[1], shape[0]) 33 | test_pipeline[1]['transforms'][0]['keep_ratio'] = False 34 | test_pipeline = [LoadImage()] + test_pipeline[1:] 35 | test_pipeline = Compose(test_pipeline) 36 | # prepare data 37 | data = dict(img=img_path) 38 | data = test_pipeline(data) 39 | imgs = data['img'] 40 | img_metas = [i.data for i in data['img_metas']] 41 | 42 | if rescale_shape is not None: 43 | for img_meta in img_metas: 44 | img_meta['ori_shape'] = tuple(rescale_shape) + (3, ) 45 | 46 | mm_inputs = {'imgs': imgs, 'img_metas': img_metas} 47 | 48 | return mm_inputs 49 | 50 | 51 | def _update_input_img(img_list: Iterable, img_meta_list: Iterable): 52 | # update img and its meta list 53 | N = img_list[0].size(0) 54 | img_meta = img_meta_list[0][0] 55 | img_shape = img_meta['img_shape'] 56 | ori_shape = img_meta['ori_shape'] 57 | pad_shape = img_meta['pad_shape'] 58 | new_img_meta_list = [[{ 59 | 'img_shape': 60 | img_shape, 61 | 'ori_shape': 62 | ori_shape, 63 | 'pad_shape': 64 | pad_shape, 65 | 'filename': 66 | img_meta['filename'], 67 | 'scale_factor': 68 | (img_shape[1] / ori_shape[1], img_shape[0] / ori_shape[0]) * 2, 69 | 'flip': 70 | False, 71 | } for _ in range(N)]] 72 | 73 | return img_list, new_img_meta_list 74 | 75 | 76 | def show_result_pyplot(img: Union[str, np.ndarray], 77 | result: np.ndarray, 78 | palette: Optional[Iterable] = None, 79 | fig_size: Iterable[int] = (15, 10), 80 | opacity: float = 0.5, 81 | title: str = '', 82 | block: bool = True): 83 | img = mmcv.imread(img) 84 | img = img.copy() 85 | seg = result[0] 86 | seg = mmcv.imresize(seg, img.shape[:2][::-1]) 87 | palette = np.array(palette) 88 | assert palette.shape[1] == 3 89 | assert len(palette.shape) == 2 90 | assert 0 < opacity <= 1.0 91 | color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) 92 | for label, color in enumerate(palette): 93 | color_seg[seg == label, :] = color 94 | # convert to BGR 95 | color_seg = color_seg[..., ::-1] 96 | 97 | img = img * (1 - opacity) + color_seg * opacity 98 | img = img.astype(np.uint8) 99 | 100 | plt.figure(figsize=fig_size) 101 | plt.imshow(mmcv.bgr2rgb(img)) 102 | plt.title(title) 103 | plt.tight_layout() 104 | plt.show(block=block) 105 | 106 | 107 | def onnx2tensorrt(onnx_file: str, 108 | trt_file: str, 109 | config: dict, 110 | input_config: dict, 111 | fp16: bool = False, 112 | verify: bool = False, 113 | show: bool = False, 114 | dataset: str = 'CityscapesDataset', 115 | workspace_size: int = 1, 116 | verbose: bool = False): 117 | import tensorrt as trt 118 | min_shape = input_config['min_shape'] 119 | max_shape = input_config['max_shape'] 120 | # create trt engine and wraper 121 | opt_shape_dict = {'input': [min_shape, min_shape, max_shape]} 122 | max_workspace_size = get_GiB(workspace_size) 123 | trt_engine = onnx2trt( 124 | onnx_file, 125 | opt_shape_dict, 126 | log_level=trt.Logger.VERBOSE if verbose else trt.Logger.ERROR, 127 | fp16_mode=fp16, 128 | max_workspace_size=max_workspace_size) 129 | save_dir, _ = osp.split(trt_file) 130 | if save_dir: 131 | os.makedirs(save_dir, exist_ok=True) 132 | save_trt_engine(trt_engine, trt_file) 133 | print(f'Successfully created TensorRT engine: {trt_file}') 134 | 135 | if verify: 136 | inputs = _prepare_input_img( 137 | input_config['input_path'], 138 | config.data.test.pipeline, 139 | shape=min_shape[2:]) 140 | 141 | imgs = inputs['imgs'] 142 | img_metas = inputs['img_metas'] 143 | img_list = [img[None, :] for img in imgs] 144 | img_meta_list = [[img_meta] for img_meta in img_metas] 145 | # update img_meta 146 | img_list, img_meta_list = _update_input_img(img_list, img_meta_list) 147 | 148 | if max_shape[0] > 1: 149 | # concate flip image for batch test 150 | flip_img_list = [_.flip(-1) for _ in img_list] 151 | img_list = [ 152 | torch.cat((ori_img, flip_img), 0) 153 | for ori_img, flip_img in zip(img_list, flip_img_list) 154 | ] 155 | 156 | # Get results from ONNXRuntime 157 | ort_custom_op_path = get_onnxruntime_op_path() 158 | session_options = ort.SessionOptions() 159 | if osp.exists(ort_custom_op_path): 160 | session_options.register_custom_ops_library(ort_custom_op_path) 161 | sess = ort.InferenceSession(onnx_file, session_options) 162 | sess.set_providers(['CPUExecutionProvider'], [{}]) # use cpu mode 163 | onnx_output = sess.run(['output'], 164 | {'input': img_list[0].detach().numpy()})[0][0] 165 | 166 | # Get results from TensorRT 167 | trt_model = TRTWraper(trt_file, ['input'], ['output']) 168 | with torch.no_grad(): 169 | trt_outputs = trt_model({'input': img_list[0].contiguous().cuda()}) 170 | trt_output = trt_outputs['output'][0].cpu().detach().numpy() 171 | 172 | if show: 173 | dataset = DATASETS.get(dataset) 174 | assert dataset is not None 175 | palette = dataset.PALETTE 176 | 177 | show_result_pyplot( 178 | input_config['input_path'], 179 | (onnx_output[0].astype(np.uint8), ), 180 | palette=palette, 181 | title='ONNXRuntime', 182 | block=False) 183 | show_result_pyplot( 184 | input_config['input_path'], (trt_output[0].astype(np.uint8), ), 185 | palette=palette, 186 | title='TensorRT') 187 | 188 | np.testing.assert_allclose( 189 | onnx_output, trt_output, rtol=1e-03, atol=1e-05) 190 | print('TensorRT and ONNXRuntime output all close.') 191 | 192 | 193 | def parse_args(): 194 | parser = argparse.ArgumentParser( 195 | description='Convert MMSegmentation models from ONNX to TensorRT') 196 | parser.add_argument('config', help='Config file of the model') 197 | parser.add_argument('model', help='Path to the input ONNX model') 198 | parser.add_argument( 199 | '--trt-file', type=str, help='Path to the output TensorRT engine') 200 | parser.add_argument( 201 | '--max-shape', 202 | type=int, 203 | nargs=4, 204 | default=[1, 3, 400, 600], 205 | help='Maximum shape of model input.') 206 | parser.add_argument( 207 | '--min-shape', 208 | type=int, 209 | nargs=4, 210 | default=[1, 3, 400, 600], 211 | help='Minimum shape of model input.') 212 | parser.add_argument('--fp16', action='store_true', help='Enable fp16 mode') 213 | parser.add_argument( 214 | '--workspace-size', 215 | type=int, 216 | default=1, 217 | help='Max workspace size in GiB') 218 | parser.add_argument( 219 | '--input-img', type=str, default='', help='Image for test') 220 | parser.add_argument( 221 | '--show', action='store_true', help='Whether to show output results') 222 | parser.add_argument( 223 | '--dataset', 224 | type=str, 225 | default='CityscapesDataset', 226 | help='Dataset name') 227 | parser.add_argument( 228 | '--verify', 229 | action='store_true', 230 | help='Verify the outputs of ONNXRuntime and TensorRT') 231 | parser.add_argument( 232 | '--verbose', 233 | action='store_true', 234 | help='Whether to verbose logging messages while creating \ 235 | TensorRT engine.') 236 | args = parser.parse_args() 237 | return args 238 | 239 | 240 | if __name__ == '__main__': 241 | 242 | assert is_tensorrt_plugin_loaded(), 'TensorRT plugin should be compiled.' 243 | args = parse_args() 244 | 245 | if not args.input_img: 246 | args.input_img = osp.join(osp.dirname(__file__), '../demo/demo.png') 247 | 248 | # check arguments 249 | assert osp.exists(args.config), 'Config {} not found.'.format(args.config) 250 | assert osp.exists(args.model), \ 251 | 'ONNX model {} not found.'.format(args.model) 252 | assert args.workspace_size >= 0, 'Workspace size less than 0.' 253 | assert DATASETS.get(args.dataset) is not None, \ 254 | 'Dataset {} does not found.'.format(args.dataset) 255 | for max_value, min_value in zip(args.max_shape, args.min_shape): 256 | assert max_value >= min_value, \ 257 | 'max_shape sould be larger than min shape' 258 | 259 | input_config = { 260 | 'min_shape': args.min_shape, 261 | 'max_shape': args.max_shape, 262 | 'input_path': args.input_img 263 | } 264 | 265 | cfg = mmcv.Config.fromfile(args.config) 266 | onnx2tensorrt( 267 | args.model, 268 | args.trt_file, 269 | cfg, 270 | input_config, 271 | fp16=args.fp16, 272 | verify=args.verify, 273 | show=args.show, 274 | dataset=args.dataset, 275 | workspace_size=args.workspace_size, 276 | verbose=args.verbose) 277 | -------------------------------------------------------------------------------- /segmentation/tools/print_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | 4 | from mmcv import Config, DictAction 5 | 6 | from mmseg.apis import init_segmentor 7 | 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser(description='Print the whole config') 11 | parser.add_argument('config', help='config file path') 12 | parser.add_argument( 13 | '--graph', action='store_true', help='print the models graph') 14 | parser.add_argument( 15 | '--options', nargs='+', action=DictAction, help='arguments in dict') 16 | args = parser.parse_args() 17 | 18 | return args 19 | 20 | 21 | def main(): 22 | args = parse_args() 23 | 24 | cfg = Config.fromfile(args.config) 25 | if args.options is not None: 26 | cfg.merge_from_dict(args.options) 27 | print(f'Config:\n{cfg.pretty_text}') 28 | # dump config 29 | cfg.dump('example.py') 30 | # dump models graph 31 | if args.graph: 32 | model = init_segmentor(args.config, device='cpu') 33 | print(f'Model graph:\n{str(model)}') 34 | with open('example-graph.txt', 'w') as f: 35 | f.writelines(str(model)) 36 | 37 | 38 | if __name__ == '__main__': 39 | main() 40 | -------------------------------------------------------------------------------- /segmentation/tools/publish_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import subprocess 4 | 5 | import torch 6 | 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser( 10 | description='Process a checkpoint to be published') 11 | parser.add_argument('in_file', help='input checkpoint filename') 12 | parser.add_argument('out_file', help='output checkpoint filename') 13 | args = parser.parse_args() 14 | return args 15 | 16 | 17 | def process_checkpoint(in_file, out_file): 18 | checkpoint = torch.load(in_file, map_location='cpu') 19 | # remove optimizer for smaller file size 20 | if 'optimizer' in checkpoint: 21 | del checkpoint['optimizer'] 22 | # if it is necessary to remove some sensitive data in checkpoint['meta'], 23 | # add the code here. 24 | torch.save(checkpoint, out_file) 25 | sha = subprocess.check_output(['sha256sum', out_file]).decode() 26 | final_file = out_file.rstrip('.pth') + '-{}.pth'.format(sha[:8]) 27 | subprocess.Popen(['mv', out_file, final_file]) 28 | 29 | 30 | def main(): 31 | args = parse_args() 32 | process_checkpoint(args.in_file, args.out_file) 33 | 34 | 35 | if __name__ == '__main__': 36 | main() 37 | -------------------------------------------------------------------------------- /segmentation/tools/pytorch2torchscript.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | 4 | import mmcv 5 | import numpy as np 6 | import torch 7 | import torch._C 8 | import torch.serialization 9 | from mmcv.runner import load_checkpoint 10 | from torch import nn 11 | 12 | from mmseg.models import build_segmentor 13 | 14 | torch.manual_seed(3) 15 | 16 | 17 | def digit_version(version_str): 18 | digit_version = [] 19 | for x in version_str.split('.'): 20 | if x.isdigit(): 21 | digit_version.append(int(x)) 22 | elif x.find('rc') != -1: 23 | patch_version = x.split('rc') 24 | digit_version.append(int(patch_version[0]) - 1) 25 | digit_version.append(int(patch_version[1])) 26 | return digit_version 27 | 28 | 29 | def check_torch_version(): 30 | torch_minimum_version = '1.8.0' 31 | torch_version = digit_version(torch.__version__) 32 | 33 | assert (torch_version >= digit_version(torch_minimum_version)), \ 34 | f'Torch=={torch.__version__} is not support for converting to ' \ 35 | f'torchscript. Please install pytorch>={torch_minimum_version}.' 36 | 37 | 38 | def _convert_batchnorm(module): 39 | module_output = module 40 | if isinstance(module, torch.nn.SyncBatchNorm): 41 | module_output = torch.nn.BatchNorm2d(module.num_features, module.eps, 42 | module.momentum, module.affine, 43 | module.track_running_stats) 44 | if module.affine: 45 | module_output.weight.data = module.weight.data.clone().detach() 46 | module_output.bias.data = module.bias.data.clone().detach() 47 | # keep requires_grad unchanged 48 | module_output.weight.requires_grad = module.weight.requires_grad 49 | module_output.bias.requires_grad = module.bias.requires_grad 50 | module_output.running_mean = module.running_mean 51 | module_output.running_var = module.running_var 52 | module_output.num_batches_tracked = module.num_batches_tracked 53 | for name, child in module.named_children(): 54 | module_output.add_module(name, _convert_batchnorm(child)) 55 | del module 56 | return module_output 57 | 58 | 59 | def _demo_mm_inputs(input_shape, num_classes): 60 | """Create a superset of inputs needed to run test or train batches. 61 | 62 | Args: 63 | input_shape (tuple): 64 | input batch dimensions 65 | num_classes (int): 66 | number of semantic classes 67 | """ 68 | (N, C, H, W) = input_shape 69 | rng = np.random.RandomState(0) 70 | imgs = rng.rand(*input_shape) 71 | segs = rng.randint( 72 | low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8) 73 | img_metas = [{ 74 | 'img_shape': (H, W, C), 75 | 'ori_shape': (H, W, C), 76 | 'pad_shape': (H, W, C), 77 | 'filename': '.png', 78 | 'scale_factor': 1.0, 79 | 'flip': False, 80 | } for _ in range(N)] 81 | mm_inputs = { 82 | 'imgs': torch.FloatTensor(imgs).requires_grad_(True), 83 | 'img_metas': img_metas, 84 | 'gt_semantic_seg': torch.LongTensor(segs) 85 | } 86 | return mm_inputs 87 | 88 | 89 | def pytorch2libtorch(model, 90 | input_shape, 91 | show=False, 92 | output_file='tmp.pt', 93 | verify=False): 94 | """Export Pytorch model to TorchScript model and verify the outputs are 95 | same between Pytorch and TorchScript. 96 | 97 | Args: 98 | model (nn.Module): Pytorch model we want to export. 99 | input_shape (tuple): Use this input shape to construct 100 | the corresponding dummy input and execute the model. 101 | show (bool): Whether print the computation graph. Default: False. 102 | output_file (string): The path to where we store the 103 | output TorchScript model. Default: `tmp.pt`. 104 | verify (bool): Whether compare the outputs between 105 | Pytorch and TorchScript. Default: False. 106 | """ 107 | if isinstance(model.decode_head, nn.ModuleList): 108 | num_classes = model.decode_head[-1].num_classes 109 | else: 110 | num_classes = model.decode_head.num_classes 111 | 112 | mm_inputs = _demo_mm_inputs(input_shape, num_classes) 113 | 114 | imgs = mm_inputs.pop('imgs') 115 | 116 | # replace the orginal forword with forward_dummy 117 | model.forward = model.forward_dummy 118 | model.eval() 119 | traced_model = torch.jit.trace( 120 | model, 121 | example_inputs=imgs, 122 | check_trace=verify, 123 | ) 124 | 125 | if show: 126 | print(traced_model.graph) 127 | 128 | traced_model.save(output_file) 129 | print('Successfully exported TorchScript model: {}'.format(output_file)) 130 | 131 | 132 | def parse_args(): 133 | parser = argparse.ArgumentParser( 134 | description='Convert MMSeg to TorchScript') 135 | parser.add_argument('config', help='test config file path') 136 | parser.add_argument('--checkpoint', help='checkpoint file', default=None) 137 | parser.add_argument( 138 | '--show', action='store_true', help='show TorchScript graph') 139 | parser.add_argument( 140 | '--verify', action='store_true', help='verify the TorchScript model') 141 | parser.add_argument('--output-file', type=str, default='tmp.pt') 142 | parser.add_argument( 143 | '--shape', 144 | type=int, 145 | nargs='+', 146 | default=[512, 512], 147 | help='input image size (height, width)') 148 | args = parser.parse_args() 149 | return args 150 | 151 | 152 | if __name__ == '__main__': 153 | args = parse_args() 154 | check_torch_version() 155 | 156 | if len(args.shape) == 1: 157 | input_shape = (1, 3, args.shape[0], args.shape[0]) 158 | elif len(args.shape) == 2: 159 | input_shape = ( 160 | 1, 161 | 3, 162 | ) + tuple(args.shape) 163 | else: 164 | raise ValueError('invalid input shape') 165 | 166 | cfg = mmcv.Config.fromfile(args.config) 167 | cfg.model.pretrained = None 168 | 169 | # build the model and load checkpoint 170 | cfg.model.train_cfg = None 171 | segmentor = build_segmentor( 172 | cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg')) 173 | # convert SyncBN to BN 174 | segmentor = _convert_batchnorm(segmentor) 175 | 176 | if args.checkpoint: 177 | load_checkpoint(segmentor, args.checkpoint, map_location='cpu') 178 | 179 | # convert the PyTorch model to LibTorch model 180 | pytorch2libtorch( 181 | segmentor, 182 | input_shape, 183 | show=args.show, 184 | output_file=args.output_file, 185 | verify=args.verify) 186 | -------------------------------------------------------------------------------- /segmentation/tools/slurm_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | CONFIG=$3 8 | CHECKPOINT=$4 9 | GPUS=${GPUS:-4} 10 | GPUS_PER_NODE=${GPUS_PER_NODE:-4} 11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 12 | PY_ARGS=${@:5} 13 | SRUN_ARGS=${SRUN_ARGS:-""} 14 | 15 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 16 | srun -p ${PARTITION} \ 17 | --job-name=${JOB_NAME} \ 18 | --gres=gpu:${GPUS_PER_NODE} \ 19 | --ntasks=${GPUS} \ 20 | --ntasks-per-node=${GPUS_PER_NODE} \ 21 | --cpus-per-task=${CPUS_PER_TASK} \ 22 | --kill-on-bad-exit=1 \ 23 | ${SRUN_ARGS} \ 24 | python -u tools/test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS} 25 | -------------------------------------------------------------------------------- /segmentation/tools/slurm_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | CONFIG=$3 8 | GPUS=${GPUS:-4} 9 | GPUS_PER_NODE=${GPUS_PER_NODE:-4} 10 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 11 | SRUN_ARGS=${SRUN_ARGS:-""} 12 | PY_ARGS=${@:4} 13 | 14 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 15 | srun -p ${PARTITION} \ 16 | --job-name=${JOB_NAME} \ 17 | --gres=gpu:${GPUS_PER_NODE} \ 18 | --ntasks=${GPUS} \ 19 | --ntasks-per-node=${GPUS_PER_NODE} \ 20 | --cpus-per-task=${CPUS_PER_TASK} \ 21 | --kill-on-bad-exit=1 \ 22 | ${SRUN_ARGS} \ 23 | python -u tools/train.py ${CONFIG} --launcher="slurm" ${PY_ARGS} 24 | -------------------------------------------------------------------------------- /segmentation/tools/test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os 4 | import os.path as osp 5 | import shutil 6 | import time 7 | import warnings 8 | 9 | import mmcv 10 | import torch 11 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel 12 | from mmcv.runner import (get_dist_info, init_dist, load_checkpoint, 13 | wrap_fp16_model) 14 | from mmcv.utils import DictAction 15 | 16 | from mmseg.apis import multi_gpu_test, single_gpu_test 17 | from mmseg.datasets import build_dataloader, build_dataset 18 | from mmseg.models import build_segmentor 19 | 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser( 23 | description='mmseg test (and eval) a model') 24 | parser.add_argument('config', help='test config file path') 25 | parser.add_argument('checkpoint', help='checkpoint file') 26 | parser.add_argument( 27 | '--work-dir', 28 | help=('if specified, the evaluation metric results will be dumped' 29 | 'into the directory as json')) 30 | parser.add_argument( 31 | '--aug-test', action='store_true', help='Use Flip and Multi scale aug') 32 | parser.add_argument('--out', help='output result file in pickle format') 33 | parser.add_argument( 34 | '--format-only', 35 | action='store_true', 36 | help='Format the output results without perform evaluation. It is' 37 | 'useful when you want to format the result to a specific format and ' 38 | 'submit it to the test server') 39 | parser.add_argument( 40 | '--eval', 41 | type=str, 42 | nargs='+', 43 | help='evaluation metrics, which depends on the dataset, e.g., "mIoU"' 44 | ' for generic datasets, and "cityscapes" for Cityscapes') 45 | parser.add_argument('--show', action='store_true', help='show results') 46 | parser.add_argument( 47 | '--show-dir', help='directory where painted images will be saved') 48 | parser.add_argument( 49 | '--gpu-collect', 50 | action='store_true', 51 | help='whether to use gpu to collect results.') 52 | parser.add_argument( 53 | '--tmpdir', 54 | help='tmp directory used for collecting results from multiple ' 55 | 'workers, available when gpu_collect is not specified') 56 | parser.add_argument( 57 | '--options', nargs='+', action=DictAction, help='custom options') 58 | parser.add_argument( 59 | '--eval-options', 60 | nargs='+', 61 | action=DictAction, 62 | help='custom options for evaluation') 63 | parser.add_argument( 64 | '--launcher', 65 | choices=['none', 'pytorch', 'slurm', 'mpi'], 66 | default='none', 67 | help='job launcher') 68 | parser.add_argument( 69 | '--opacity', 70 | type=float, 71 | default=0.5, 72 | help='Opacity of painted segmentation map. In (0, 1] range.') 73 | parser.add_argument('--local_rank', type=int, default=0) 74 | args = parser.parse_args() 75 | if 'LOCAL_RANK' not in os.environ: 76 | os.environ['LOCAL_RANK'] = str(args.local_rank) 77 | return args 78 | 79 | 80 | def main(): 81 | args = parse_args() 82 | 83 | assert args.out or args.eval or args.format_only or args.show \ 84 | or args.show_dir, \ 85 | ('Please specify at least one operation (save/eval/format/show the ' 86 | 'results / save the results) with the argument "--out", "--eval"' 87 | ', "--format-only", "--show" or "--show-dir"') 88 | 89 | if args.eval and args.format_only: 90 | raise ValueError('--eval and --format_only cannot be both specified') 91 | 92 | if args.out is not None and not args.out.endswith(('.pkl', '.pickle')): 93 | raise ValueError('The output file must be a pkl file.') 94 | 95 | cfg = mmcv.Config.fromfile(args.config) 96 | if args.options is not None: 97 | cfg.merge_from_dict(args.options) 98 | # set cudnn_benchmark 99 | if cfg.get('cudnn_benchmark', False): 100 | torch.backends.cudnn.benchmark = True 101 | if args.aug_test: 102 | # hard code index 103 | cfg.data.test.pipeline[1].img_ratios = [ 104 | 0.5, 0.75, 1.0, 1.25, 1.5, 1.75 105 | ] 106 | cfg.data.test.pipeline[1].flip = True 107 | cfg.model.pretrained = None 108 | cfg.data.test.test_mode = True 109 | 110 | # init distributed env first, since logger depends on the dist info. 111 | if args.launcher == 'none': 112 | distributed = False 113 | else: 114 | distributed = True 115 | init_dist(args.launcher, **cfg.dist_params) 116 | 117 | rank, _ = get_dist_info() 118 | # allows not to create 119 | if args.work_dir is not None and rank == 0: 120 | mmcv.mkdir_or_exist(osp.abspath(args.work_dir)) 121 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 122 | json_file = osp.join(args.work_dir, f'eval_{timestamp}.json') 123 | 124 | # build the dataloader 125 | # TODO: support multiple images per gpu (only minor changes are needed) 126 | dataset = build_dataset(cfg.data.test) 127 | data_loader = build_dataloader( 128 | dataset, 129 | samples_per_gpu=1, 130 | workers_per_gpu=cfg.data.workers_per_gpu, 131 | dist=distributed, 132 | shuffle=False) 133 | 134 | # build the model and load checkpoint 135 | cfg.model.train_cfg = None 136 | model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg')) 137 | fp16_cfg = cfg.get('fp16', None) 138 | if fp16_cfg is not None: 139 | wrap_fp16_model(model) 140 | checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu') 141 | if 'CLASSES' in checkpoint.get('meta', {}): 142 | model.CLASSES = checkpoint['meta']['CLASSES'] 143 | else: 144 | print('"CLASSES" not found in meta, use dataset.CLASSES instead') 145 | model.CLASSES = dataset.CLASSES 146 | if 'PALETTE' in checkpoint.get('meta', {}): 147 | model.PALETTE = checkpoint['meta']['PALETTE'] 148 | else: 149 | print('"PALETTE" not found in meta, use dataset.PALETTE instead') 150 | model.PALETTE = dataset.PALETTE 151 | 152 | # clean gpu memory when starting a new evaluation. 153 | torch.cuda.empty_cache() 154 | eval_kwargs = {} if args.eval_options is None else args.eval_options 155 | 156 | # Deprecated 157 | efficient_test = eval_kwargs.get('efficient_test', False) 158 | if efficient_test: 159 | warnings.warn( 160 | '``efficient_test=True`` does not have effect in tools/test.py, ' 161 | 'the evaluation and format results are CPU memory efficient by ' 162 | 'default') 163 | 164 | eval_on_format_results = ( 165 | args.eval is not None and 'cityscapes' in args.eval) 166 | if eval_on_format_results: 167 | assert len(args.eval) == 1, 'eval on format results is not ' \ 168 | 'applicable for metrics other than ' \ 169 | 'cityscapes' 170 | if args.format_only or eval_on_format_results: 171 | if 'imgfile_prefix' in eval_kwargs: 172 | tmpdir = eval_kwargs['imgfile_prefix'] 173 | else: 174 | tmpdir = '.format_cityscapes' 175 | eval_kwargs.setdefault('imgfile_prefix', tmpdir) 176 | mmcv.mkdir_or_exist(tmpdir) 177 | else: 178 | tmpdir = None 179 | 180 | if not distributed: 181 | model = MMDataParallel(model, device_ids=[0]) 182 | results = single_gpu_test( 183 | model, 184 | data_loader, 185 | args.show, 186 | args.show_dir, 187 | False, 188 | args.opacity, 189 | pre_eval=args.eval is not None and not eval_on_format_results, 190 | format_only=args.format_only or eval_on_format_results, 191 | format_args=eval_kwargs) 192 | else: 193 | model = MMDistributedDataParallel( 194 | model.cuda(), 195 | device_ids=[torch.cuda.current_device()], 196 | broadcast_buffers=False) 197 | results = multi_gpu_test( 198 | model, 199 | data_loader, 200 | args.tmpdir, 201 | args.gpu_collect, 202 | False, 203 | pre_eval=args.eval is not None and not eval_on_format_results, 204 | format_only=args.format_only or eval_on_format_results, 205 | format_args=eval_kwargs) 206 | 207 | rank, _ = get_dist_info() 208 | if rank == 0: 209 | if args.out: 210 | warnings.warn( 211 | 'The behavior of ``args.out`` has been changed since MMSeg ' 212 | 'v0.16, the pickled outputs could be seg map as type of ' 213 | 'np.array, pre-eval results or file paths for ' 214 | '``dataset.format_results()``.') 215 | print(f'\nwriting results to {args.out}') 216 | mmcv.dump(results, args.out) 217 | if args.eval: 218 | eval_kwargs.update(metric=args.eval) 219 | metric = dataset.evaluate(results, **eval_kwargs) 220 | metric_dict = dict(config=args.config, metric=metric) 221 | if args.work_dir is not None and rank == 0: 222 | mmcv.dump(metric_dict, json_file, indent=4) 223 | if tmpdir is not None and eval_on_format_results: 224 | # remove tmp dir when cityscapes evaluation 225 | shutil.rmtree(tmpdir) 226 | 227 | 228 | if __name__ == '__main__': 229 | main() 230 | -------------------------------------------------------------------------------- /segmentation/tools/torchserve/mmseg2torchserve.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from argparse import ArgumentParser, Namespace 3 | from pathlib import Path 4 | from tempfile import TemporaryDirectory 5 | 6 | import mmcv 7 | 8 | try: 9 | from model_archiver.model_packaging import package_model 10 | from model_archiver.model_packaging_utils import ModelExportUtils 11 | except ImportError: 12 | package_model = None 13 | 14 | 15 | def mmseg2torchserve( 16 | config_file: str, 17 | checkpoint_file: str, 18 | output_folder: str, 19 | model_name: str, 20 | model_version: str = '1.0', 21 | force: bool = False, 22 | ): 23 | """Converts mmsegmentation model (config + checkpoint) to TorchServe 24 | `.mar`. 25 | 26 | Args: 27 | config_file: 28 | In MMSegmentation config format. 29 | The contents vary for each task repository. 30 | checkpoint_file: 31 | In MMSegmentation checkpoint format. 32 | The contents vary for each task repository. 33 | output_folder: 34 | Folder where `{model_name}.mar` will be created. 35 | The file created will be in TorchServe archive format. 36 | model_name: 37 | If not None, used for naming the `{model_name}.mar` file 38 | that will be created under `output_folder`. 39 | If None, `{Path(checkpoint_file).stem}` will be used. 40 | model_version: 41 | Model's version. 42 | force: 43 | If True, if there is an existing `{model_name}.mar` 44 | file under `output_folder` it will be overwritten. 45 | """ 46 | mmcv.mkdir_or_exist(output_folder) 47 | 48 | config = mmcv.Config.fromfile(config_file) 49 | 50 | with TemporaryDirectory() as tmpdir: 51 | config.dump(f'{tmpdir}/config.py') 52 | 53 | args = Namespace( 54 | **{ 55 | 'model_file': f'{tmpdir}/config.py', 56 | 'serialized_file': checkpoint_file, 57 | 'handler': f'{Path(__file__).parent}/mmseg_handler.py', 58 | 'model_name': model_name or Path(checkpoint_file).stem, 59 | 'version': model_version, 60 | 'export_path': output_folder, 61 | 'force': force, 62 | 'requirements_file': None, 63 | 'extra_files': None, 64 | 'runtime': 'python', 65 | 'archive_format': 'default' 66 | }) 67 | manifest = ModelExportUtils.generate_manifest_json(args) 68 | package_model(args, manifest) 69 | 70 | 71 | def parse_args(): 72 | parser = ArgumentParser( 73 | description='Convert mmseg models to TorchServe `.mar` format.') 74 | parser.add_argument('config', type=str, help='config file path') 75 | parser.add_argument('checkpoint', type=str, help='checkpoint file path') 76 | parser.add_argument( 77 | '--output-folder', 78 | type=str, 79 | required=True, 80 | help='Folder where `{model_name}.mar` will be created.') 81 | parser.add_argument( 82 | '--model-name', 83 | type=str, 84 | default=None, 85 | help='If not None, used for naming the `{model_name}.mar`' 86 | 'file that will be created under `output_folder`.' 87 | 'If None, `{Path(checkpoint_file).stem}` will be used.') 88 | parser.add_argument( 89 | '--model-version', 90 | type=str, 91 | default='1.0', 92 | help='Number used for versioning.') 93 | parser.add_argument( 94 | '-f', 95 | '--force', 96 | action='store_true', 97 | help='overwrite the existing `{model_name}.mar`') 98 | args = parser.parse_args() 99 | 100 | return args 101 | 102 | 103 | if __name__ == '__main__': 104 | args = parse_args() 105 | 106 | if package_model is None: 107 | raise ImportError('`torch-model-archiver` is required.' 108 | 'Try: pip install torch-model-archiver') 109 | 110 | mmseg2torchserve(args.config, args.checkpoint, args.output_folder, 111 | args.model_name, args.model_version, args.force) 112 | -------------------------------------------------------------------------------- /segmentation/tools/torchserve/mmseg_handler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import base64 3 | import os 4 | 5 | import cv2 6 | import mmcv 7 | import torch 8 | from mmcv.cnn.utils.sync_bn import revert_sync_batchnorm 9 | from ts.torch_handler.base_handler import BaseHandler 10 | 11 | from mmseg.apis import inference_segmentor, init_segmentor 12 | 13 | 14 | class MMsegHandler(BaseHandler): 15 | 16 | def initialize(self, context): 17 | properties = context.system_properties 18 | self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu' 19 | self.device = torch.device(self.map_location + ':' + 20 | str(properties.get('gpu_id')) if torch.cuda. 21 | is_available() else self.map_location) 22 | self.manifest = context.manifest 23 | 24 | model_dir = properties.get('model_dir') 25 | serialized_file = self.manifest['model']['serializedFile'] 26 | checkpoint = os.path.join(model_dir, serialized_file) 27 | self.config_file = os.path.join(model_dir, 'config.py') 28 | 29 | self.model = init_segmentor(self.config_file, checkpoint, self.device) 30 | self.model = revert_sync_batchnorm(self.model) 31 | self.initialized = True 32 | 33 | def preprocess(self, data): 34 | images = [] 35 | 36 | for row in data: 37 | image = row.get('data') or row.get('body') 38 | if isinstance(image, str): 39 | image = base64.b64decode(image) 40 | image = mmcv.imfrombytes(image) 41 | images.append(image) 42 | 43 | return images 44 | 45 | def inference(self, data, *args, **kwargs): 46 | results = [inference_segmentor(self.model, img) for img in data] 47 | return results 48 | 49 | def postprocess(self, data): 50 | output = [] 51 | 52 | for image_result in data: 53 | _, buffer = cv2.imencode('.png', image_result[0].astype('uint8')) 54 | bast64_data = base64.b64encode(buffer.tobytes()) 55 | bast64_str = str(bast64_data, 'utf-8') 56 | output.append(bast64_str) 57 | return output 58 | -------------------------------------------------------------------------------- /segmentation/tools/torchserve/test_torchserve.py: -------------------------------------------------------------------------------- 1 | import base64 2 | from argparse import ArgumentParser 3 | from io import BytesIO 4 | 5 | import matplotlib.pyplot as plt 6 | import mmcv 7 | import requests 8 | 9 | from mmseg.apis import inference_segmentor, init_segmentor 10 | 11 | 12 | def parse_args(): 13 | parser = ArgumentParser( 14 | description='Compare result of torchserve and pytorch,' 15 | 'and visualize them.') 16 | parser.add_argument('img', help='Image file') 17 | parser.add_argument('config', help='Config file') 18 | parser.add_argument('checkpoint', help='Checkpoint file') 19 | parser.add_argument('model_name', help='The model name in the server') 20 | parser.add_argument( 21 | '--inference-addr', 22 | default='127.0.0.1:8080', 23 | help='Address and port of the inference server') 24 | parser.add_argument( 25 | '--result-image', 26 | type=str, 27 | default=None, 28 | help='save server output in result-image') 29 | parser.add_argument( 30 | '--device', default='cuda:0', help='Device used for inference') 31 | 32 | args = parser.parse_args() 33 | return args 34 | 35 | 36 | def main(args): 37 | url = 'http://' + args.inference_addr + '/predictions/' + args.model_name 38 | with open(args.img, 'rb') as image: 39 | tmp_res = requests.post(url, image) 40 | base64_str = tmp_res.content 41 | buffer = base64.b64decode(base64_str) 42 | if args.result_image: 43 | with open(args.result_image, 'wb') as out_image: 44 | out_image.write(buffer) 45 | plt.imshow(mmcv.imread(args.result_image, 'grayscale')) 46 | plt.show() 47 | else: 48 | plt.imshow(plt.imread(BytesIO(buffer))) 49 | plt.show() 50 | model = init_segmentor(args.config, args.checkpoint, args.device) 51 | image = mmcv.imread(args.img) 52 | result = inference_segmentor(model, image) 53 | plt.imshow(result[0]) 54 | plt.show() 55 | 56 | 57 | if __name__ == '__main__': 58 | args = parse_args() 59 | main(args) 60 | -------------------------------------------------------------------------------- /segmentation/tools/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import copy 4 | import os 5 | import os.path as osp 6 | import time 7 | import warnings 8 | 9 | import mmcv 10 | import torch 11 | from mmcv.cnn.utils import revert_sync_batchnorm 12 | from mmcv.runner import get_dist_info, init_dist 13 | from mmcv.utils import Config, DictAction, get_git_hash 14 | 15 | from mmseg import __version__ 16 | from mmseg.apis import set_random_seed, train_segmentor 17 | from mmseg.datasets import build_dataset 18 | from mmseg.models import build_segmentor 19 | from mmseg.utils import collect_env, get_root_logger 20 | from backbone import vit_PASS 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser(description='Train a segmentor') 24 | parser.add_argument('config', help='train config file path') 25 | parser.add_argument('--work-dir', help='the dir to save logs and models') 26 | parser.add_argument( 27 | '--load-from', help='the checkpoint file to load weights from') 28 | parser.add_argument( 29 | '--resume-from', help='the checkpoint file to resume from') 30 | parser.add_argument( 31 | '--no-validate', 32 | action='store_true', 33 | help='whether not to evaluate the checkpoint during training') 34 | group_gpus = parser.add_mutually_exclusive_group() 35 | group_gpus.add_argument( 36 | '--gpus', 37 | type=int, 38 | help='number of gpus to use ' 39 | '(only applicable to non-distributed training)') 40 | group_gpus.add_argument( 41 | '--gpu-ids', 42 | type=int, 43 | nargs='+', 44 | help='ids of gpus to use ' 45 | '(only applicable to non-distributed training)') 46 | parser.add_argument('--seed', type=int, default=None, help='random seed') 47 | parser.add_argument( 48 | '--deterministic', 49 | action='store_true', 50 | help='whether to set deterministic options for CUDNN backend.') 51 | parser.add_argument( 52 | '--options', nargs='+', action=DictAction, help='custom options') 53 | parser.add_argument( 54 | '--launcher', 55 | choices=['none', 'pytorch', 'slurm', 'mpi'], 56 | default='none', 57 | help='job launcher') 58 | parser.add_argument('--local_rank', type=int, default=0) 59 | args = parser.parse_args() 60 | if 'LOCAL_RANK' not in os.environ: 61 | os.environ['LOCAL_RANK'] = str(args.local_rank) 62 | 63 | return args 64 | 65 | 66 | def main(): 67 | args = parse_args() 68 | 69 | cfg = Config.fromfile(args.config) 70 | if args.options is not None: 71 | cfg.merge_from_dict(args.options) 72 | # set cudnn_benchmark 73 | if cfg.get('cudnn_benchmark', False): 74 | torch.backends.cudnn.benchmark = True 75 | 76 | # work_dir is determined in this priority: CLI > segment in file > filename 77 | if args.work_dir is not None: 78 | # update configs according to CLI args if args.work_dir is not None 79 | cfg.work_dir = args.work_dir 80 | elif cfg.get('work_dir', None) is None: 81 | # use config filename as default work_dir if cfg.work_dir is None 82 | cfg.work_dir = osp.join('./work_dirs', 83 | osp.splitext(osp.basename(args.config))[0]) 84 | if args.load_from is not None: 85 | cfg.load_from = args.load_from 86 | if args.resume_from is not None: 87 | cfg.resume_from = args.resume_from 88 | if args.gpu_ids is not None: 89 | cfg.gpu_ids = args.gpu_ids 90 | else: 91 | cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus) 92 | 93 | # init distributed env first, since logger depends on the dist info. 94 | if args.launcher == 'none': 95 | distributed = False 96 | else: 97 | distributed = True 98 | init_dist(args.launcher, **cfg.dist_params) 99 | # gpu_ids is used to calculate iter when resuming checkpoint, 100 | _, world_size = get_dist_info() 101 | cfg.gpu_ids = range(world_size) 102 | 103 | # create work_dir 104 | mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) 105 | # dump config 106 | cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) 107 | # init the logger before other steps 108 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 109 | log_file = osp.join(cfg.work_dir, f'{timestamp}.log') 110 | logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) 111 | 112 | # init the meta dict to record some important information such as 113 | # environment info and seed, which will be logged 114 | meta = dict() 115 | # log env info 116 | env_info_dict = collect_env() 117 | env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()]) 118 | dash_line = '-' * 60 + '\n' 119 | logger.info('Environment info:\n' + dash_line + env_info + '\n' + 120 | dash_line) 121 | meta['env_info'] = env_info 122 | 123 | # log some basic info 124 | logger.info(f'Distributed training: {distributed}') 125 | logger.info(f'Config:\n{cfg.pretty_text}') 126 | 127 | # set random seeds 128 | if args.seed is not None: 129 | logger.info(f'Set random seed to {args.seed}, deterministic: ' 130 | f'{args.deterministic}') 131 | set_random_seed(args.seed, deterministic=args.deterministic) 132 | cfg.seed = args.seed 133 | meta['seed'] = args.seed 134 | meta['exp_name'] = osp.basename(args.config) 135 | 136 | model = build_segmentor( 137 | cfg.model, 138 | train_cfg=cfg.get('train_cfg'), 139 | test_cfg=cfg.get('test_cfg')) 140 | model.init_weights() 141 | 142 | # SyncBN is not support for DP 143 | if not distributed: 144 | warnings.warn( 145 | 'SyncBN is only supported with DDP. To be compatible with DP, ' 146 | 'we convert SyncBN to BN. Please use dist_train.sh which can ' 147 | 'avoid this error.') 148 | model = revert_sync_batchnorm(model) 149 | 150 | logger.info(model) 151 | 152 | datasets = [build_dataset(cfg.data.train)] 153 | if len(cfg.workflow) == 2: 154 | val_dataset = copy.deepcopy(cfg.data.val) 155 | val_dataset.pipeline = cfg.data.train.pipeline 156 | datasets.append(build_dataset(val_dataset)) 157 | if cfg.checkpoint_config is not None: 158 | # save mmseg version, config file content and class names in 159 | # checkpoints as meta data 160 | cfg.checkpoint_config.meta = dict( 161 | mmseg_version=f'{__version__}+{get_git_hash()[:7]}', 162 | config=cfg.pretty_text, 163 | CLASSES=datasets[0].CLASSES, 164 | PALETTE=datasets[0].PALETTE) 165 | # add an attribute for visualization convenience 166 | model.CLASSES = datasets[0].CLASSES 167 | # passing checkpoint meta for saving best checkpoint 168 | meta.update(cfg.checkpoint_config.meta) 169 | train_segmentor( 170 | model, 171 | datasets, 172 | cfg, 173 | distributed=distributed, 174 | validate=(not args.no_validate), 175 | timestamp=timestamp, 176 | meta=meta) 177 | 178 | 179 | if __name__ == '__main__': 180 | main() 181 | --------------------------------------------------------------------------------