├── .gitignore ├── LICENSE ├── README.md ├── configs ├── _base_ │ ├── datasets │ │ └── ade20k.py │ ├── default_runtime.py │ ├── models │ │ ├── fpn_van.py │ │ └── upernet_van.py │ └── schedules │ │ ├── schedule_160k.py │ │ ├── schedule_20k.py │ │ ├── schedule_40k.py │ │ └── schedule_80k.py ├── sem_fpn │ ├── fpn_van_b0_ade20k_40k.py │ ├── fpn_van_b1_ade20k_40k.py │ ├── fpn_van_b2_ade20k_40k.py │ ├── fpn_van_b3_ade20k_40k.py │ └── fpn_van_b4_ade20k_40k.py └── upernet │ ├── 1k_pretrained │ ├── upernet_van_b0_512x512_160k_ade20k.py │ ├── upernet_van_b1_512x512_160k_ade20k.py │ ├── upernet_van_b2_512x512_160k_ade20k.py │ ├── upernet_van_b3_512x512_160k_ade20k.py │ └── upernet_van_b4_512x512_160k_ade20k.py │ └── 22k_pretrained │ ├── upernet_van_b4_512x512_160k_ade20k_22k.py │ ├── upernet_van_b5_512x512_160k_ade20k_22k.py │ └── upernet_van_b6_512x512_160k_ade20k_22k.py ├── dist_test.sh ├── dist_train.sh ├── test.py ├── tools ├── analyze_logs.py ├── benchmark.py ├── browse_dataset.py ├── confusion_matrix.py ├── convert_datasets │ ├── chase_db1.py │ ├── cityscapes.py │ ├── coco_stuff10k.py │ ├── coco_stuff164k.py │ ├── drive.py │ ├── hrf.py │ ├── isaid.py │ ├── loveda.py │ ├── pascal_context.py │ ├── potsdam.py │ ├── stare.py │ ├── vaihingen.py │ └── voc_aug.py ├── deploy_test.py ├── flops.sh ├── get_flops.py ├── model_converters │ ├── beit2mmseg.py │ ├── mit2mmseg.py │ ├── stdc2mmseg.py │ ├── swin2mmseg.py │ ├── twins2mmseg.py │ ├── vit2mmseg.py │ └── vitjax2mmseg.py ├── onnx2tensorrt.py ├── print_config.py ├── publish_model.py ├── pytorch2onnx.py ├── pytorch2torchscript.py ├── slurm_test.sh ├── slurm_train.sh ├── torchserve │ ├── mmseg2torchserve.py │ ├── mmseg_handler.py │ └── test_torchserve.py └── van.py ├── train.py └── van.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | pretrained -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Visual Attention Network (VAN) for Segmentaion 2 | 3 | This repo is a PyTorch implementation of applying **VAN** (**Visual Attention Network**) to semantic segmentation. 4 | The code is based on [mmsegmentaion](https://github.com/open-mmlab/mmsegmentation/tree/v0.12.0). 5 | 6 | More details can be found in [**Visual Attention Network**](https://arxiv.org/abs/2202.09741). 7 | 8 | ## Citation 9 | 10 | ```bib 11 | @article{guo2022visual, 12 | title={Visual Attention Network}, 13 | author={Guo, Meng-Hao and Lu, Cheng-Ze and Liu, Zheng-Ning and Cheng, Ming-Ming and Hu, Shi-Min}, 14 | journal={arXiv preprint arXiv:2202.09741}, 15 | year={2022} 16 | } 17 | ``` 18 | 19 | ## Results 20 | 21 | **Notes**: Pre-trained models can be found in [TsingHua Cloud](https://cloud.tsinghua.edu.cn/d/0100f0cea37d41ba8d08/). 22 | 23 | ### VAN + UperNet 24 | 25 | | Method | Backbone | Pretrained | Iters | mIoU(ms) | Params | FLOPs | Config | Download | 26 | | :-------: | :-------------: | :-----: | :---: | :--: | :----: | :----: | :----: | :-------: | 27 | | UperNet | VAN-B0 | IN-1K | 160K | 41.1 | 32M | - | [config](https://github.com/Visual-Attention-Network/VAN-Segmentation/blob/main/configs/upernet/1k_pretrained/upernet_van_b0_512x512_160k_ade20k.py) | - | 28 | | UperNet | VAN-B1 | IN-1K | 160K | 44.9 | 44M | - | [config](https://github.com/Visual-Attention-Network/VAN-Segmentation/blob/main/configs/upernet/1k_pretrained/upernet_van_b1_512x512_160k_ade20k.py) | - | 29 | | UperNet | VAN-B2 | IN-1K | 160K | 50.1 | 57M | 948G | [config](https://github.com/Visual-Attention-Network/VAN-Segmentation/blob/main/configs/upernet/1k_pretrained/upernet_van_b2_512x512_160k_ade20k.py) | [TsingHua Cloud](https://cloud.tsinghua.edu.cn/f/68c8b494f3824d30bf07/?dl=1) | 30 | | UperNet | VAN-B3 | IN-1K | 160K | 50.6 | 75M | 1030G | [config](https://github.com/Visual-Attention-Network/VAN-Segmentation/blob/main/configs/upernet/1k_pretrained/upernet_van_b3_512x512_160k_ade20k.py) | [TsingHua Cloud](https://cloud.tsinghua.edu.cn/f/97bde65fbe334b358797/?dl=1) | 31 | | UperNet | VAN-B4 | IN-1K | 160K | 52.2 | 90M | 1098G | [config](https://github.com/Visual-Attention-Network/VAN-Segmentation/blob/main/configs/upernet/1k_pretrained/upernet_van_b4_512x512_160k_ade20k.py) | [TsingHua Cloud](https://cloud.tsinghua.edu.cn/f/5273f92c77a94395b804/?dl=1) | 32 | | UperNet | VAN-B4 | IN-22K | 160K | 53.5 | 90M | 1098G | [config](https://github.com/Visual-Attention-Network/VAN-Segmentation/blob/main/configs/upernet/22k_pretrained/upernet_van_b4_512x512_160k_ade20k_22k.py) | [TsingHua Cloud](https://cloud.tsinghua.edu.cn/f/8f1f0a9c4c71478fa43b/?dl=1) | 33 | | UperNet | VAN-B5 | IN-22K | 160K | 53.9 | 117M | 1208G | [config](https://github.com/Visual-Attention-Network/VAN-Segmentation/blob/main/configs/upernet/22k_pretrained/upernet_van_b5_512x512_160k_ade20k_22k.py) | [TsingHua Cloud](https://cloud.tsinghua.edu.cn/f/2175bdc39d094e5f8f99/?dl=1) | 34 | | UperNet | VAN-B6 | IN-22K | 160K | 54.7 | 231M | 1658G | [config](https://github.com/Visual-Attention-Network/VAN-Segmentation/blob/main/configs/upernet/22k_pretrained/upernet_van_b6_512x512_160k_ade20k_22k.py) | [TsingHua Cloud](https://cloud.tsinghua.edu.cn/f/853d9d0ea0f44c2aa090/?dl=1) | 35 | 36 | **Notes**: In this scheme, we use multi-scale validation following Swin-Transformer. FLOPs are tested under the input size of 2048 $\times$ 512 using [torchprofile](https://github.com/zhijian-liu/torchprofile) (recommended, highly accurate and automatic MACs/FLOPs statistics). 37 | 38 | ### VAN + Semantic FPN 39 | 40 | | Backbone | Iters | mIoU | Config | Download | 41 | | :-------------: | :-----: | :------: | :------------: | :----: | 42 | | VAN-Tiny | 40K | 38.5 | [config](https://github.com/Visual-Attention-Network/VAN-Segmentation/blob/main/configs/sem_fpn/fpn_van_b0_ade20k_40k.py) | [Google Drive](https://drive.google.com/file/d/1Jl8LtyvOl6xeNMKCjpK2Rp_tGRfua8LJ/view?usp=sharing) | 43 | | VAN-Small | 40K | 42.9 | [config](https://github.com/Visual-Attention-Network/VAN-Segmentation/blob/main/configs/sem_fpn/fpn_van_b1_ade20k_40k.py) | [Google Drive](https://drive.google.com/file/d/1Xfuo9D3Fo7b6zSCLTWE77k2jgYSHVSb8/view?usp=sharing) | 44 | | VAN-Base | 40K | 46.7 | [config](https://github.com/Visual-Attention-Network/VAN-Segmentation/blob/main/configs/sem_fpn/fpn_van_b2_ade20k_40k.py) | [Google Drive](https://drive.google.com/file/d/1Ar4Hq9DjgaULQKfwM-jJvSO-D6gendpf/view?usp=sharing) | 45 | | VAN-Large | 40K | 48.1 | [config](https://github.com/Visual-Attention-Network/VAN-Segmentation/blob/main/configs/sem_fpn/fpn_van_b3_ade20k_40k.py) | [Google Drive](https://drive.google.com/file/d/1v61uCi07IC6eyVHn3xbJqz4nOiGa1POY/view?usp=sharing) | 46 | 47 | ## Preparation 48 | 49 | Install MMSegmentation and download ADE20K according to the guidelines in MMSegmentation. 50 | 51 | ## Requirement 52 | 53 | ``` 54 | pip install mmsegmentation==0.26.0 (https://github.com/open-mmlab/mmsegmentation/tree/v0.26.0) 55 | ``` 56 | 57 | ## Training 58 | 59 | We use 8 GPUs for training by default. Run: 60 | 61 | ```bash 62 | ./dist_train.sh /path/to/config 8 63 | ``` 64 | 65 | ## Evaluation 66 | 67 | To evaluate the model, run: 68 | 69 | ```bash 70 | ./dist_test.sh /path/to/config /path/to/checkpoint_file 8 --eval mIoU 71 | ``` 72 | 73 | ## FLOPs 74 | 75 | Install torchprofile using 76 | 77 | ```bash 78 | pip install torchprofile 79 | ``` 80 | 81 | To calculate FLOPs for a model, run: 82 | 83 | ```bash 84 | bash tools/flops.sh /path/to/config --shape 512 512 85 | ``` 86 | 87 | 88 | ## Acknowledgment 89 | 90 | Our implementation is mainly based on [mmsegmentaion](https://github.com/open-mmlab/mmsegmentation/tree/v0.12.0), [Swin-Transformer](https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation), [PoolFormer](https://github.com/sail-sg/poolformer), and [Enjoy-Hamburger](https://github.com/Gsunshine/Enjoy-Hamburger). Thanks for their authors. 91 | 92 | ## LICENSE 93 | 94 | This repo is under the Apache-2.0 license. For commercial use, please contact the authors. 95 | -------------------------------------------------------------------------------- /configs/_base_/datasets/ade20k.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'ADE20KDataset' 3 | data_root = '/home/gmh/dataset/ADEChallengeData2016' 4 | img_norm_cfg = dict( 5 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 6 | crop_size = (512, 512) 7 | train_pipeline = [ 8 | dict(type='LoadImageFromFile'), 9 | dict(type='LoadAnnotations', reduce_zero_label=True), 10 | dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), 11 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 12 | dict(type='RandomFlip', prob=0.5), 13 | dict(type='PhotoMetricDistortion'), 14 | dict(type='Normalize', **img_norm_cfg), 15 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), 16 | dict(type='DefaultFormatBundle'), 17 | dict(type='Collect', keys=['img', 'gt_semantic_seg']), 18 | ] 19 | test_pipeline = [ 20 | dict(type='LoadImageFromFile'), 21 | dict( 22 | type='MultiScaleFlipAug', 23 | img_scale=(2048, 512), 24 | # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 25 | flip=False, 26 | transforms=[ 27 | dict(type='Resize', keep_ratio=True), 28 | dict(type='ResizeToMultiple', size_divisor=32), 29 | dict(type='RandomFlip'), 30 | dict(type='Normalize', **img_norm_cfg), 31 | dict(type='ImageToTensor', keys=['img']), 32 | dict(type='Collect', keys=['img']), 33 | ]) 34 | ] 35 | data = dict( 36 | samples_per_gpu=4, 37 | workers_per_gpu=4, 38 | train=dict( 39 | type='RepeatDataset', 40 | times=50, 41 | dataset=dict( 42 | type=dataset_type, 43 | data_root=data_root, 44 | img_dir='images/training', 45 | ann_dir='annotations/training', 46 | pipeline=train_pipeline)), 47 | val=dict( 48 | type=dataset_type, 49 | data_root=data_root, 50 | img_dir='images/validation', 51 | ann_dir='annotations/validation', 52 | pipeline=test_pipeline), 53 | test=dict( 54 | type=dataset_type, 55 | data_root=data_root, 56 | img_dir='images/validation', 57 | ann_dir='annotations/validation', 58 | pipeline=test_pipeline)) 59 | -------------------------------------------------------------------------------- /configs/_base_/default_runtime.py: -------------------------------------------------------------------------------- 1 | # yapf:disable 2 | log_config = dict( 3 | interval=50, 4 | hooks=[ 5 | dict(type='TextLoggerHook', by_epoch=False), 6 | # dict(type='TensorboardLoggerHook') 7 | ]) 8 | # yapf:enable 9 | dist_params = dict(backend='nccl') 10 | log_level = 'INFO' 11 | load_from = None 12 | resume_from = None 13 | workflow = [('train', 1)] 14 | cudnn_benchmark = True 15 | -------------------------------------------------------------------------------- /configs/_base_/models/fpn_van.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained=None, 6 | backbone=dict( 7 | type='VAN', 8 | embed_dims=[32, 64, 160, 256], 9 | drop_rate=0.0, 10 | drop_path_rate=0.1, 11 | depths=[3, 3, 5, 2], 12 | norm_cfg=norm_cfg), 13 | neck=dict( 14 | type='FPN', 15 | in_channels=[32, 64, 160, 256], 16 | out_channels=256, 17 | num_outs=4), 18 | decode_head=dict( 19 | type='FPNHead', 20 | in_channels=[256, 256, 256, 256], 21 | in_index=[0, 1, 2, 3], 22 | feature_strides=[4, 8, 16, 32], 23 | channels=128, 24 | dropout_ratio=0.1, 25 | num_classes=150, 26 | norm_cfg=norm_cfg, 27 | align_corners=False, 28 | loss_decode=dict( 29 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 30 | # model training and testing settings 31 | train_cfg=dict(), 32 | test_cfg=dict(mode='whole')) 33 | -------------------------------------------------------------------------------- /configs/_base_/models/upernet_van.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='SyncBN', requires_grad=True) 3 | model = dict( 4 | type='EncoderDecoder', 5 | pretrained=None, 6 | backbone=dict( 7 | type='VAN', 8 | embed_dims=[32, 64, 160, 256], 9 | drop_rate=0.0, 10 | drop_path_rate=0.1, 11 | depths=[3, 3, 5, 2], 12 | norm_cfg=norm_cfg), 13 | decode_head=dict( 14 | type='UPerHead', 15 | in_channels=[32, 64, 160, 256], 16 | in_index=[0, 1, 2, 3], 17 | pool_scales=(1, 2, 3, 6), 18 | channels=512, 19 | dropout_ratio=0.1, 20 | num_classes=150, 21 | norm_cfg=norm_cfg, 22 | align_corners=False, 23 | loss_decode=dict( 24 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), 25 | auxiliary_head=dict( 26 | type='FCNHead', 27 | in_channels=160, 28 | in_index=2, 29 | channels=256, 30 | num_convs=1, 31 | concat_input=False, 32 | dropout_ratio=0.1, 33 | num_classes=150, 34 | norm_cfg=norm_cfg, 35 | align_corners=False, 36 | loss_decode=dict( 37 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), 38 | # model training and testing settings 39 | train_cfg=dict(), 40 | test_cfg=dict(mode='whole')) 41 | -------------------------------------------------------------------------------- /configs/_base_/schedules/schedule_160k.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) 3 | optimizer_config = dict() 4 | # learning policy 5 | lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) 6 | # runtime settings 7 | runner = dict(type='IterBasedRunner', max_iters=160000) 8 | checkpoint_config = dict(by_epoch=False, interval=16000) 9 | evaluation = dict(interval=16000, metric='mIoU', pre_eval=True) 10 | -------------------------------------------------------------------------------- /configs/_base_/schedules/schedule_20k.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) 3 | optimizer_config = dict() 4 | # learning policy 5 | lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) 6 | # runtime settings 7 | runner = dict(type='IterBasedRunner', max_iters=20000) 8 | checkpoint_config = dict(by_epoch=False, interval=2000) 9 | evaluation = dict(interval=2000, metric='mIoU') 10 | -------------------------------------------------------------------------------- /configs/_base_/schedules/schedule_40k.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) 3 | optimizer_config = dict() 4 | # learning policy 5 | lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) 6 | # runtime settings 7 | runner = dict(type='IterBasedRunner', max_iters=40000) 8 | checkpoint_config = dict(by_epoch=False, interval=4000) 9 | evaluation = dict(interval=4000, metric='mIoU') 10 | -------------------------------------------------------------------------------- /configs/_base_/schedules/schedule_80k.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) 3 | optimizer_config = dict() 4 | # learning policy 5 | lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) 6 | # runtime settings 7 | runner = dict(type='IterBasedRunner', max_iters=80000) 8 | checkpoint_config = dict(by_epoch=False, interval=8000) 9 | evaluation = dict(interval=8000, metric='mIoU') 10 | -------------------------------------------------------------------------------- /configs/sem_fpn/fpn_van_b0_ade20k_40k.py: -------------------------------------------------------------------------------- 1 | _base_ = './fpn_van_b2_ade20k_40k.py' 2 | 3 | # model settings 4 | model = dict( 5 | backbone=dict( 6 | embed_dims=[32, 64, 160, 256], 7 | depths=[3, 3, 5, 2], 8 | init_cfg=dict(type='Pretrained', checkpoint='pretrained/van_b0.pth')), 9 | neck=dict(in_channels=[32, 64, 160, 256])) 10 | -------------------------------------------------------------------------------- /configs/sem_fpn/fpn_van_b1_ade20k_40k.py: -------------------------------------------------------------------------------- 1 | _base_ = './fpn_van_b2_ade20k_40k.py' 2 | 3 | # model settings 4 | model = dict( 5 | type='EncoderDecoder', 6 | backbone=dict( 7 | embed_dims=[64, 128, 320, 512], 8 | depths=[2, 2, 4, 2], 9 | init_cfg=dict(type='Pretrained', checkpoint='pretrained/van_b1.pth')), 10 | neck=dict(in_channels=[64, 128, 320, 512])) 11 | -------------------------------------------------------------------------------- /configs/sem_fpn/fpn_van_b2_ade20k_40k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/fpn_van.py', 3 | '../_base_/datasets/ade20k.py', 4 | '../_base_/default_runtime.py' 5 | ] 6 | # model settings 7 | model = dict( 8 | type='EncoderDecoder', 9 | backbone=dict( 10 | embed_dims=[64, 128, 320, 512], 11 | depths=[3, 3, 12, 3], 12 | init_cfg=dict(type='Pretrained', checkpoint='pretrained/van_b2.pth'), 13 | drop_path_rate=0.2), 14 | neck=dict(in_channels=[64, 128, 320, 512]), 15 | decode_head=dict(num_classes=150)) 16 | 17 | 18 | gpu_multiples = 2 # we use 8 gpu instead of 4 in mmsegmentation, so lr*2 and max_iters/2 19 | # optimizer 20 | optimizer = dict(type='AdamW', lr=0.0001*gpu_multiples, weight_decay=0.0001) 21 | optimizer_config = dict() 22 | # learning policy 23 | lr_config = dict(policy='poly', power=0.9, min_lr=0.0, by_epoch=False) 24 | # runtime settings 25 | runner = dict(type='IterBasedRunner', max_iters=80000//gpu_multiples) 26 | checkpoint_config = dict(by_epoch=False, interval=8000//gpu_multiples) 27 | evaluation = dict(interval=8000//gpu_multiples, metric='mIoU') 28 | data = dict(samples_per_gpu=4) 29 | -------------------------------------------------------------------------------- /configs/sem_fpn/fpn_van_b3_ade20k_40k.py: -------------------------------------------------------------------------------- 1 | _base_ = './fpn_van_b2_ade20k_40k.py' 2 | 3 | # model settings 4 | model = dict( 5 | type='EncoderDecoder', 6 | backbone=dict( 7 | embed_dims=[64, 128, 320, 512], 8 | depths=[3, 5, 27, 3], 9 | init_cfg=dict(type='Pretrained', checkpoint='pretrained/van_b3.pth'), 10 | drop_path_rate=0.3), 11 | neck=dict(in_channels=[64, 128, 320, 512])) 12 | data = dict(samples_per_gpu=4) 13 | -------------------------------------------------------------------------------- /configs/sem_fpn/fpn_van_b4_ade20k_40k.py: -------------------------------------------------------------------------------- 1 | _base_ = './fpn_van_b2_ade20k_40k.py' 2 | 3 | # model settings 4 | model = dict( 5 | type='EncoderDecoder', 6 | backbone=dict( 7 | embed_dims=[64, 128, 320, 512], 8 | depths=[3, 6, 40, 3], 9 | init_cfg=dict(type='Pretrained', checkpoint='pretrained/van_b4.pth'), 10 | drop_path_rate=0.4), 11 | neck=dict(in_channels=[64, 128, 320, 512])) 12 | data = dict(samples_per_gpu=4) 13 | -------------------------------------------------------------------------------- /configs/upernet/1k_pretrained/upernet_van_b0_512x512_160k_ade20k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../_base_/models/upernet_van.py', 3 | '../../_base_/datasets/ade20k.py', 4 | '../../_base_/default_runtime.py', 5 | '../../_base_/schedules/schedule_160k.py' 6 | ] 7 | model = dict( 8 | type='EncoderDecoder', 9 | backbone=dict( 10 | init_cfg=dict(type='Pretrained', checkpoint='pretrained/van_b0.pth')), 11 | decode_head=dict( 12 | in_channels=[32, 64, 160, 256], 13 | num_classes=150 14 | ), 15 | auxiliary_head=dict( 16 | in_channels=160, 17 | num_classes=150 18 | )) 19 | 20 | # AdamW optimizer, no weight decay for position embedding & layer norm in backbone 21 | optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01, 22 | paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.), 23 | 'relative_position_bias_table': dict(decay_mult=0.), 24 | 'norm': dict(decay_mult=0.)})) 25 | 26 | lr_config = dict(_delete_=True, policy='poly', 27 | warmup='linear', 28 | warmup_iters=1500, 29 | warmup_ratio=1e-6, 30 | power=1.0, min_lr=0.0, by_epoch=False) 31 | 32 | # By default, models are trained on 8 GPUs with 2 images per GPU 33 | data = dict(samples_per_gpu=2) 34 | -------------------------------------------------------------------------------- /configs/upernet/1k_pretrained/upernet_van_b1_512x512_160k_ade20k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../_base_/models/upernet_van.py', 3 | '../../_base_/datasets/ade20k.py', 4 | '../../_base_/default_runtime.py', 5 | '../../_base_/schedules/schedule_160k.py' 6 | ] 7 | model = dict( 8 | type='EncoderDecoder', 9 | backbone=dict( 10 | embed_dims=[64, 128, 320, 512], 11 | depths=[2, 2, 4, 2], 12 | init_cfg=dict(type='Pretrained', checkpoint='pretrained/van_b1.pth')), 13 | decode_head=dict( 14 | in_channels=[64, 128, 320, 512], 15 | num_classes=150 16 | ), 17 | auxiliary_head=dict( 18 | in_channels=320, 19 | num_classes=150 20 | )) 21 | 22 | # AdamW optimizer, no weight decay for position embedding & layer norm in backbone 23 | optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01, 24 | paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.), 25 | 'relative_position_bias_table': dict(decay_mult=0.), 26 | 'norm': dict(decay_mult=0.)})) 27 | 28 | lr_config = dict(_delete_=True, policy='poly', 29 | warmup='linear', 30 | warmup_iters=1500, 31 | warmup_ratio=1e-6, 32 | power=1.0, min_lr=0.0, by_epoch=False) 33 | 34 | # By default, models are trained on 8 GPUs with 2 images per GPU 35 | data = dict(samples_per_gpu=2) 36 | -------------------------------------------------------------------------------- /configs/upernet/1k_pretrained/upernet_van_b2_512x512_160k_ade20k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../_base_/models/upernet_van.py', 3 | '../../_base_/datasets/ade20k.py', 4 | '../../_base_/default_runtime.py', 5 | '../../_base_/schedules/schedule_160k.py' 6 | ] 7 | model = dict( 8 | type='EncoderDecoder', 9 | backbone=dict( 10 | embed_dims=[64, 128, 320, 512], 11 | depths=[3, 3, 12, 3], 12 | init_cfg=dict(type='Pretrained', checkpoint='pretrained/van_b2.pth')), 13 | decode_head=dict( 14 | in_channels=[64, 128, 320, 512], 15 | num_classes=150 16 | ), 17 | auxiliary_head=dict( 18 | in_channels=320, 19 | num_classes=150 20 | )) 21 | 22 | # AdamW optimizer, no weight decay for position embedding & layer norm in backbone 23 | optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01, 24 | paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.), 25 | 'relative_position_bias_table': dict(decay_mult=0.), 26 | 'norm': dict(decay_mult=0.)})) 27 | 28 | lr_config = dict(_delete_=True, policy='poly', 29 | warmup='linear', 30 | warmup_iters=1500, 31 | warmup_ratio=1e-6, 32 | power=1.0, min_lr=0.0, by_epoch=False) 33 | 34 | # By default, models are trained on 8 GPUs with 2 images per GPU 35 | data = dict(samples_per_gpu=2) 36 | -------------------------------------------------------------------------------- /configs/upernet/1k_pretrained/upernet_van_b3_512x512_160k_ade20k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../_base_/models/upernet_van.py', 3 | '../../_base_/datasets/ade20k.py', 4 | '../../_base_/default_runtime.py', 5 | '../../_base_/schedules/schedule_160k.py' 6 | ] 7 | model = dict( 8 | type='EncoderDecoder', 9 | backbone=dict( 10 | embed_dims=[64, 128, 320, 512], 11 | depths=[3, 5, 27, 3], 12 | init_cfg=dict(type='Pretrained', checkpoint='pretrained/van_b3.pth'), 13 | drop_path_rate=0.3), 14 | decode_head=dict( 15 | in_channels=[64, 128, 320, 512], 16 | num_classes=150 17 | ), 18 | auxiliary_head=dict( 19 | in_channels=320, 20 | num_classes=150 21 | )) 22 | 23 | # AdamW optimizer, no weight decay for position embedding & layer norm in backbone 24 | optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01, 25 | paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.), 26 | 'relative_position_bias_table': dict(decay_mult=0.), 27 | 'norm': dict(decay_mult=0.)})) 28 | 29 | lr_config = dict(_delete_=True, policy='poly', 30 | warmup='linear', 31 | warmup_iters=1500, 32 | warmup_ratio=1e-6, 33 | power=1.0, min_lr=0.0, by_epoch=False) 34 | 35 | # By default, models are trained on 8 GPUs with 2 images per GPU 36 | data = dict(samples_per_gpu=2) 37 | -------------------------------------------------------------------------------- /configs/upernet/1k_pretrained/upernet_van_b4_512x512_160k_ade20k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../_base_/models/upernet_van.py', 3 | '../../_base_/datasets/ade20k.py', 4 | '../../_base_/default_runtime.py', 5 | '../../_base_/schedules/schedule_160k.py' 6 | ] 7 | model = dict( 8 | type='EncoderDecoder', 9 | backbone=dict( 10 | embed_dims=[64, 128, 320, 512], 11 | depths=[3, 6, 40, 3], 12 | init_cfg=dict(type='Pretrained', checkpoint='pretrained/van_b4.pth'), 13 | drop_path_rate=0.4), 14 | decode_head=dict( 15 | in_channels=[64, 128, 320, 512], 16 | num_classes=150 17 | ), 18 | auxiliary_head=dict( 19 | in_channels=320, 20 | num_classes=150 21 | )) 22 | 23 | 24 | # AdamW optimizer, no weight decay for position embedding & layer norm in backbone 25 | optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01, 26 | paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.), 27 | 'relative_position_bias_table': dict(decay_mult=0.), 28 | 'norm': dict(decay_mult=0.)})) 29 | 30 | lr_config = dict(_delete_=True, policy='poly', 31 | warmup='linear', 32 | warmup_iters=1500, 33 | warmup_ratio=1e-6, 34 | power=1.0, min_lr=0.0, by_epoch=False) 35 | 36 | # By default, models are trained on 8 GPUs with 2 images per GPU 37 | data = dict(samples_per_gpu=4) 38 | -------------------------------------------------------------------------------- /configs/upernet/22k_pretrained/upernet_van_b4_512x512_160k_ade20k_22k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../_base_/models/upernet_van.py', 3 | '../../_base_/datasets/ade20k.py', 4 | '../../_base_/default_runtime.py', 5 | '../../_base_/schedules/schedule_160k.py' 6 | ] 7 | model = dict( 8 | type='EncoderDecoder', 9 | backbone=dict( 10 | embed_dims=[64, 128, 320, 512], 11 | depths=[3, 6, 40, 3], 12 | init_cfg=dict(type='Pretrained', 13 | checkpoint='pretrained/van_b4_22k.pth'), 14 | drop_path_rate=0.4), 15 | decode_head=dict( 16 | in_channels=[64, 128, 320, 512], 17 | num_classes=150 18 | ), 19 | auxiliary_head=dict( 20 | in_channels=320, 21 | num_classes=150 22 | )) 23 | 24 | # AdamW optimizer, no weight decay for position embedding & layer norm in backbone 25 | optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01, 26 | paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.), 27 | 'relative_position_bias_table': dict(decay_mult=0.), 28 | 'norm': dict(decay_mult=0.)})) 29 | 30 | lr_config = dict(_delete_=True, policy='poly', 31 | warmup='linear', 32 | warmup_iters=1500, 33 | warmup_ratio=1e-6, 34 | power=1.0, min_lr=0.0, by_epoch=False) 35 | 36 | # By default, models are trained on 8 GPUs with 2 images per GPU 37 | data = dict(samples_per_gpu=4) 38 | -------------------------------------------------------------------------------- /configs/upernet/22k_pretrained/upernet_van_b5_512x512_160k_ade20k_22k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../_base_/models/upernet_van.py', 3 | '../../_base_/datasets/ade20k.py', 4 | '../../_base_/default_runtime.py', 5 | '../../_base_/schedules/schedule_160k.py' 6 | ] 7 | model = dict( 8 | type='EncoderDecoder', 9 | backbone=dict( 10 | embed_dims=[96, 192, 480, 768], 11 | depths=[3, 3, 24, 3], 12 | init_cfg=dict(type='Pretrained', 13 | checkpoint='pretrained/van_b5_22k.pth'), 14 | drop_path_rate=0.4), 15 | decode_head=dict( 16 | in_channels=[96, 192, 480, 768], 17 | num_classes=150 18 | ), 19 | auxiliary_head=dict( 20 | in_channels=480, 21 | num_classes=150 22 | )) 23 | 24 | # AdamW optimizer, no weight decay for position embedding & layer norm in backbone 25 | optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01, 26 | paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.), 27 | 'relative_position_bias_table': dict(decay_mult=0.), 28 | 'norm': dict(decay_mult=0.)})) 29 | 30 | lr_config = dict(_delete_=True, policy='poly', 31 | warmup='linear', 32 | warmup_iters=1500, 33 | warmup_ratio=1e-6, 34 | power=1.0, min_lr=0.0, by_epoch=False) 35 | 36 | # By default, models are trained on 8 GPUs with 2 images per GPU 37 | data = dict(samples_per_gpu=2) 38 | -------------------------------------------------------------------------------- /configs/upernet/22k_pretrained/upernet_van_b6_512x512_160k_ade20k_22k.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../_base_/models/upernet_van.py', 3 | '../../_base_/datasets/ade20k.py', 4 | '../../_base_/default_runtime.py', 5 | '../../_base_/schedules/schedule_160k.py' 6 | ] 7 | model = dict( 8 | type='EncoderDecoder', 9 | backbone=dict( 10 | embed_dims=[96, 192, 384, 768], 11 | depths=[6, 6, 90, 6], 12 | init_cfg=dict(type='Pretrained', 13 | checkpoint='pretrained/van_b6_22k.pth'), 14 | drop_path_rate=0.5), 15 | decode_head=dict( 16 | in_channels=[96, 192, 384, 768], 17 | num_classes=150 18 | ), 19 | auxiliary_head=dict( 20 | in_channels=384, 21 | num_classes=150 22 | )) 23 | 24 | # AdamW optimizer, no weight decay for position embedding & layer norm in backbone 25 | optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01, 26 | paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.), 27 | 'relative_position_bias_table': dict(decay_mult=0.), 28 | 'norm': dict(decay_mult=0.)})) 29 | 30 | lr_config = dict(_delete_=True, policy='poly', 31 | warmup='linear', 32 | warmup_iters=1500, 33 | warmup_ratio=1e-6, 34 | power=1.0, min_lr=0.0, by_epoch=False) 35 | 36 | # By default, models are trained on 8 GPUs with 2 images per GPU 37 | data = dict(samples_per_gpu=2) 38 | -------------------------------------------------------------------------------- /dist_test.sh: -------------------------------------------------------------------------------- 1 | CONFIG=$1 2 | CHECKPOINT=$2 3 | GPUS=$3 4 | NNODES=${NNODES:-1} 5 | NODE_RANK=${NODE_RANK:-0} 6 | PORT=${PORT:-29500} 7 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 8 | 9 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 10 | python -m torch.distributed.launch \ 11 | --nnodes=$NNODES \ 12 | --node_rank=$NODE_RANK \ 13 | --master_addr=$MASTER_ADDR \ 14 | --nproc_per_node=$GPUS \ 15 | --master_port=$PORT \ 16 | $(dirname "$0")/test.py \ 17 | $CONFIG \ 18 | $CHECKPOINT \ 19 | --launcher pytorch \ 20 | ${@:4} 21 | -------------------------------------------------------------------------------- /dist_train.sh: -------------------------------------------------------------------------------- 1 | CONFIG=$1 2 | GPUS=$2 3 | NNODES=${NNODES:-1} 4 | NODE_RANK=${NODE_RANK:-0} 5 | PORT=${PORT:-29500} 6 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 7 | 8 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 9 | python -m torch.distributed.launch \ 10 | --nnodes=$NNODES \ 11 | --node_rank=$NODE_RANK \ 12 | --master_addr=$MASTER_ADDR \ 13 | --nproc_per_node=$GPUS \ 14 | --master_port=$PORT \ 15 | $(dirname "$0")/train.py \ 16 | $CONFIG \ 17 | --launcher pytorch ${@:3} 18 | -------------------------------------------------------------------------------- /tools/analyze_logs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | """Modified from https://github.com/open- 3 | mmlab/mmdetection/blob/master/tools/analysis_tools/analyze_logs.py.""" 4 | import argparse 5 | import json 6 | from collections import defaultdict 7 | 8 | import matplotlib.pyplot as plt 9 | import seaborn as sns 10 | 11 | 12 | def plot_curve(log_dicts, args): 13 | if args.backend is not None: 14 | plt.switch_backend(args.backend) 15 | sns.set_style(args.style) 16 | # if legend is None, use {filename}_{key} as legend 17 | legend = args.legend 18 | if legend is None: 19 | legend = [] 20 | for json_log in args.json_logs: 21 | for metric in args.keys: 22 | legend.append(f'{json_log}_{metric}') 23 | assert len(legend) == (len(args.json_logs) * len(args.keys)) 24 | metrics = args.keys 25 | 26 | num_metrics = len(metrics) 27 | for i, log_dict in enumerate(log_dicts): 28 | epochs = list(log_dict.keys()) 29 | for j, metric in enumerate(metrics): 30 | print(f'plot curve of {args.json_logs[i]}, metric is {metric}') 31 | plot_epochs = [] 32 | plot_iters = [] 33 | plot_values = [] 34 | # In some log files exist lines of validation, 35 | # `mode` list is used to only collect iter number 36 | # of training line. 37 | for epoch in epochs: 38 | epoch_logs = log_dict[epoch] 39 | if metric not in epoch_logs.keys(): 40 | continue 41 | if metric in ['mIoU', 'mAcc', 'aAcc']: 42 | plot_epochs.append(epoch) 43 | plot_values.append(epoch_logs[metric][0]) 44 | else: 45 | for idx in range(len(epoch_logs[metric])): 46 | if epoch_logs['mode'][idx] == 'train': 47 | plot_iters.append(epoch_logs['iter'][idx]) 48 | plot_values.append(epoch_logs[metric][idx]) 49 | ax = plt.gca() 50 | label = legend[i * num_metrics + j] 51 | if metric in ['mIoU', 'mAcc', 'aAcc']: 52 | ax.set_xticks(plot_epochs) 53 | plt.xlabel('epoch') 54 | plt.plot(plot_epochs, plot_values, label=label, marker='o') 55 | else: 56 | plt.xlabel('iter') 57 | plt.plot(plot_iters, plot_values, label=label, linewidth=0.5) 58 | plt.legend() 59 | if args.title is not None: 60 | plt.title(args.title) 61 | if args.out is None: 62 | plt.show() 63 | else: 64 | print(f'save curve to: {args.out}') 65 | plt.savefig(args.out) 66 | plt.cla() 67 | 68 | 69 | def parse_args(): 70 | parser = argparse.ArgumentParser(description='Analyze Json Log') 71 | parser.add_argument( 72 | 'json_logs', 73 | type=str, 74 | nargs='+', 75 | help='path of train log in json format') 76 | parser.add_argument( 77 | '--keys', 78 | type=str, 79 | nargs='+', 80 | default=['mIoU'], 81 | help='the metric that you want to plot') 82 | parser.add_argument('--title', type=str, help='title of figure') 83 | parser.add_argument( 84 | '--legend', 85 | type=str, 86 | nargs='+', 87 | default=None, 88 | help='legend of each plot') 89 | parser.add_argument( 90 | '--backend', type=str, default=None, help='backend of plt') 91 | parser.add_argument( 92 | '--style', type=str, default='dark', help='style of plt') 93 | parser.add_argument('--out', type=str, default=None) 94 | args = parser.parse_args() 95 | return args 96 | 97 | 98 | def load_json_logs(json_logs): 99 | # load and convert json_logs to log_dict, key is epoch, value is a sub dict 100 | # keys of sub dict is different metrics 101 | # value of sub dict is a list of corresponding values of all iterations 102 | log_dicts = [dict() for _ in json_logs] 103 | for json_log, log_dict in zip(json_logs, log_dicts): 104 | with open(json_log, 'r') as log_file: 105 | for line in log_file: 106 | log = json.loads(line.strip()) 107 | # skip lines without `epoch` field 108 | if 'epoch' not in log: 109 | continue 110 | epoch = log.pop('epoch') 111 | if epoch not in log_dict: 112 | log_dict[epoch] = defaultdict(list) 113 | for k, v in log.items(): 114 | log_dict[epoch][k].append(v) 115 | return log_dicts 116 | 117 | 118 | def main(): 119 | args = parse_args() 120 | json_logs = args.json_logs 121 | for json_log in json_logs: 122 | assert json_log.endswith('.json') 123 | log_dicts = load_json_logs(json_logs) 124 | plot_curve(log_dicts, args) 125 | 126 | 127 | if __name__ == '__main__': 128 | main() 129 | -------------------------------------------------------------------------------- /tools/benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os.path as osp 4 | import time 5 | 6 | import mmcv 7 | import numpy as np 8 | import torch 9 | from mmcv import Config 10 | from mmcv.parallel import MMDataParallel 11 | from mmcv.runner import load_checkpoint, wrap_fp16_model 12 | 13 | from mmseg.datasets import build_dataloader, build_dataset 14 | from mmseg.models import build_segmentor 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser(description='MMSeg benchmark a model') 19 | parser.add_argument('config', help='test config file path') 20 | parser.add_argument('checkpoint', help='checkpoint file') 21 | parser.add_argument( 22 | '--log-interval', type=int, default=50, help='interval of logging') 23 | parser.add_argument( 24 | '--work-dir', 25 | help=('if specified, the results will be dumped ' 26 | 'into the directory as json')) 27 | parser.add_argument('--repeat-times', type=int, default=1) 28 | args = parser.parse_args() 29 | return args 30 | 31 | 32 | def main(): 33 | args = parse_args() 34 | 35 | cfg = Config.fromfile(args.config) 36 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 37 | if args.work_dir is not None: 38 | mmcv.mkdir_or_exist(osp.abspath(args.work_dir)) 39 | json_file = osp.join(args.work_dir, f'fps_{timestamp}.json') 40 | else: 41 | # use config filename as default work_dir if cfg.work_dir is None 42 | work_dir = osp.join('./work_dirs', 43 | osp.splitext(osp.basename(args.config))[0]) 44 | mmcv.mkdir_or_exist(osp.abspath(work_dir)) 45 | json_file = osp.join(work_dir, f'fps_{timestamp}.json') 46 | 47 | repeat_times = args.repeat_times 48 | # set cudnn_benchmark 49 | torch.backends.cudnn.benchmark = False 50 | cfg.model.pretrained = None 51 | cfg.data.test.test_mode = True 52 | 53 | benchmark_dict = dict(config=args.config, unit='img / s') 54 | overall_fps_list = [] 55 | for time_index in range(repeat_times): 56 | print(f'Run {time_index + 1}:') 57 | # build the dataloader 58 | # TODO: support multiple images per gpu (only minor changes are needed) 59 | dataset = build_dataset(cfg.data.test) 60 | data_loader = build_dataloader( 61 | dataset, 62 | samples_per_gpu=1, 63 | workers_per_gpu=cfg.data.workers_per_gpu, 64 | dist=False, 65 | shuffle=False) 66 | 67 | # build the model and load checkpoint 68 | cfg.model.train_cfg = None 69 | model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg')) 70 | fp16_cfg = cfg.get('fp16', None) 71 | if fp16_cfg is not None: 72 | wrap_fp16_model(model) 73 | if 'checkpoint' in args and osp.exists(args.checkpoint): 74 | load_checkpoint(model, args.checkpoint, map_location='cpu') 75 | 76 | model = MMDataParallel(model, device_ids=[0]) 77 | 78 | model.eval() 79 | 80 | # the first several iterations may be very slow so skip them 81 | num_warmup = 5 82 | pure_inf_time = 0 83 | total_iters = 200 84 | 85 | # benchmark with 200 image and take the average 86 | for i, data in enumerate(data_loader): 87 | 88 | torch.cuda.synchronize() 89 | start_time = time.perf_counter() 90 | 91 | with torch.no_grad(): 92 | model(return_loss=False, rescale=True, **data) 93 | 94 | torch.cuda.synchronize() 95 | elapsed = time.perf_counter() - start_time 96 | 97 | if i >= num_warmup: 98 | pure_inf_time += elapsed 99 | if (i + 1) % args.log_interval == 0: 100 | fps = (i + 1 - num_warmup) / pure_inf_time 101 | print(f'Done image [{i + 1:<3}/ {total_iters}], ' 102 | f'fps: {fps:.2f} img / s') 103 | 104 | if (i + 1) == total_iters: 105 | fps = (i + 1 - num_warmup) / pure_inf_time 106 | print(f'Overall fps: {fps:.2f} img / s\n') 107 | benchmark_dict[f'overall_fps_{time_index + 1}'] = round(fps, 2) 108 | overall_fps_list.append(fps) 109 | break 110 | benchmark_dict['average_fps'] = round(np.mean(overall_fps_list), 2) 111 | benchmark_dict['fps_variance'] = round(np.var(overall_fps_list), 4) 112 | print(f'Average fps of {repeat_times} evaluations: ' 113 | f'{benchmark_dict["average_fps"]}') 114 | print(f'The variance of {repeat_times} evaluations: ' 115 | f'{benchmark_dict["fps_variance"]}') 116 | mmcv.dump(benchmark_dict, json_file, indent=4) 117 | 118 | 119 | if __name__ == '__main__': 120 | main() 121 | -------------------------------------------------------------------------------- /tools/browse_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os 4 | import warnings 5 | from pathlib import Path 6 | 7 | import mmcv 8 | import numpy as np 9 | from mmcv import Config, DictAction 10 | 11 | from mmseg.datasets.builder import build_dataset 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser(description='Browse a dataset') 16 | parser.add_argument('config', help='train config file path') 17 | parser.add_argument( 18 | '--show-origin', 19 | default=False, 20 | action='store_true', 21 | help='if True, omit all augmentation in pipeline,' 22 | ' show origin image and seg map') 23 | parser.add_argument( 24 | '--skip-type', 25 | type=str, 26 | nargs='+', 27 | default=['DefaultFormatBundle', 'Normalize', 'Collect'], 28 | help='skip some useless pipeline,if `show-origin` is true, ' 29 | 'all pipeline except `Load` will be skipped') 30 | parser.add_argument( 31 | '--output-dir', 32 | default='./output', 33 | type=str, 34 | help='If there is no display interface, you can save it') 35 | parser.add_argument('--show', default=False, action='store_true') 36 | parser.add_argument( 37 | '--show-interval', 38 | type=int, 39 | default=999, 40 | help='the interval of show (ms)') 41 | parser.add_argument( 42 | '--opacity', 43 | type=float, 44 | default=0.5, 45 | help='the opacity of semantic map') 46 | parser.add_argument( 47 | '--cfg-options', 48 | nargs='+', 49 | action=DictAction, 50 | help='override some settings in the used config, the key-value pair ' 51 | 'in xxx=yyy format will be merged into config file. If the value to ' 52 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 53 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 54 | 'Note that the quotation marks are necessary and that no white space ' 55 | 'is allowed.') 56 | args = parser.parse_args() 57 | return args 58 | 59 | 60 | def imshow_semantic(img, 61 | seg, 62 | class_names, 63 | palette=None, 64 | win_name='', 65 | show=False, 66 | wait_time=0, 67 | out_file=None, 68 | opacity=0.5): 69 | """Draw `result` over `img`. 70 | 71 | Args: 72 | img (str or Tensor): The image to be displayed. 73 | seg (Tensor): The semantic segmentation results to draw over 74 | `img`. 75 | class_names (list[str]): Names of each classes. 76 | palette (list[list[int]]] | np.ndarray | None): The palette of 77 | segmentation map. If None is given, random palette will be 78 | generated. Default: None 79 | win_name (str): The window name. 80 | wait_time (int): Value of waitKey param. 81 | Default: 0. 82 | show (bool): Whether to show the image. 83 | Default: False. 84 | out_file (str or None): The filename to write the image. 85 | Default: None. 86 | opacity(float): Opacity of painted segmentation map. 87 | Default 0.5. 88 | Must be in (0, 1] range. 89 | Returns: 90 | img (Tensor): Only if not `show` or `out_file` 91 | """ 92 | img = mmcv.imread(img) 93 | img = img.copy() 94 | if palette is None: 95 | palette = np.random.randint(0, 255, size=(len(class_names), 3)) 96 | palette = np.array(palette) 97 | assert palette.shape[0] == len(class_names) 98 | assert palette.shape[1] == 3 99 | assert len(palette.shape) == 2 100 | assert 0 < opacity <= 1.0 101 | color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) 102 | for label, color in enumerate(palette): 103 | color_seg[seg == label, :] = color 104 | # convert to BGR 105 | color_seg = color_seg[..., ::-1] 106 | 107 | img = img * (1 - opacity) + color_seg * opacity 108 | img = img.astype(np.uint8) 109 | # if out_file specified, do not show image in window 110 | if out_file is not None: 111 | show = False 112 | 113 | if show: 114 | mmcv.imshow(img, win_name, wait_time) 115 | if out_file is not None: 116 | mmcv.imwrite(img, out_file) 117 | 118 | if not (show or out_file): 119 | warnings.warn('show==False and out_file is not specified, only ' 120 | 'result image will be returned') 121 | return img 122 | 123 | 124 | def _retrieve_data_cfg(_data_cfg, skip_type, show_origin): 125 | if show_origin is True: 126 | # only keep pipeline of Loading data and ann 127 | _data_cfg['pipeline'] = [ 128 | x for x in _data_cfg.pipeline if 'Load' in x['type'] 129 | ] 130 | else: 131 | _data_cfg['pipeline'] = [ 132 | x for x in _data_cfg.pipeline if x['type'] not in skip_type 133 | ] 134 | 135 | 136 | def retrieve_data_cfg(config_path, skip_type, cfg_options, show_origin=False): 137 | cfg = Config.fromfile(config_path) 138 | if cfg_options is not None: 139 | cfg.merge_from_dict(cfg_options) 140 | train_data_cfg = cfg.data.train 141 | if isinstance(train_data_cfg, list): 142 | for _data_cfg in train_data_cfg: 143 | while 'dataset' in _data_cfg and _data_cfg[ 144 | 'type'] != 'MultiImageMixDataset': 145 | _data_cfg = _data_cfg['dataset'] 146 | if 'pipeline' in _data_cfg: 147 | _retrieve_data_cfg(_data_cfg, skip_type, show_origin) 148 | else: 149 | raise ValueError 150 | else: 151 | while 'dataset' in train_data_cfg and train_data_cfg[ 152 | 'type'] != 'MultiImageMixDataset': 153 | train_data_cfg = train_data_cfg['dataset'] 154 | _retrieve_data_cfg(train_data_cfg, skip_type, show_origin) 155 | return cfg 156 | 157 | 158 | def main(): 159 | args = parse_args() 160 | cfg = retrieve_data_cfg(args.config, args.skip_type, args.cfg_options, 161 | args.show_origin) 162 | dataset = build_dataset(cfg.data.train) 163 | progress_bar = mmcv.ProgressBar(len(dataset)) 164 | for item in dataset: 165 | filename = os.path.join(args.output_dir, 166 | Path(item['filename']).name 167 | ) if args.output_dir is not None else None 168 | imshow_semantic( 169 | item['img'], 170 | item['gt_semantic_seg'], 171 | dataset.CLASSES, 172 | dataset.PALETTE, 173 | show=args.show, 174 | wait_time=args.show_interval, 175 | out_file=filename, 176 | opacity=args.opacity, 177 | ) 178 | progress_bar.update() 179 | 180 | 181 | if __name__ == '__main__': 182 | main() 183 | -------------------------------------------------------------------------------- /tools/confusion_matrix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os 4 | 5 | import matplotlib.pyplot as plt 6 | import mmcv 7 | import numpy as np 8 | from matplotlib.ticker import MultipleLocator 9 | from mmcv import Config, DictAction 10 | 11 | from mmseg.datasets import build_dataset 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser( 16 | description='Generate confusion matrix from segmentation results') 17 | parser.add_argument('config', help='test config file path') 18 | parser.add_argument( 19 | 'prediction_path', help='prediction path where test .pkl result') 20 | parser.add_argument( 21 | 'save_dir', help='directory where confusion matrix will be saved') 22 | parser.add_argument( 23 | '--show', action='store_true', help='show confusion matrix') 24 | parser.add_argument( 25 | '--color-theme', 26 | default='winter', 27 | help='theme of the matrix color map') 28 | parser.add_argument( 29 | '--title', 30 | default='Normalized Confusion Matrix', 31 | help='title of the matrix color map') 32 | parser.add_argument( 33 | '--cfg-options', 34 | nargs='+', 35 | action=DictAction, 36 | help='override some settings in the used config, the key-value pair ' 37 | 'in xxx=yyy format will be merged into config file. If the value to ' 38 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 39 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 40 | 'Note that the quotation marks are necessary and that no white space ' 41 | 'is allowed.') 42 | args = parser.parse_args() 43 | return args 44 | 45 | 46 | def calculate_confusion_matrix(dataset, results): 47 | """Calculate the confusion matrix. 48 | 49 | Args: 50 | dataset (Dataset): Test or val dataset. 51 | results (list[ndarray]): A list of segmentation results in each image. 52 | """ 53 | n = len(dataset.CLASSES) 54 | confusion_matrix = np.zeros(shape=[n, n]) 55 | assert len(dataset) == len(results) 56 | prog_bar = mmcv.ProgressBar(len(results)) 57 | for idx, per_img_res in enumerate(results): 58 | res_segm = per_img_res 59 | gt_segm = dataset.get_gt_seg_map_by_idx(idx) 60 | inds = n * gt_segm + res_segm 61 | inds = inds.flatten() 62 | mat = np.bincount(inds, minlength=n**2).reshape(n, n) 63 | confusion_matrix += mat 64 | prog_bar.update() 65 | return confusion_matrix 66 | 67 | 68 | def plot_confusion_matrix(confusion_matrix, 69 | labels, 70 | save_dir=None, 71 | show=True, 72 | title='Normalized Confusion Matrix', 73 | color_theme='winter'): 74 | """Draw confusion matrix with matplotlib. 75 | 76 | Args: 77 | confusion_matrix (ndarray): The confusion matrix. 78 | labels (list[str]): List of class names. 79 | save_dir (str|optional): If set, save the confusion matrix plot to the 80 | given path. Default: None. 81 | show (bool): Whether to show the plot. Default: True. 82 | title (str): Title of the plot. Default: `Normalized Confusion Matrix`. 83 | color_theme (str): Theme of the matrix color map. Default: `winter`. 84 | """ 85 | # normalize the confusion matrix 86 | per_label_sums = confusion_matrix.sum(axis=1)[:, np.newaxis] 87 | confusion_matrix = \ 88 | confusion_matrix.astype(np.float32) / per_label_sums * 100 89 | 90 | num_classes = len(labels) 91 | fig, ax = plt.subplots( 92 | figsize=(2 * num_classes, 2 * num_classes * 0.8), dpi=180) 93 | cmap = plt.get_cmap(color_theme) 94 | im = ax.imshow(confusion_matrix, cmap=cmap) 95 | plt.colorbar(mappable=im, ax=ax) 96 | 97 | title_font = {'weight': 'bold', 'size': 12} 98 | ax.set_title(title, fontdict=title_font) 99 | label_font = {'size': 10} 100 | plt.ylabel('Ground Truth Label', fontdict=label_font) 101 | plt.xlabel('Prediction Label', fontdict=label_font) 102 | 103 | # draw locator 104 | xmajor_locator = MultipleLocator(1) 105 | xminor_locator = MultipleLocator(0.5) 106 | ax.xaxis.set_major_locator(xmajor_locator) 107 | ax.xaxis.set_minor_locator(xminor_locator) 108 | ymajor_locator = MultipleLocator(1) 109 | yminor_locator = MultipleLocator(0.5) 110 | ax.yaxis.set_major_locator(ymajor_locator) 111 | ax.yaxis.set_minor_locator(yminor_locator) 112 | 113 | # draw grid 114 | ax.grid(True, which='minor', linestyle='-') 115 | 116 | # draw label 117 | ax.set_xticks(np.arange(num_classes)) 118 | ax.set_yticks(np.arange(num_classes)) 119 | ax.set_xticklabels(labels) 120 | ax.set_yticklabels(labels) 121 | 122 | ax.tick_params( 123 | axis='x', bottom=False, top=True, labelbottom=False, labeltop=True) 124 | plt.setp( 125 | ax.get_xticklabels(), rotation=45, ha='left', rotation_mode='anchor') 126 | 127 | # draw confusion matrix value 128 | for i in range(num_classes): 129 | for j in range(num_classes): 130 | ax.text( 131 | j, 132 | i, 133 | '{}%'.format( 134 | round(confusion_matrix[i, j], 2 135 | ) if not np.isnan(confusion_matrix[i, j]) else -1), 136 | ha='center', 137 | va='center', 138 | color='w', 139 | size=7) 140 | 141 | ax.set_ylim(len(confusion_matrix) - 0.5, -0.5) # matplotlib>3.1.1 142 | 143 | fig.tight_layout() 144 | if save_dir is not None: 145 | plt.savefig( 146 | os.path.join(save_dir, 'confusion_matrix.png'), format='png') 147 | if show: 148 | plt.show() 149 | 150 | 151 | def main(): 152 | args = parse_args() 153 | 154 | cfg = Config.fromfile(args.config) 155 | if args.cfg_options is not None: 156 | cfg.merge_from_dict(args.cfg_options) 157 | 158 | results = mmcv.load(args.prediction_path) 159 | 160 | assert isinstance(results, list) 161 | if isinstance(results[0], np.ndarray): 162 | pass 163 | else: 164 | raise TypeError('invalid type of prediction results') 165 | 166 | if isinstance(cfg.data.test, dict): 167 | cfg.data.test.test_mode = True 168 | elif isinstance(cfg.data.test, list): 169 | for ds_cfg in cfg.data.test: 170 | ds_cfg.test_mode = True 171 | 172 | dataset = build_dataset(cfg.data.test) 173 | confusion_matrix = calculate_confusion_matrix(dataset, results) 174 | plot_confusion_matrix( 175 | confusion_matrix, 176 | dataset.CLASSES, 177 | save_dir=args.save_dir, 178 | show=args.show, 179 | title=args.title, 180 | color_theme=args.color_theme) 181 | 182 | 183 | if __name__ == '__main__': 184 | main() 185 | -------------------------------------------------------------------------------- /tools/convert_datasets/chase_db1.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os 4 | import os.path as osp 5 | import tempfile 6 | import zipfile 7 | 8 | import mmcv 9 | 10 | CHASE_DB1_LEN = 28 * 3 11 | TRAINING_LEN = 60 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser( 16 | description='Convert CHASE_DB1 dataset to mmsegmentation format') 17 | parser.add_argument('dataset_path', help='path of CHASEDB1.zip') 18 | parser.add_argument('--tmp_dir', help='path of the temporary directory') 19 | parser.add_argument('-o', '--out_dir', help='output path') 20 | args = parser.parse_args() 21 | return args 22 | 23 | 24 | def main(): 25 | args = parse_args() 26 | dataset_path = args.dataset_path 27 | if args.out_dir is None: 28 | out_dir = osp.join('data', 'CHASE_DB1') 29 | else: 30 | out_dir = args.out_dir 31 | 32 | print('Making directories...') 33 | mmcv.mkdir_or_exist(out_dir) 34 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images')) 35 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'training')) 36 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'validation')) 37 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations')) 38 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'training')) 39 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation')) 40 | 41 | with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: 42 | print('Extracting CHASEDB1.zip...') 43 | zip_file = zipfile.ZipFile(dataset_path) 44 | zip_file.extractall(tmp_dir) 45 | 46 | print('Generating training dataset...') 47 | 48 | assert len(os.listdir(tmp_dir)) == CHASE_DB1_LEN, \ 49 | 'len(os.listdir(tmp_dir)) != {}'.format(CHASE_DB1_LEN) 50 | 51 | for img_name in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]: 52 | img = mmcv.imread(osp.join(tmp_dir, img_name)) 53 | if osp.splitext(img_name)[1] == '.jpg': 54 | mmcv.imwrite( 55 | img, 56 | osp.join(out_dir, 'images', 'training', 57 | osp.splitext(img_name)[0] + '.png')) 58 | else: 59 | # The annotation img should be divided by 128, because some of 60 | # the annotation imgs are not standard. We should set a 61 | # threshold to convert the nonstandard annotation imgs. The 62 | # value divided by 128 is equivalent to '1 if value >= 128 63 | # else 0' 64 | mmcv.imwrite( 65 | img[:, :, 0] // 128, 66 | osp.join(out_dir, 'annotations', 'training', 67 | osp.splitext(img_name)[0] + '.png')) 68 | 69 | for img_name in sorted(os.listdir(tmp_dir))[TRAINING_LEN:]: 70 | img = mmcv.imread(osp.join(tmp_dir, img_name)) 71 | if osp.splitext(img_name)[1] == '.jpg': 72 | mmcv.imwrite( 73 | img, 74 | osp.join(out_dir, 'images', 'validation', 75 | osp.splitext(img_name)[0] + '.png')) 76 | else: 77 | mmcv.imwrite( 78 | img[:, :, 0] // 128, 79 | osp.join(out_dir, 'annotations', 'validation', 80 | osp.splitext(img_name)[0] + '.png')) 81 | 82 | print('Removing the temporary files...') 83 | 84 | print('Done!') 85 | 86 | 87 | if __name__ == '__main__': 88 | main() 89 | -------------------------------------------------------------------------------- /tools/convert_datasets/cityscapes.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os.path as osp 4 | 5 | import mmcv 6 | from cityscapesscripts.preparation.json2labelImg import json2labelImg 7 | 8 | 9 | def convert_json_to_label(json_file): 10 | label_file = json_file.replace('_polygons.json', '_labelTrainIds.png') 11 | json2labelImg(json_file, label_file, 'trainIds') 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser( 16 | description='Convert Cityscapes annotations to TrainIds') 17 | parser.add_argument('cityscapes_path', help='cityscapes data path') 18 | parser.add_argument('--gt-dir', default='gtFine', type=str) 19 | parser.add_argument('-o', '--out-dir', help='output path') 20 | parser.add_argument( 21 | '--nproc', default=1, type=int, help='number of process') 22 | args = parser.parse_args() 23 | return args 24 | 25 | 26 | def main(): 27 | args = parse_args() 28 | cityscapes_path = args.cityscapes_path 29 | out_dir = args.out_dir if args.out_dir else cityscapes_path 30 | mmcv.mkdir_or_exist(out_dir) 31 | 32 | gt_dir = osp.join(cityscapes_path, args.gt_dir) 33 | 34 | poly_files = [] 35 | for poly in mmcv.scandir(gt_dir, '_polygons.json', recursive=True): 36 | poly_file = osp.join(gt_dir, poly) 37 | poly_files.append(poly_file) 38 | if args.nproc > 1: 39 | mmcv.track_parallel_progress(convert_json_to_label, poly_files, 40 | args.nproc) 41 | else: 42 | mmcv.track_progress(convert_json_to_label, poly_files) 43 | 44 | split_names = ['train', 'val', 'test'] 45 | 46 | for split in split_names: 47 | filenames = [] 48 | for poly in mmcv.scandir( 49 | osp.join(gt_dir, split), '_polygons.json', recursive=True): 50 | filenames.append(poly.replace('_gtFine_polygons.json', '')) 51 | with open(osp.join(out_dir, f'{split}.txt'), 'w') as f: 52 | f.writelines(f + '\n' for f in filenames) 53 | 54 | 55 | if __name__ == '__main__': 56 | main() 57 | -------------------------------------------------------------------------------- /tools/convert_datasets/coco_stuff10k.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os.path as osp 4 | import shutil 5 | from functools import partial 6 | 7 | import mmcv 8 | import numpy as np 9 | from PIL import Image 10 | from scipy.io import loadmat 11 | 12 | COCO_LEN = 10000 13 | 14 | clsID_to_trID = { 15 | 0: 0, 16 | 1: 1, 17 | 2: 2, 18 | 3: 3, 19 | 4: 4, 20 | 5: 5, 21 | 6: 6, 22 | 7: 7, 23 | 8: 8, 24 | 9: 9, 25 | 10: 10, 26 | 11: 11, 27 | 13: 12, 28 | 14: 13, 29 | 15: 14, 30 | 16: 15, 31 | 17: 16, 32 | 18: 17, 33 | 19: 18, 34 | 20: 19, 35 | 21: 20, 36 | 22: 21, 37 | 23: 22, 38 | 24: 23, 39 | 25: 24, 40 | 27: 25, 41 | 28: 26, 42 | 31: 27, 43 | 32: 28, 44 | 33: 29, 45 | 34: 30, 46 | 35: 31, 47 | 36: 32, 48 | 37: 33, 49 | 38: 34, 50 | 39: 35, 51 | 40: 36, 52 | 41: 37, 53 | 42: 38, 54 | 43: 39, 55 | 44: 40, 56 | 46: 41, 57 | 47: 42, 58 | 48: 43, 59 | 49: 44, 60 | 50: 45, 61 | 51: 46, 62 | 52: 47, 63 | 53: 48, 64 | 54: 49, 65 | 55: 50, 66 | 56: 51, 67 | 57: 52, 68 | 58: 53, 69 | 59: 54, 70 | 60: 55, 71 | 61: 56, 72 | 62: 57, 73 | 63: 58, 74 | 64: 59, 75 | 65: 60, 76 | 67: 61, 77 | 70: 62, 78 | 72: 63, 79 | 73: 64, 80 | 74: 65, 81 | 75: 66, 82 | 76: 67, 83 | 77: 68, 84 | 78: 69, 85 | 79: 70, 86 | 80: 71, 87 | 81: 72, 88 | 82: 73, 89 | 84: 74, 90 | 85: 75, 91 | 86: 76, 92 | 87: 77, 93 | 88: 78, 94 | 89: 79, 95 | 90: 80, 96 | 92: 81, 97 | 93: 82, 98 | 94: 83, 99 | 95: 84, 100 | 96: 85, 101 | 97: 86, 102 | 98: 87, 103 | 99: 88, 104 | 100: 89, 105 | 101: 90, 106 | 102: 91, 107 | 103: 92, 108 | 104: 93, 109 | 105: 94, 110 | 106: 95, 111 | 107: 96, 112 | 108: 97, 113 | 109: 98, 114 | 110: 99, 115 | 111: 100, 116 | 112: 101, 117 | 113: 102, 118 | 114: 103, 119 | 115: 104, 120 | 116: 105, 121 | 117: 106, 122 | 118: 107, 123 | 119: 108, 124 | 120: 109, 125 | 121: 110, 126 | 122: 111, 127 | 123: 112, 128 | 124: 113, 129 | 125: 114, 130 | 126: 115, 131 | 127: 116, 132 | 128: 117, 133 | 129: 118, 134 | 130: 119, 135 | 131: 120, 136 | 132: 121, 137 | 133: 122, 138 | 134: 123, 139 | 135: 124, 140 | 136: 125, 141 | 137: 126, 142 | 138: 127, 143 | 139: 128, 144 | 140: 129, 145 | 141: 130, 146 | 142: 131, 147 | 143: 132, 148 | 144: 133, 149 | 145: 134, 150 | 146: 135, 151 | 147: 136, 152 | 148: 137, 153 | 149: 138, 154 | 150: 139, 155 | 151: 140, 156 | 152: 141, 157 | 153: 142, 158 | 154: 143, 159 | 155: 144, 160 | 156: 145, 161 | 157: 146, 162 | 158: 147, 163 | 159: 148, 164 | 160: 149, 165 | 161: 150, 166 | 162: 151, 167 | 163: 152, 168 | 164: 153, 169 | 165: 154, 170 | 166: 155, 171 | 167: 156, 172 | 168: 157, 173 | 169: 158, 174 | 170: 159, 175 | 171: 160, 176 | 172: 161, 177 | 173: 162, 178 | 174: 163, 179 | 175: 164, 180 | 176: 165, 181 | 177: 166, 182 | 178: 167, 183 | 179: 168, 184 | 180: 169, 185 | 181: 170, 186 | 182: 171 187 | } 188 | 189 | 190 | def convert_to_trainID(tuple_path, in_img_dir, in_ann_dir, out_img_dir, 191 | out_mask_dir, is_train): 192 | imgpath, maskpath = tuple_path 193 | shutil.copyfile( 194 | osp.join(in_img_dir, imgpath), 195 | osp.join(out_img_dir, 'train2014', imgpath) if is_train else osp.join( 196 | out_img_dir, 'test2014', imgpath)) 197 | annotate = loadmat(osp.join(in_ann_dir, maskpath)) 198 | mask = annotate['S'].astype(np.uint8) 199 | mask_copy = mask.copy() 200 | for clsID, trID in clsID_to_trID.items(): 201 | mask_copy[mask == clsID] = trID 202 | seg_filename = osp.join(out_mask_dir, 'train2014', 203 | maskpath.split('.')[0] + 204 | '_labelTrainIds.png') if is_train else osp.join( 205 | out_mask_dir, 'test2014', 206 | maskpath.split('.')[0] + '_labelTrainIds.png') 207 | Image.fromarray(mask_copy).save(seg_filename, 'PNG') 208 | 209 | 210 | def generate_coco_list(folder): 211 | train_list = osp.join(folder, 'imageLists', 'train.txt') 212 | test_list = osp.join(folder, 'imageLists', 'test.txt') 213 | train_paths = [] 214 | test_paths = [] 215 | 216 | with open(train_list) as f: 217 | for filename in f: 218 | basename = filename.strip() 219 | imgpath = basename + '.jpg' 220 | maskpath = basename + '.mat' 221 | train_paths.append((imgpath, maskpath)) 222 | 223 | with open(test_list) as f: 224 | for filename in f: 225 | basename = filename.strip() 226 | imgpath = basename + '.jpg' 227 | maskpath = basename + '.mat' 228 | test_paths.append((imgpath, maskpath)) 229 | 230 | return train_paths, test_paths 231 | 232 | 233 | def parse_args(): 234 | parser = argparse.ArgumentParser( 235 | description=\ 236 | 'Convert COCO Stuff 10k annotations to mmsegmentation format') # noqa 237 | parser.add_argument('coco_path', help='coco stuff path') 238 | parser.add_argument('-o', '--out_dir', help='output path') 239 | parser.add_argument( 240 | '--nproc', default=16, type=int, help='number of process') 241 | args = parser.parse_args() 242 | return args 243 | 244 | 245 | def main(): 246 | args = parse_args() 247 | coco_path = args.coco_path 248 | nproc = args.nproc 249 | 250 | out_dir = args.out_dir or coco_path 251 | out_img_dir = osp.join(out_dir, 'images') 252 | out_mask_dir = osp.join(out_dir, 'annotations') 253 | 254 | mmcv.mkdir_or_exist(osp.join(out_img_dir, 'train2014')) 255 | mmcv.mkdir_or_exist(osp.join(out_img_dir, 'test2014')) 256 | mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'train2014')) 257 | mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'test2014')) 258 | 259 | train_list, test_list = generate_coco_list(coco_path) 260 | assert (len(train_list) + 261 | len(test_list)) == COCO_LEN, 'Wrong length of list {} & {}'.format( 262 | len(train_list), len(test_list)) 263 | 264 | if args.nproc > 1: 265 | mmcv.track_parallel_progress( 266 | partial( 267 | convert_to_trainID, 268 | in_img_dir=osp.join(coco_path, 'images'), 269 | in_ann_dir=osp.join(coco_path, 'annotations'), 270 | out_img_dir=out_img_dir, 271 | out_mask_dir=out_mask_dir, 272 | is_train=True), 273 | train_list, 274 | nproc=nproc) 275 | mmcv.track_parallel_progress( 276 | partial( 277 | convert_to_trainID, 278 | in_img_dir=osp.join(coco_path, 'images'), 279 | in_ann_dir=osp.join(coco_path, 'annotations'), 280 | out_img_dir=out_img_dir, 281 | out_mask_dir=out_mask_dir, 282 | is_train=False), 283 | test_list, 284 | nproc=nproc) 285 | else: 286 | mmcv.track_progress( 287 | partial( 288 | convert_to_trainID, 289 | in_img_dir=osp.join(coco_path, 'images'), 290 | in_ann_dir=osp.join(coco_path, 'annotations'), 291 | out_img_dir=out_img_dir, 292 | out_mask_dir=out_mask_dir, 293 | is_train=True), train_list) 294 | mmcv.track_progress( 295 | partial( 296 | convert_to_trainID, 297 | in_img_dir=osp.join(coco_path, 'images'), 298 | in_ann_dir=osp.join(coco_path, 'annotations'), 299 | out_img_dir=out_img_dir, 300 | out_mask_dir=out_mask_dir, 301 | is_train=False), test_list) 302 | 303 | print('Done!') 304 | 305 | 306 | if __name__ == '__main__': 307 | main() 308 | -------------------------------------------------------------------------------- /tools/convert_datasets/coco_stuff164k.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os.path as osp 4 | import shutil 5 | from functools import partial 6 | from glob import glob 7 | 8 | import mmcv 9 | import numpy as np 10 | from PIL import Image 11 | 12 | COCO_LEN = 123287 13 | 14 | clsID_to_trID = { 15 | 0: 0, 16 | 1: 1, 17 | 2: 2, 18 | 3: 3, 19 | 4: 4, 20 | 5: 5, 21 | 6: 6, 22 | 7: 7, 23 | 8: 8, 24 | 9: 9, 25 | 10: 10, 26 | 12: 11, 27 | 13: 12, 28 | 14: 13, 29 | 15: 14, 30 | 16: 15, 31 | 17: 16, 32 | 18: 17, 33 | 19: 18, 34 | 20: 19, 35 | 21: 20, 36 | 22: 21, 37 | 23: 22, 38 | 24: 23, 39 | 26: 24, 40 | 27: 25, 41 | 30: 26, 42 | 31: 27, 43 | 32: 28, 44 | 33: 29, 45 | 34: 30, 46 | 35: 31, 47 | 36: 32, 48 | 37: 33, 49 | 38: 34, 50 | 39: 35, 51 | 40: 36, 52 | 41: 37, 53 | 42: 38, 54 | 43: 39, 55 | 45: 40, 56 | 46: 41, 57 | 47: 42, 58 | 48: 43, 59 | 49: 44, 60 | 50: 45, 61 | 51: 46, 62 | 52: 47, 63 | 53: 48, 64 | 54: 49, 65 | 55: 50, 66 | 56: 51, 67 | 57: 52, 68 | 58: 53, 69 | 59: 54, 70 | 60: 55, 71 | 61: 56, 72 | 62: 57, 73 | 63: 58, 74 | 64: 59, 75 | 66: 60, 76 | 69: 61, 77 | 71: 62, 78 | 72: 63, 79 | 73: 64, 80 | 74: 65, 81 | 75: 66, 82 | 76: 67, 83 | 77: 68, 84 | 78: 69, 85 | 79: 70, 86 | 80: 71, 87 | 81: 72, 88 | 83: 73, 89 | 84: 74, 90 | 85: 75, 91 | 86: 76, 92 | 87: 77, 93 | 88: 78, 94 | 89: 79, 95 | 91: 80, 96 | 92: 81, 97 | 93: 82, 98 | 94: 83, 99 | 95: 84, 100 | 96: 85, 101 | 97: 86, 102 | 98: 87, 103 | 99: 88, 104 | 100: 89, 105 | 101: 90, 106 | 102: 91, 107 | 103: 92, 108 | 104: 93, 109 | 105: 94, 110 | 106: 95, 111 | 107: 96, 112 | 108: 97, 113 | 109: 98, 114 | 110: 99, 115 | 111: 100, 116 | 112: 101, 117 | 113: 102, 118 | 114: 103, 119 | 115: 104, 120 | 116: 105, 121 | 117: 106, 122 | 118: 107, 123 | 119: 108, 124 | 120: 109, 125 | 121: 110, 126 | 122: 111, 127 | 123: 112, 128 | 124: 113, 129 | 125: 114, 130 | 126: 115, 131 | 127: 116, 132 | 128: 117, 133 | 129: 118, 134 | 130: 119, 135 | 131: 120, 136 | 132: 121, 137 | 133: 122, 138 | 134: 123, 139 | 135: 124, 140 | 136: 125, 141 | 137: 126, 142 | 138: 127, 143 | 139: 128, 144 | 140: 129, 145 | 141: 130, 146 | 142: 131, 147 | 143: 132, 148 | 144: 133, 149 | 145: 134, 150 | 146: 135, 151 | 147: 136, 152 | 148: 137, 153 | 149: 138, 154 | 150: 139, 155 | 151: 140, 156 | 152: 141, 157 | 153: 142, 158 | 154: 143, 159 | 155: 144, 160 | 156: 145, 161 | 157: 146, 162 | 158: 147, 163 | 159: 148, 164 | 160: 149, 165 | 161: 150, 166 | 162: 151, 167 | 163: 152, 168 | 164: 153, 169 | 165: 154, 170 | 166: 155, 171 | 167: 156, 172 | 168: 157, 173 | 169: 158, 174 | 170: 159, 175 | 171: 160, 176 | 172: 161, 177 | 173: 162, 178 | 174: 163, 179 | 175: 164, 180 | 176: 165, 181 | 177: 166, 182 | 178: 167, 183 | 179: 168, 184 | 180: 169, 185 | 181: 170, 186 | 255: 255 187 | } 188 | 189 | 190 | def convert_to_trainID(maskpath, out_mask_dir, is_train): 191 | mask = np.array(Image.open(maskpath)) 192 | mask_copy = mask.copy() 193 | for clsID, trID in clsID_to_trID.items(): 194 | mask_copy[mask == clsID] = trID 195 | seg_filename = osp.join( 196 | out_mask_dir, 'train2017', 197 | osp.basename(maskpath).split('.')[0] + 198 | '_labelTrainIds.png') if is_train else osp.join( 199 | out_mask_dir, 'val2017', 200 | osp.basename(maskpath).split('.')[0] + '_labelTrainIds.png') 201 | Image.fromarray(mask_copy).save(seg_filename, 'PNG') 202 | 203 | 204 | def parse_args(): 205 | parser = argparse.ArgumentParser( 206 | description=\ 207 | 'Convert COCO Stuff 164k annotations to mmsegmentation format') # noqa 208 | parser.add_argument('coco_path', help='coco stuff path') 209 | parser.add_argument('-o', '--out_dir', help='output path') 210 | parser.add_argument( 211 | '--nproc', default=16, type=int, help='number of process') 212 | args = parser.parse_args() 213 | return args 214 | 215 | 216 | def main(): 217 | args = parse_args() 218 | coco_path = args.coco_path 219 | nproc = args.nproc 220 | 221 | out_dir = args.out_dir or coco_path 222 | out_img_dir = osp.join(out_dir, 'images') 223 | out_mask_dir = osp.join(out_dir, 'annotations') 224 | 225 | mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'train2017')) 226 | mmcv.mkdir_or_exist(osp.join(out_mask_dir, 'val2017')) 227 | 228 | if out_dir != coco_path: 229 | shutil.copytree(osp.join(coco_path, 'images'), out_img_dir) 230 | 231 | train_list = glob(osp.join(coco_path, 'annotations', 'train2017', '*.png')) 232 | train_list = [file for file in train_list if '_labelTrainIds' not in file] 233 | test_list = glob(osp.join(coco_path, 'annotations', 'val2017', '*.png')) 234 | test_list = [file for file in test_list if '_labelTrainIds' not in file] 235 | assert (len(train_list) + 236 | len(test_list)) == COCO_LEN, 'Wrong length of list {} & {}'.format( 237 | len(train_list), len(test_list)) 238 | 239 | if args.nproc > 1: 240 | mmcv.track_parallel_progress( 241 | partial( 242 | convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True), 243 | train_list, 244 | nproc=nproc) 245 | mmcv.track_parallel_progress( 246 | partial( 247 | convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False), 248 | test_list, 249 | nproc=nproc) 250 | else: 251 | mmcv.track_progress( 252 | partial( 253 | convert_to_trainID, out_mask_dir=out_mask_dir, is_train=True), 254 | train_list) 255 | mmcv.track_progress( 256 | partial( 257 | convert_to_trainID, out_mask_dir=out_mask_dir, is_train=False), 258 | test_list) 259 | 260 | print('Done!') 261 | 262 | 263 | if __name__ == '__main__': 264 | main() 265 | -------------------------------------------------------------------------------- /tools/convert_datasets/drive.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os 4 | import os.path as osp 5 | import tempfile 6 | import zipfile 7 | 8 | import cv2 9 | import mmcv 10 | 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser( 14 | description='Convert DRIVE dataset to mmsegmentation format') 15 | parser.add_argument( 16 | 'training_path', help='the training part of DRIVE dataset') 17 | parser.add_argument( 18 | 'testing_path', help='the testing part of DRIVE dataset') 19 | parser.add_argument('--tmp_dir', help='path of the temporary directory') 20 | parser.add_argument('-o', '--out_dir', help='output path') 21 | args = parser.parse_args() 22 | return args 23 | 24 | 25 | def main(): 26 | args = parse_args() 27 | training_path = args.training_path 28 | testing_path = args.testing_path 29 | if args.out_dir is None: 30 | out_dir = osp.join('data', 'DRIVE') 31 | else: 32 | out_dir = args.out_dir 33 | 34 | print('Making directories...') 35 | mmcv.mkdir_or_exist(out_dir) 36 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images')) 37 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'training')) 38 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'validation')) 39 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations')) 40 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'training')) 41 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation')) 42 | 43 | with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: 44 | print('Extracting training.zip...') 45 | zip_file = zipfile.ZipFile(training_path) 46 | zip_file.extractall(tmp_dir) 47 | 48 | print('Generating training dataset...') 49 | now_dir = osp.join(tmp_dir, 'training', 'images') 50 | for img_name in os.listdir(now_dir): 51 | img = mmcv.imread(osp.join(now_dir, img_name)) 52 | mmcv.imwrite( 53 | img, 54 | osp.join( 55 | out_dir, 'images', 'training', 56 | osp.splitext(img_name)[0].replace('_training', '') + 57 | '.png')) 58 | 59 | now_dir = osp.join(tmp_dir, 'training', '1st_manual') 60 | for img_name in os.listdir(now_dir): 61 | cap = cv2.VideoCapture(osp.join(now_dir, img_name)) 62 | ret, img = cap.read() 63 | mmcv.imwrite( 64 | img[:, :, 0] // 128, 65 | osp.join(out_dir, 'annotations', 'training', 66 | osp.splitext(img_name)[0] + '.png')) 67 | 68 | print('Extracting test.zip...') 69 | zip_file = zipfile.ZipFile(testing_path) 70 | zip_file.extractall(tmp_dir) 71 | 72 | print('Generating validation dataset...') 73 | now_dir = osp.join(tmp_dir, 'test', 'images') 74 | for img_name in os.listdir(now_dir): 75 | img = mmcv.imread(osp.join(now_dir, img_name)) 76 | mmcv.imwrite( 77 | img, 78 | osp.join( 79 | out_dir, 'images', 'validation', 80 | osp.splitext(img_name)[0].replace('_test', '') + '.png')) 81 | 82 | now_dir = osp.join(tmp_dir, 'test', '1st_manual') 83 | if osp.exists(now_dir): 84 | for img_name in os.listdir(now_dir): 85 | cap = cv2.VideoCapture(osp.join(now_dir, img_name)) 86 | ret, img = cap.read() 87 | # The annotation img should be divided by 128, because some of 88 | # the annotation imgs are not standard. We should set a 89 | # threshold to convert the nonstandard annotation imgs. The 90 | # value divided by 128 is equivalent to '1 if value >= 128 91 | # else 0' 92 | mmcv.imwrite( 93 | img[:, :, 0] // 128, 94 | osp.join(out_dir, 'annotations', 'validation', 95 | osp.splitext(img_name)[0] + '.png')) 96 | 97 | now_dir = osp.join(tmp_dir, 'test', '2nd_manual') 98 | if osp.exists(now_dir): 99 | for img_name in os.listdir(now_dir): 100 | cap = cv2.VideoCapture(osp.join(now_dir, img_name)) 101 | ret, img = cap.read() 102 | mmcv.imwrite( 103 | img[:, :, 0] // 128, 104 | osp.join(out_dir, 'annotations', 'validation', 105 | osp.splitext(img_name)[0] + '.png')) 106 | 107 | print('Removing the temporary files...') 108 | 109 | print('Done!') 110 | 111 | 112 | if __name__ == '__main__': 113 | main() 114 | -------------------------------------------------------------------------------- /tools/convert_datasets/hrf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os 4 | import os.path as osp 5 | import tempfile 6 | import zipfile 7 | 8 | import mmcv 9 | 10 | HRF_LEN = 15 11 | TRAINING_LEN = 5 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser( 16 | description='Convert HRF dataset to mmsegmentation format') 17 | parser.add_argument('healthy_path', help='the path of healthy.zip') 18 | parser.add_argument( 19 | 'healthy_manualsegm_path', help='the path of healthy_manualsegm.zip') 20 | parser.add_argument('glaucoma_path', help='the path of glaucoma.zip') 21 | parser.add_argument( 22 | 'glaucoma_manualsegm_path', help='the path of glaucoma_manualsegm.zip') 23 | parser.add_argument( 24 | 'diabetic_retinopathy_path', 25 | help='the path of diabetic_retinopathy.zip') 26 | parser.add_argument( 27 | 'diabetic_retinopathy_manualsegm_path', 28 | help='the path of diabetic_retinopathy_manualsegm.zip') 29 | parser.add_argument('--tmp_dir', help='path of the temporary directory') 30 | parser.add_argument('-o', '--out_dir', help='output path') 31 | args = parser.parse_args() 32 | return args 33 | 34 | 35 | def main(): 36 | args = parse_args() 37 | images_path = [ 38 | args.healthy_path, args.glaucoma_path, args.diabetic_retinopathy_path 39 | ] 40 | annotations_path = [ 41 | args.healthy_manualsegm_path, args.glaucoma_manualsegm_path, 42 | args.diabetic_retinopathy_manualsegm_path 43 | ] 44 | if args.out_dir is None: 45 | out_dir = osp.join('data', 'HRF') 46 | else: 47 | out_dir = args.out_dir 48 | 49 | print('Making directories...') 50 | mmcv.mkdir_or_exist(out_dir) 51 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images')) 52 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'training')) 53 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'validation')) 54 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations')) 55 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'training')) 56 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation')) 57 | 58 | print('Generating images...') 59 | for now_path in images_path: 60 | with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: 61 | zip_file = zipfile.ZipFile(now_path) 62 | zip_file.extractall(tmp_dir) 63 | 64 | assert len(os.listdir(tmp_dir)) == HRF_LEN, \ 65 | 'len(os.listdir(tmp_dir)) != {}'.format(HRF_LEN) 66 | 67 | for filename in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]: 68 | img = mmcv.imread(osp.join(tmp_dir, filename)) 69 | mmcv.imwrite( 70 | img, 71 | osp.join(out_dir, 'images', 'training', 72 | osp.splitext(filename)[0] + '.png')) 73 | for filename in sorted(os.listdir(tmp_dir))[TRAINING_LEN:]: 74 | img = mmcv.imread(osp.join(tmp_dir, filename)) 75 | mmcv.imwrite( 76 | img, 77 | osp.join(out_dir, 'images', 'validation', 78 | osp.splitext(filename)[0] + '.png')) 79 | 80 | print('Generating annotations...') 81 | for now_path in annotations_path: 82 | with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: 83 | zip_file = zipfile.ZipFile(now_path) 84 | zip_file.extractall(tmp_dir) 85 | 86 | assert len(os.listdir(tmp_dir)) == HRF_LEN, \ 87 | 'len(os.listdir(tmp_dir)) != {}'.format(HRF_LEN) 88 | 89 | for filename in sorted(os.listdir(tmp_dir))[:TRAINING_LEN]: 90 | img = mmcv.imread(osp.join(tmp_dir, filename)) 91 | # The annotation img should be divided by 128, because some of 92 | # the annotation imgs are not standard. We should set a 93 | # threshold to convert the nonstandard annotation imgs. The 94 | # value divided by 128 is equivalent to '1 if value >= 128 95 | # else 0' 96 | mmcv.imwrite( 97 | img[:, :, 0] // 128, 98 | osp.join(out_dir, 'annotations', 'training', 99 | osp.splitext(filename)[0] + '.png')) 100 | for filename in sorted(os.listdir(tmp_dir))[TRAINING_LEN:]: 101 | img = mmcv.imread(osp.join(tmp_dir, filename)) 102 | mmcv.imwrite( 103 | img[:, :, 0] // 128, 104 | osp.join(out_dir, 'annotations', 'validation', 105 | osp.splitext(filename)[0] + '.png')) 106 | 107 | print('Done!') 108 | 109 | 110 | if __name__ == '__main__': 111 | main() 112 | -------------------------------------------------------------------------------- /tools/convert_datasets/isaid.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import glob 4 | import os 5 | import os.path as osp 6 | import shutil 7 | import tempfile 8 | import zipfile 9 | 10 | import mmcv 11 | import numpy as np 12 | from PIL import Image 13 | 14 | iSAID_palette = \ 15 | { 16 | 0: (0, 0, 0), 17 | 1: (0, 0, 63), 18 | 2: (0, 63, 63), 19 | 3: (0, 63, 0), 20 | 4: (0, 63, 127), 21 | 5: (0, 63, 191), 22 | 6: (0, 63, 255), 23 | 7: (0, 127, 63), 24 | 8: (0, 127, 127), 25 | 9: (0, 0, 127), 26 | 10: (0, 0, 191), 27 | 11: (0, 0, 255), 28 | 12: (0, 191, 127), 29 | 13: (0, 127, 191), 30 | 14: (0, 127, 255), 31 | 15: (0, 100, 155) 32 | } 33 | 34 | iSAID_invert_palette = {v: k for k, v in iSAID_palette.items()} 35 | 36 | 37 | def iSAID_convert_from_color(arr_3d, palette=iSAID_invert_palette): 38 | """RGB-color encoding to grayscale labels.""" 39 | arr_2d = np.zeros((arr_3d.shape[0], arr_3d.shape[1]), dtype=np.uint8) 40 | 41 | for c, i in palette.items(): 42 | m = np.all(arr_3d == np.array(c).reshape(1, 1, 3), axis=2) 43 | arr_2d[m] = i 44 | 45 | return arr_2d 46 | 47 | 48 | def slide_crop_image(src_path, out_dir, mode, patch_H, patch_W, overlap): 49 | img = np.asarray(Image.open(src_path).convert('RGB')) 50 | 51 | img_H, img_W, _ = img.shape 52 | 53 | if img_H < patch_H and img_W > patch_W: 54 | 55 | img = mmcv.impad(img, shape=(patch_H, img_W), pad_val=0) 56 | 57 | img_H, img_W, _ = img.shape 58 | 59 | elif img_H > patch_H and img_W < patch_W: 60 | 61 | img = mmcv.impad(img, shape=(img_H, patch_W), pad_val=0) 62 | 63 | img_H, img_W, _ = img.shape 64 | 65 | elif img_H < patch_H and img_W < patch_W: 66 | 67 | img = mmcv.impad(img, shape=(patch_H, patch_W), pad_val=0) 68 | 69 | img_H, img_W, _ = img.shape 70 | 71 | for x in range(0, img_W, patch_W - overlap): 72 | for y in range(0, img_H, patch_H - overlap): 73 | x_str = x 74 | x_end = x + patch_W 75 | if x_end > img_W: 76 | diff_x = x_end - img_W 77 | x_str -= diff_x 78 | x_end = img_W 79 | y_str = y 80 | y_end = y + patch_H 81 | if y_end > img_H: 82 | diff_y = y_end - img_H 83 | y_str -= diff_y 84 | y_end = img_H 85 | 86 | img_patch = img[y_str:y_end, x_str:x_end, :] 87 | img_patch = Image.fromarray(img_patch.astype(np.uint8)) 88 | image = osp.basename(src_path).split('.')[0] + '_' + str( 89 | y_str) + '_' + str(y_end) + '_' + str(x_str) + '_' + str( 90 | x_end) + '.png' 91 | # print(image) 92 | save_path_image = osp.join(out_dir, 'img_dir', mode, str(image)) 93 | img_patch.save(save_path_image) 94 | 95 | 96 | def slide_crop_label(src_path, out_dir, mode, patch_H, patch_W, overlap): 97 | label = mmcv.imread(src_path, channel_order='rgb') 98 | label = iSAID_convert_from_color(label) 99 | img_H, img_W = label.shape 100 | 101 | if img_H < patch_H and img_W > patch_W: 102 | 103 | label = mmcv.impad(label, shape=(patch_H, img_W), pad_val=255) 104 | 105 | img_H = patch_H 106 | 107 | elif img_H > patch_H and img_W < patch_W: 108 | 109 | label = mmcv.impad(label, shape=(img_H, patch_W), pad_val=255) 110 | 111 | img_W = patch_W 112 | 113 | elif img_H < patch_H and img_W < patch_W: 114 | 115 | label = mmcv.impad(label, shape=(patch_H, patch_W), pad_val=255) 116 | 117 | img_H = patch_H 118 | img_W = patch_W 119 | 120 | for x in range(0, img_W, patch_W - overlap): 121 | for y in range(0, img_H, patch_H - overlap): 122 | x_str = x 123 | x_end = x + patch_W 124 | if x_end > img_W: 125 | diff_x = x_end - img_W 126 | x_str -= diff_x 127 | x_end = img_W 128 | y_str = y 129 | y_end = y + patch_H 130 | if y_end > img_H: 131 | diff_y = y_end - img_H 132 | y_str -= diff_y 133 | y_end = img_H 134 | 135 | lab_patch = label[y_str:y_end, x_str:x_end] 136 | lab_patch = Image.fromarray(lab_patch.astype(np.uint8), mode='P') 137 | 138 | image = osp.basename(src_path).split('.')[0].split( 139 | '_')[0] + '_' + str(y_str) + '_' + str(y_end) + '_' + str( 140 | x_str) + '_' + str(x_end) + '_instance_color_RGB' + '.png' 141 | lab_patch.save(osp.join(out_dir, 'ann_dir', mode, str(image))) 142 | 143 | 144 | def parse_args(): 145 | parser = argparse.ArgumentParser( 146 | description='Convert iSAID dataset to mmsegmentation format') 147 | parser.add_argument('dataset_path', help='iSAID folder path') 148 | parser.add_argument('--tmp_dir', help='path of the temporary directory') 149 | parser.add_argument('-o', '--out_dir', help='output path') 150 | 151 | parser.add_argument( 152 | '--patch_width', 153 | default=896, 154 | type=int, 155 | help='Width of the cropped image patch') 156 | parser.add_argument( 157 | '--patch_height', 158 | default=896, 159 | type=int, 160 | help='Height of the cropped image patch') 161 | parser.add_argument( 162 | '--overlap_area', default=384, type=int, help='Overlap area') 163 | args = parser.parse_args() 164 | return args 165 | 166 | 167 | def main(): 168 | args = parse_args() 169 | dataset_path = args.dataset_path 170 | # image patch width and height 171 | patch_H, patch_W = args.patch_width, args.patch_height 172 | 173 | overlap = args.overlap_area # overlap area 174 | 175 | if args.out_dir is None: 176 | out_dir = osp.join('data', 'iSAID') 177 | else: 178 | out_dir = args.out_dir 179 | 180 | print('Making directories...') 181 | mmcv.mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train')) 182 | mmcv.mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val')) 183 | mmcv.mkdir_or_exist(osp.join(out_dir, 'img_dir', 'test')) 184 | 185 | mmcv.mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train')) 186 | mmcv.mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val')) 187 | mmcv.mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'test')) 188 | 189 | assert os.path.exists(os.path.join(dataset_path, 'train')), \ 190 | 'train is not in {}'.format(dataset_path) 191 | assert os.path.exists(os.path.join(dataset_path, 'val')), \ 192 | 'val is not in {}'.format(dataset_path) 193 | assert os.path.exists(os.path.join(dataset_path, 'test')), \ 194 | 'test is not in {}'.format(dataset_path) 195 | 196 | with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: 197 | for dataset_mode in ['train', 'val', 'test']: 198 | 199 | # for dataset_mode in [ 'test']: 200 | print('Extracting {}ing.zip...'.format(dataset_mode)) 201 | img_zipp_list = glob.glob( 202 | os.path.join(dataset_path, dataset_mode, 'images', '*.zip')) 203 | print('Find the data', img_zipp_list) 204 | for img_zipp in img_zipp_list: 205 | zip_file = zipfile.ZipFile(img_zipp) 206 | zip_file.extractall(os.path.join(tmp_dir, dataset_mode, 'img')) 207 | src_path_list = glob.glob( 208 | os.path.join(tmp_dir, dataset_mode, 'img', 'images', '*.png')) 209 | 210 | src_prog_bar = mmcv.ProgressBar(len(src_path_list)) 211 | for i, img_path in enumerate(src_path_list): 212 | if dataset_mode != 'test': 213 | slide_crop_image(img_path, out_dir, dataset_mode, patch_H, 214 | patch_W, overlap) 215 | 216 | else: 217 | shutil.move(img_path, 218 | os.path.join(out_dir, 'img_dir', dataset_mode)) 219 | src_prog_bar.update() 220 | 221 | if dataset_mode != 'test': 222 | label_zipp_list = glob.glob( 223 | os.path.join(dataset_path, dataset_mode, 'Semantic_masks', 224 | '*.zip')) 225 | for label_zipp in label_zipp_list: 226 | zip_file = zipfile.ZipFile(label_zipp) 227 | zip_file.extractall( 228 | os.path.join(tmp_dir, dataset_mode, 'lab')) 229 | 230 | lab_path_list = glob.glob( 231 | os.path.join(tmp_dir, dataset_mode, 'lab', 'images', 232 | '*.png')) 233 | lab_prog_bar = mmcv.ProgressBar(len(lab_path_list)) 234 | for i, lab_path in enumerate(lab_path_list): 235 | slide_crop_label(lab_path, out_dir, dataset_mode, patch_H, 236 | patch_W, overlap) 237 | lab_prog_bar.update() 238 | 239 | print('Removing the temporary files...') 240 | 241 | print('Done!') 242 | 243 | 244 | if __name__ == '__main__': 245 | main() 246 | -------------------------------------------------------------------------------- /tools/convert_datasets/loveda.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os 4 | import os.path as osp 5 | import shutil 6 | import tempfile 7 | import zipfile 8 | 9 | import mmcv 10 | 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser( 14 | description='Convert LoveDA dataset to mmsegmentation format') 15 | parser.add_argument('dataset_path', help='LoveDA folder path') 16 | parser.add_argument('--tmp_dir', help='path of the temporary directory') 17 | parser.add_argument('-o', '--out_dir', help='output path') 18 | args = parser.parse_args() 19 | return args 20 | 21 | 22 | def main(): 23 | args = parse_args() 24 | dataset_path = args.dataset_path 25 | if args.out_dir is None: 26 | out_dir = osp.join('data', 'loveDA') 27 | else: 28 | out_dir = args.out_dir 29 | 30 | print('Making directories...') 31 | mmcv.mkdir_or_exist(out_dir) 32 | mmcv.mkdir_or_exist(osp.join(out_dir, 'img_dir')) 33 | mmcv.mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train')) 34 | mmcv.mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val')) 35 | mmcv.mkdir_or_exist(osp.join(out_dir, 'img_dir', 'test')) 36 | mmcv.mkdir_or_exist(osp.join(out_dir, 'ann_dir')) 37 | mmcv.mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train')) 38 | mmcv.mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val')) 39 | 40 | assert 'Train.zip' in os.listdir(dataset_path), \ 41 | 'Train.zip is not in {}'.format(dataset_path) 42 | assert 'Val.zip' in os.listdir(dataset_path), \ 43 | 'Val.zip is not in {}'.format(dataset_path) 44 | assert 'Test.zip' in os.listdir(dataset_path), \ 45 | 'Test.zip is not in {}'.format(dataset_path) 46 | 47 | with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: 48 | for dataset in ['Train', 'Val', 'Test']: 49 | zip_file = zipfile.ZipFile( 50 | os.path.join(dataset_path, dataset + '.zip')) 51 | zip_file.extractall(tmp_dir) 52 | data_type = dataset.lower() 53 | for location in ['Rural', 'Urban']: 54 | for image_type in ['images_png', 'masks_png']: 55 | if image_type == 'images_png': 56 | dst = osp.join(out_dir, 'img_dir', data_type) 57 | else: 58 | dst = osp.join(out_dir, 'ann_dir', data_type) 59 | if dataset == 'Test' and image_type == 'masks_png': 60 | continue 61 | else: 62 | src_dir = osp.join(tmp_dir, dataset, location, 63 | image_type) 64 | src_lst = os.listdir(src_dir) 65 | for file in src_lst: 66 | shutil.move(osp.join(src_dir, file), dst) 67 | print('Removing the temporary files...') 68 | 69 | print('Done!') 70 | 71 | 72 | if __name__ == '__main__': 73 | main() 74 | -------------------------------------------------------------------------------- /tools/convert_datasets/pascal_context.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os.path as osp 4 | from functools import partial 5 | 6 | import mmcv 7 | import numpy as np 8 | from detail import Detail 9 | from PIL import Image 10 | 11 | _mapping = np.sort( 12 | np.array([ 13 | 0, 2, 259, 260, 415, 324, 9, 258, 144, 18, 19, 22, 23, 397, 25, 284, 14 | 158, 159, 416, 33, 162, 420, 454, 295, 296, 427, 44, 45, 46, 308, 59, 15 | 440, 445, 31, 232, 65, 354, 424, 68, 326, 72, 458, 34, 207, 80, 355, 16 | 85, 347, 220, 349, 360, 98, 187, 104, 105, 366, 189, 368, 113, 115 17 | ])) 18 | _key = np.array(range(len(_mapping))).astype('uint8') 19 | 20 | 21 | def generate_labels(img_id, detail, out_dir): 22 | 23 | def _class_to_index(mask, _mapping, _key): 24 | # assert the values 25 | values = np.unique(mask) 26 | for i in range(len(values)): 27 | assert (values[i] in _mapping) 28 | index = np.digitize(mask.ravel(), _mapping, right=True) 29 | return _key[index].reshape(mask.shape) 30 | 31 | mask = Image.fromarray( 32 | _class_to_index(detail.getMask(img_id), _mapping=_mapping, _key=_key)) 33 | filename = img_id['file_name'] 34 | mask.save(osp.join(out_dir, filename.replace('jpg', 'png'))) 35 | return osp.splitext(osp.basename(filename))[0] 36 | 37 | 38 | def parse_args(): 39 | parser = argparse.ArgumentParser( 40 | description='Convert PASCAL VOC annotations to mmsegmentation format') 41 | parser.add_argument('devkit_path', help='pascal voc devkit path') 42 | parser.add_argument('json_path', help='annoation json filepath') 43 | parser.add_argument('-o', '--out_dir', help='output path') 44 | args = parser.parse_args() 45 | return args 46 | 47 | 48 | def main(): 49 | args = parse_args() 50 | devkit_path = args.devkit_path 51 | if args.out_dir is None: 52 | out_dir = osp.join(devkit_path, 'VOC2010', 'SegmentationClassContext') 53 | else: 54 | out_dir = args.out_dir 55 | json_path = args.json_path 56 | mmcv.mkdir_or_exist(out_dir) 57 | img_dir = osp.join(devkit_path, 'VOC2010', 'JPEGImages') 58 | 59 | train_detail = Detail(json_path, img_dir, 'train') 60 | train_ids = train_detail.getImgs() 61 | 62 | val_detail = Detail(json_path, img_dir, 'val') 63 | val_ids = val_detail.getImgs() 64 | 65 | mmcv.mkdir_or_exist( 66 | osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext')) 67 | 68 | train_list = mmcv.track_progress( 69 | partial(generate_labels, detail=train_detail, out_dir=out_dir), 70 | train_ids) 71 | with open( 72 | osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext', 73 | 'train.txt'), 'w') as f: 74 | f.writelines(line + '\n' for line in sorted(train_list)) 75 | 76 | val_list = mmcv.track_progress( 77 | partial(generate_labels, detail=val_detail, out_dir=out_dir), val_ids) 78 | with open( 79 | osp.join(devkit_path, 'VOC2010/ImageSets/SegmentationContext', 80 | 'val.txt'), 'w') as f: 81 | f.writelines(line + '\n' for line in sorted(val_list)) 82 | 83 | print('Done!') 84 | 85 | 86 | if __name__ == '__main__': 87 | main() 88 | -------------------------------------------------------------------------------- /tools/convert_datasets/potsdam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import glob 4 | import math 5 | import os 6 | import os.path as osp 7 | import tempfile 8 | import zipfile 9 | 10 | import mmcv 11 | import numpy as np 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser( 16 | description='Convert potsdam dataset to mmsegmentation format') 17 | parser.add_argument('dataset_path', help='potsdam folder path') 18 | parser.add_argument('--tmp_dir', help='path of the temporary directory') 19 | parser.add_argument('-o', '--out_dir', help='output path') 20 | parser.add_argument( 21 | '--clip_size', 22 | type=int, 23 | help='clipped size of image after preparation', 24 | default=512) 25 | parser.add_argument( 26 | '--stride_size', 27 | type=int, 28 | help='stride of clipping original images', 29 | default=256) 30 | args = parser.parse_args() 31 | return args 32 | 33 | 34 | def clip_big_image(image_path, clip_save_dir, args, to_label=False): 35 | # Original image of Potsdam dataset is very large, thus pre-processing 36 | # of them is adopted. Given fixed clip size and stride size to generate 37 | # clipped image, the intersection of width and height is determined. 38 | # For example, given one 5120 x 5120 original image, the clip size is 39 | # 512 and stride size is 256, thus it would generate 20x20 = 400 images 40 | # whose size are all 512x512. 41 | image = mmcv.imread(image_path) 42 | 43 | h, w, c = image.shape 44 | clip_size = args.clip_size 45 | stride_size = args.stride_size 46 | 47 | num_rows = math.ceil((h - clip_size) / stride_size) if math.ceil( 48 | (h - clip_size) / 49 | stride_size) * stride_size + clip_size >= h else math.ceil( 50 | (h - clip_size) / stride_size) + 1 51 | num_cols = math.ceil((w - clip_size) / stride_size) if math.ceil( 52 | (w - clip_size) / 53 | stride_size) * stride_size + clip_size >= w else math.ceil( 54 | (w - clip_size) / stride_size) + 1 55 | 56 | x, y = np.meshgrid(np.arange(num_cols + 1), np.arange(num_rows + 1)) 57 | xmin = x * clip_size 58 | ymin = y * clip_size 59 | 60 | xmin = xmin.ravel() 61 | ymin = ymin.ravel() 62 | xmin_offset = np.where(xmin + clip_size > w, w - xmin - clip_size, 63 | np.zeros_like(xmin)) 64 | ymin_offset = np.where(ymin + clip_size > h, h - ymin - clip_size, 65 | np.zeros_like(ymin)) 66 | boxes = np.stack([ 67 | xmin + xmin_offset, ymin + ymin_offset, 68 | np.minimum(xmin + clip_size, w), 69 | np.minimum(ymin + clip_size, h) 70 | ], 71 | axis=1) 72 | 73 | if to_label: 74 | color_map = np.array([[0, 0, 0], [255, 255, 255], [255, 0, 0], 75 | [255, 255, 0], [0, 255, 0], [0, 255, 255], 76 | [0, 0, 255]]) 77 | flatten_v = np.matmul( 78 | image.reshape(-1, c), 79 | np.array([2, 3, 4]).reshape(3, 1)) 80 | out = np.zeros_like(flatten_v) 81 | for idx, class_color in enumerate(color_map): 82 | value_idx = np.matmul(class_color, 83 | np.array([2, 3, 4]).reshape(3, 1)) 84 | out[flatten_v == value_idx] = idx 85 | image = out.reshape(h, w) 86 | 87 | for box in boxes: 88 | start_x, start_y, end_x, end_y = box 89 | clipped_image = image[start_y:end_y, 90 | start_x:end_x] if to_label else image[ 91 | start_y:end_y, start_x:end_x, :] 92 | idx_i, idx_j = osp.basename(image_path).split('_')[2:4] 93 | mmcv.imwrite( 94 | clipped_image.astype(np.uint8), 95 | osp.join( 96 | clip_save_dir, 97 | f'{idx_i}_{idx_j}_{start_x}_{start_y}_{end_x}_{end_y}.png')) 98 | 99 | 100 | def main(): 101 | args = parse_args() 102 | splits = { 103 | 'train': [ 104 | '2_10', '2_11', '2_12', '3_10', '3_11', '3_12', '4_10', '4_11', 105 | '4_12', '5_10', '5_11', '5_12', '6_10', '6_11', '6_12', '6_7', 106 | '6_8', '6_9', '7_10', '7_11', '7_12', '7_7', '7_8', '7_9' 107 | ], 108 | 'val': [ 109 | '5_15', '6_15', '6_13', '3_13', '4_14', '6_14', '5_14', '2_13', 110 | '4_15', '2_14', '5_13', '4_13', '3_14', '7_13' 111 | ] 112 | } 113 | 114 | dataset_path = args.dataset_path 115 | if args.out_dir is None: 116 | out_dir = osp.join('data', 'potsdam') 117 | else: 118 | out_dir = args.out_dir 119 | 120 | print('Making directories...') 121 | mmcv.mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train')) 122 | mmcv.mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val')) 123 | mmcv.mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train')) 124 | mmcv.mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val')) 125 | 126 | zipp_list = glob.glob(os.path.join(dataset_path, '*.zip')) 127 | print('Find the data', zipp_list) 128 | 129 | for zipp in zipp_list: 130 | with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: 131 | zip_file = zipfile.ZipFile(zipp) 132 | zip_file.extractall(tmp_dir) 133 | src_path_list = glob.glob(os.path.join(tmp_dir, '*.tif')) 134 | if not len(src_path_list): 135 | sub_tmp_dir = os.path.join(tmp_dir, os.listdir(tmp_dir)[0]) 136 | src_path_list = glob.glob(os.path.join(sub_tmp_dir, '*.tif')) 137 | 138 | prog_bar = mmcv.ProgressBar(len(src_path_list)) 139 | for i, src_path in enumerate(src_path_list): 140 | idx_i, idx_j = osp.basename(src_path).split('_')[2:4] 141 | data_type = 'train' if f'{idx_i}_{idx_j}' in splits[ 142 | 'train'] else 'val' 143 | if 'label' in src_path: 144 | dst_dir = osp.join(out_dir, 'ann_dir', data_type) 145 | clip_big_image(src_path, dst_dir, args, to_label=True) 146 | else: 147 | dst_dir = osp.join(out_dir, 'img_dir', data_type) 148 | clip_big_image(src_path, dst_dir, args, to_label=False) 149 | prog_bar.update() 150 | 151 | print('Removing the temporary files...') 152 | 153 | print('Done!') 154 | 155 | 156 | if __name__ == '__main__': 157 | main() 158 | -------------------------------------------------------------------------------- /tools/convert_datasets/stare.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import gzip 4 | import os 5 | import os.path as osp 6 | import tarfile 7 | import tempfile 8 | 9 | import mmcv 10 | 11 | STARE_LEN = 20 12 | TRAINING_LEN = 10 13 | 14 | 15 | def un_gz(src, dst): 16 | g_file = gzip.GzipFile(src) 17 | with open(dst, 'wb+') as f: 18 | f.write(g_file.read()) 19 | g_file.close() 20 | 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser( 24 | description='Convert STARE dataset to mmsegmentation format') 25 | parser.add_argument('image_path', help='the path of stare-images.tar') 26 | parser.add_argument('labels_ah', help='the path of labels-ah.tar') 27 | parser.add_argument('labels_vk', help='the path of labels-vk.tar') 28 | parser.add_argument('--tmp_dir', help='path of the temporary directory') 29 | parser.add_argument('-o', '--out_dir', help='output path') 30 | args = parser.parse_args() 31 | return args 32 | 33 | 34 | def main(): 35 | args = parse_args() 36 | image_path = args.image_path 37 | labels_ah = args.labels_ah 38 | labels_vk = args.labels_vk 39 | if args.out_dir is None: 40 | out_dir = osp.join('data', 'STARE') 41 | else: 42 | out_dir = args.out_dir 43 | 44 | print('Making directories...') 45 | mmcv.mkdir_or_exist(out_dir) 46 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images')) 47 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'training')) 48 | mmcv.mkdir_or_exist(osp.join(out_dir, 'images', 'validation')) 49 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations')) 50 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'training')) 51 | mmcv.mkdir_or_exist(osp.join(out_dir, 'annotations', 'validation')) 52 | 53 | with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: 54 | mmcv.mkdir_or_exist(osp.join(tmp_dir, 'gz')) 55 | mmcv.mkdir_or_exist(osp.join(tmp_dir, 'files')) 56 | 57 | print('Extracting stare-images.tar...') 58 | with tarfile.open(image_path) as f: 59 | f.extractall(osp.join(tmp_dir, 'gz')) 60 | 61 | for filename in os.listdir(osp.join(tmp_dir, 'gz')): 62 | un_gz( 63 | osp.join(tmp_dir, 'gz', filename), 64 | osp.join(tmp_dir, 'files', 65 | osp.splitext(filename)[0])) 66 | 67 | now_dir = osp.join(tmp_dir, 'files') 68 | 69 | assert len(os.listdir(now_dir)) == STARE_LEN, \ 70 | 'len(os.listdir(now_dir)) != {}'.format(STARE_LEN) 71 | 72 | for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]: 73 | img = mmcv.imread(osp.join(now_dir, filename)) 74 | mmcv.imwrite( 75 | img, 76 | osp.join(out_dir, 'images', 'training', 77 | osp.splitext(filename)[0] + '.png')) 78 | 79 | for filename in sorted(os.listdir(now_dir))[TRAINING_LEN:]: 80 | img = mmcv.imread(osp.join(now_dir, filename)) 81 | mmcv.imwrite( 82 | img, 83 | osp.join(out_dir, 'images', 'validation', 84 | osp.splitext(filename)[0] + '.png')) 85 | 86 | print('Removing the temporary files...') 87 | 88 | with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: 89 | mmcv.mkdir_or_exist(osp.join(tmp_dir, 'gz')) 90 | mmcv.mkdir_or_exist(osp.join(tmp_dir, 'files')) 91 | 92 | print('Extracting labels-ah.tar...') 93 | with tarfile.open(labels_ah) as f: 94 | f.extractall(osp.join(tmp_dir, 'gz')) 95 | 96 | for filename in os.listdir(osp.join(tmp_dir, 'gz')): 97 | un_gz( 98 | osp.join(tmp_dir, 'gz', filename), 99 | osp.join(tmp_dir, 'files', 100 | osp.splitext(filename)[0])) 101 | 102 | now_dir = osp.join(tmp_dir, 'files') 103 | 104 | assert len(os.listdir(now_dir)) == STARE_LEN, \ 105 | 'len(os.listdir(now_dir)) != {}'.format(STARE_LEN) 106 | 107 | for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]: 108 | img = mmcv.imread(osp.join(now_dir, filename)) 109 | # The annotation img should be divided by 128, because some of 110 | # the annotation imgs are not standard. We should set a threshold 111 | # to convert the nonstandard annotation imgs. The value divided by 112 | # 128 equivalent to '1 if value >= 128 else 0' 113 | mmcv.imwrite( 114 | img[:, :, 0] // 128, 115 | osp.join(out_dir, 'annotations', 'training', 116 | osp.splitext(filename)[0] + '.png')) 117 | 118 | for filename in sorted(os.listdir(now_dir))[TRAINING_LEN:]: 119 | img = mmcv.imread(osp.join(now_dir, filename)) 120 | mmcv.imwrite( 121 | img[:, :, 0] // 128, 122 | osp.join(out_dir, 'annotations', 'validation', 123 | osp.splitext(filename)[0] + '.png')) 124 | 125 | print('Removing the temporary files...') 126 | 127 | with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: 128 | mmcv.mkdir_or_exist(osp.join(tmp_dir, 'gz')) 129 | mmcv.mkdir_or_exist(osp.join(tmp_dir, 'files')) 130 | 131 | print('Extracting labels-vk.tar...') 132 | with tarfile.open(labels_vk) as f: 133 | f.extractall(osp.join(tmp_dir, 'gz')) 134 | 135 | for filename in os.listdir(osp.join(tmp_dir, 'gz')): 136 | un_gz( 137 | osp.join(tmp_dir, 'gz', filename), 138 | osp.join(tmp_dir, 'files', 139 | osp.splitext(filename)[0])) 140 | 141 | now_dir = osp.join(tmp_dir, 'files') 142 | 143 | assert len(os.listdir(now_dir)) == STARE_LEN, \ 144 | 'len(os.listdir(now_dir)) != {}'.format(STARE_LEN) 145 | 146 | for filename in sorted(os.listdir(now_dir))[:TRAINING_LEN]: 147 | img = mmcv.imread(osp.join(now_dir, filename)) 148 | mmcv.imwrite( 149 | img[:, :, 0] // 128, 150 | osp.join(out_dir, 'annotations', 'training', 151 | osp.splitext(filename)[0] + '.png')) 152 | 153 | for filename in sorted(os.listdir(now_dir))[TRAINING_LEN:]: 154 | img = mmcv.imread(osp.join(now_dir, filename)) 155 | mmcv.imwrite( 156 | img[:, :, 0] // 128, 157 | osp.join(out_dir, 'annotations', 'validation', 158 | osp.splitext(filename)[0] + '.png')) 159 | 160 | print('Removing the temporary files...') 161 | 162 | print('Done!') 163 | 164 | 165 | if __name__ == '__main__': 166 | main() 167 | -------------------------------------------------------------------------------- /tools/convert_datasets/vaihingen.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import glob 4 | import math 5 | import os 6 | import os.path as osp 7 | import tempfile 8 | import zipfile 9 | 10 | import mmcv 11 | import numpy as np 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser( 16 | description='Convert vaihingen dataset to mmsegmentation format') 17 | parser.add_argument('dataset_path', help='vaihingen folder path') 18 | parser.add_argument('--tmp_dir', help='path of the temporary directory') 19 | parser.add_argument('-o', '--out_dir', help='output path') 20 | parser.add_argument( 21 | '--clip_size', 22 | type=int, 23 | help='clipped size of image after preparation', 24 | default=512) 25 | parser.add_argument( 26 | '--stride_size', 27 | type=int, 28 | help='stride of clipping original images', 29 | default=256) 30 | args = parser.parse_args() 31 | return args 32 | 33 | 34 | def clip_big_image(image_path, clip_save_dir, to_label=False): 35 | # Original image of Vaihingen dataset is very large, thus pre-processing 36 | # of them is adopted. Given fixed clip size and stride size to generate 37 | # clipped image, the intersection of width and height is determined. 38 | # For example, given one 5120 x 5120 original image, the clip size is 39 | # 512 and stride size is 256, thus it would generate 20x20 = 400 images 40 | # whose size are all 512x512. 41 | image = mmcv.imread(image_path) 42 | 43 | h, w, c = image.shape 44 | cs = args.clip_size 45 | ss = args.stride_size 46 | 47 | num_rows = math.ceil((h - cs) / ss) if math.ceil( 48 | (h - cs) / ss) * ss + cs >= h else math.ceil((h - cs) / ss) + 1 49 | num_cols = math.ceil((w - cs) / ss) if math.ceil( 50 | (w - cs) / ss) * ss + cs >= w else math.ceil((w - cs) / ss) + 1 51 | 52 | x, y = np.meshgrid(np.arange(num_cols + 1), np.arange(num_rows + 1)) 53 | xmin = x * cs 54 | ymin = y * cs 55 | 56 | xmin = xmin.ravel() 57 | ymin = ymin.ravel() 58 | xmin_offset = np.where(xmin + cs > w, w - xmin - cs, np.zeros_like(xmin)) 59 | ymin_offset = np.where(ymin + cs > h, h - ymin - cs, np.zeros_like(ymin)) 60 | boxes = np.stack([ 61 | xmin + xmin_offset, ymin + ymin_offset, 62 | np.minimum(xmin + cs, w), 63 | np.minimum(ymin + cs, h) 64 | ], 65 | axis=1) 66 | 67 | if to_label: 68 | color_map = np.array([[0, 0, 0], [255, 255, 255], [255, 0, 0], 69 | [255, 255, 0], [0, 255, 0], [0, 255, 255], 70 | [0, 0, 255]]) 71 | flatten_v = np.matmul( 72 | image.reshape(-1, c), 73 | np.array([2, 3, 4]).reshape(3, 1)) 74 | out = np.zeros_like(flatten_v) 75 | for idx, class_color in enumerate(color_map): 76 | value_idx = np.matmul(class_color, 77 | np.array([2, 3, 4]).reshape(3, 1)) 78 | out[flatten_v == value_idx] = idx 79 | image = out.reshape(h, w) 80 | 81 | for box in boxes: 82 | start_x, start_y, end_x, end_y = box 83 | clipped_image = image[start_y:end_y, 84 | start_x:end_x] if to_label else image[ 85 | start_y:end_y, start_x:end_x, :] 86 | area_idx = osp.basename(image_path).split('_')[3].strip('.tif') 87 | mmcv.imwrite( 88 | clipped_image.astype(np.uint8), 89 | osp.join(clip_save_dir, 90 | f'{area_idx}_{start_x}_{start_y}_{end_x}_{end_y}.png')) 91 | 92 | 93 | def main(): 94 | splits = { 95 | 'train': [ 96 | 'area1', 'area11', 'area13', 'area15', 'area17', 'area21', 97 | 'area23', 'area26', 'area28', 'area3', 'area30', 'area32', 98 | 'area34', 'area37', 'area5', 'area7' 99 | ], 100 | 'val': [ 101 | 'area6', 'area24', 'area35', 'area16', 'area14', 'area22', 102 | 'area10', 'area4', 'area2', 'area20', 'area8', 'area31', 'area33', 103 | 'area27', 'area38', 'area12', 'area29' 104 | ], 105 | } 106 | 107 | dataset_path = args.dataset_path 108 | if args.out_dir is None: 109 | out_dir = osp.join('data', 'vaihingen') 110 | else: 111 | out_dir = args.out_dir 112 | 113 | print('Making directories...') 114 | mmcv.mkdir_or_exist(osp.join(out_dir, 'img_dir', 'train')) 115 | mmcv.mkdir_or_exist(osp.join(out_dir, 'img_dir', 'val')) 116 | mmcv.mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'train')) 117 | mmcv.mkdir_or_exist(osp.join(out_dir, 'ann_dir', 'val')) 118 | 119 | zipp_list = glob.glob(os.path.join(dataset_path, '*.zip')) 120 | print('Find the data', zipp_list) 121 | 122 | with tempfile.TemporaryDirectory(dir=args.tmp_dir) as tmp_dir: 123 | for zipp in zipp_list: 124 | zip_file = zipfile.ZipFile(zipp) 125 | zip_file.extractall(tmp_dir) 126 | src_path_list = glob.glob(os.path.join(tmp_dir, '*.tif')) 127 | if 'ISPRS_semantic_labeling_Vaihingen' in zipp: 128 | src_path_list = glob.glob( 129 | os.path.join(os.path.join(tmp_dir, 'top'), '*.tif')) 130 | if 'ISPRS_semantic_labeling_Vaihingen_ground_truth_eroded_COMPLETE' in zipp: # noqa 131 | src_path_list = glob.glob(os.path.join(tmp_dir, '*.tif')) 132 | # delete unused area9 ground truth 133 | for area_ann in src_path_list: 134 | if 'area9' in area_ann: 135 | src_path_list.remove(area_ann) 136 | prog_bar = mmcv.ProgressBar(len(src_path_list)) 137 | for i, src_path in enumerate(src_path_list): 138 | area_idx = osp.basename(src_path).split('_')[3].strip('.tif') 139 | data_type = 'train' if area_idx in splits['train'] else 'val' 140 | if 'noBoundary' in src_path: 141 | dst_dir = osp.join(out_dir, 'ann_dir', data_type) 142 | clip_big_image(src_path, dst_dir, to_label=True) 143 | else: 144 | dst_dir = osp.join(out_dir, 'img_dir', data_type) 145 | clip_big_image(src_path, dst_dir, to_label=False) 146 | prog_bar.update() 147 | 148 | print('Removing the temporary files...') 149 | 150 | print('Done!') 151 | 152 | 153 | if __name__ == '__main__': 154 | args = parse_args() 155 | main() 156 | -------------------------------------------------------------------------------- /tools/convert_datasets/voc_aug.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os.path as osp 4 | from functools import partial 5 | 6 | import mmcv 7 | import numpy as np 8 | from PIL import Image 9 | from scipy.io import loadmat 10 | 11 | AUG_LEN = 10582 12 | 13 | 14 | def convert_mat(mat_file, in_dir, out_dir): 15 | data = loadmat(osp.join(in_dir, mat_file)) 16 | mask = data['GTcls'][0]['Segmentation'][0].astype(np.uint8) 17 | seg_filename = osp.join(out_dir, mat_file.replace('.mat', '.png')) 18 | Image.fromarray(mask).save(seg_filename, 'PNG') 19 | 20 | 21 | def generate_aug_list(merged_list, excluded_list): 22 | return list(set(merged_list) - set(excluded_list)) 23 | 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser( 27 | description='Convert PASCAL VOC annotations to mmsegmentation format') 28 | parser.add_argument('devkit_path', help='pascal voc devkit path') 29 | parser.add_argument('aug_path', help='pascal voc aug path') 30 | parser.add_argument('-o', '--out_dir', help='output path') 31 | parser.add_argument( 32 | '--nproc', default=1, type=int, help='number of process') 33 | args = parser.parse_args() 34 | return args 35 | 36 | 37 | def main(): 38 | args = parse_args() 39 | devkit_path = args.devkit_path 40 | aug_path = args.aug_path 41 | nproc = args.nproc 42 | if args.out_dir is None: 43 | out_dir = osp.join(devkit_path, 'VOC2012', 'SegmentationClassAug') 44 | else: 45 | out_dir = args.out_dir 46 | mmcv.mkdir_or_exist(out_dir) 47 | in_dir = osp.join(aug_path, 'dataset', 'cls') 48 | 49 | mmcv.track_parallel_progress( 50 | partial(convert_mat, in_dir=in_dir, out_dir=out_dir), 51 | list(mmcv.scandir(in_dir, suffix='.mat')), 52 | nproc=nproc) 53 | 54 | full_aug_list = [] 55 | with open(osp.join(aug_path, 'dataset', 'train.txt')) as f: 56 | full_aug_list += [line.strip() for line in f] 57 | with open(osp.join(aug_path, 'dataset', 'val.txt')) as f: 58 | full_aug_list += [line.strip() for line in f] 59 | 60 | with open( 61 | osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', 62 | 'train.txt')) as f: 63 | ori_train_list = [line.strip() for line in f] 64 | with open( 65 | osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', 66 | 'val.txt')) as f: 67 | val_list = [line.strip() for line in f] 68 | 69 | aug_train_list = generate_aug_list(ori_train_list + full_aug_list, 70 | val_list) 71 | assert len(aug_train_list) == AUG_LEN, 'len(aug_train_list) != {}'.format( 72 | AUG_LEN) 73 | 74 | with open( 75 | osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', 76 | 'trainaug.txt'), 'w') as f: 77 | f.writelines(line + '\n' for line in aug_train_list) 78 | 79 | aug_list = generate_aug_list(full_aug_list, ori_train_list + val_list) 80 | assert len(aug_list) == AUG_LEN - len( 81 | ori_train_list), 'len(aug_list) != {}'.format(AUG_LEN - 82 | len(ori_train_list)) 83 | with open( 84 | osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', 'aug.txt'), 85 | 'w') as f: 86 | f.writelines(line + '\n' for line in aug_list) 87 | 88 | print('Done!') 89 | 90 | 91 | if __name__ == '__main__': 92 | main() 93 | -------------------------------------------------------------------------------- /tools/flops.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | 5 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 6 | python $(dirname "$0")/get_flops.py $CONFIG ${@:2} 7 | -------------------------------------------------------------------------------- /tools/get_flops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | 4 | from mmcv import Config 5 | from mmcv.cnn import get_model_complexity_info 6 | 7 | from mmseg.models import build_segmentor 8 | import van 9 | 10 | import torch 11 | from torchprofile import profile_macs 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser(description='Train a segmentor') 16 | parser.add_argument('config', help='train config file path') 17 | parser.add_argument( 18 | '--shape', 19 | type=int, 20 | nargs='+', 21 | default=[2048, 1024], 22 | help='input image size') 23 | args = parser.parse_args() 24 | return args 25 | 26 | 27 | def main(): 28 | 29 | args = parse_args() 30 | 31 | if len(args.shape) == 1: 32 | input_shape = (3, args.shape[0], args.shape[0]) 33 | elif len(args.shape) == 2: 34 | input_shape = (3, ) + tuple(args.shape) 35 | else: 36 | raise ValueError('invalid input shape') 37 | 38 | cfg = Config.fromfile(args.config) 39 | cfg.model.pretrained = None 40 | model = build_segmentor( 41 | cfg.model, 42 | train_cfg=cfg.get('train_cfg'), 43 | test_cfg=cfg.get('test_cfg')).cuda() 44 | model.eval() 45 | 46 | if hasattr(model, 'forward_dummy'): 47 | model.forward = model.forward_dummy 48 | else: 49 | raise NotImplementedError( 50 | 'FLOPs counter is currently not currently supported with {}'. 51 | format(model.__class__.__name__)) 52 | 53 | flops, params = get_model_complexity_info(model, input_shape) 54 | split_line = '=' * 30 55 | print('{0}\nInput shape: {1}\nFlops: {2}\nParams: {3}\n{0}'.format( 56 | split_line, input_shape, flops, params)) 57 | 58 | inputs = torch.randn(1, *input_shape).cuda() 59 | macs = profile_macs(model, inputs) / 1e9 60 | print(f'GFLOPs {macs}.') 61 | 62 | 63 | if __name__ == '__main__': 64 | main() 65 | -------------------------------------------------------------------------------- /tools/model_converters/beit2mmseg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os.path as osp 4 | from collections import OrderedDict 5 | 6 | import mmcv 7 | import torch 8 | from mmcv.runner import CheckpointLoader 9 | 10 | 11 | def convert_beit(ckpt): 12 | new_ckpt = OrderedDict() 13 | 14 | for k, v in ckpt.items(): 15 | if k.startswith('blocks'): 16 | new_key = k.replace('blocks', 'layers') 17 | if 'norm' in new_key: 18 | new_key = new_key.replace('norm', 'ln') 19 | elif 'mlp.fc1' in new_key: 20 | new_key = new_key.replace('mlp.fc1', 'ffn.layers.0.0') 21 | elif 'mlp.fc2' in new_key: 22 | new_key = new_key.replace('mlp.fc2', 'ffn.layers.1') 23 | new_ckpt[new_key] = v 24 | elif k.startswith('patch_embed'): 25 | new_key = k.replace('patch_embed.proj', 'patch_embed.projection') 26 | new_ckpt[new_key] = v 27 | else: 28 | new_key = k 29 | new_ckpt[new_key] = v 30 | 31 | return new_ckpt 32 | 33 | 34 | def main(): 35 | parser = argparse.ArgumentParser( 36 | description='Convert keys in official pretrained beit models to' 37 | 'MMSegmentation style.') 38 | parser.add_argument('src', help='src model path or url') 39 | # The dst path must be a full path of the new checkpoint. 40 | parser.add_argument('dst', help='save path') 41 | args = parser.parse_args() 42 | 43 | checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') 44 | if 'state_dict' in checkpoint: 45 | state_dict = checkpoint['state_dict'] 46 | elif 'model' in checkpoint: 47 | state_dict = checkpoint['model'] 48 | else: 49 | state_dict = checkpoint 50 | weight = convert_beit(state_dict) 51 | mmcv.mkdir_or_exist(osp.dirname(args.dst)) 52 | torch.save(weight, args.dst) 53 | 54 | 55 | if __name__ == '__main__': 56 | main() 57 | -------------------------------------------------------------------------------- /tools/model_converters/mit2mmseg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os.path as osp 4 | from collections import OrderedDict 5 | 6 | import mmcv 7 | import torch 8 | from mmcv.runner import CheckpointLoader 9 | 10 | 11 | def convert_mit(ckpt): 12 | new_ckpt = OrderedDict() 13 | # Process the concat between q linear weights and kv linear weights 14 | for k, v in ckpt.items(): 15 | if k.startswith('head'): 16 | continue 17 | # patch embedding conversion 18 | elif k.startswith('patch_embed'): 19 | stage_i = int(k.split('.')[0].replace('patch_embed', '')) 20 | new_k = k.replace(f'patch_embed{stage_i}', f'layers.{stage_i-1}.0') 21 | new_v = v 22 | if 'proj.' in new_k: 23 | new_k = new_k.replace('proj.', 'projection.') 24 | # transformer encoder layer conversion 25 | elif k.startswith('block'): 26 | stage_i = int(k.split('.')[0].replace('block', '')) 27 | new_k = k.replace(f'block{stage_i}', f'layers.{stage_i-1}.1') 28 | new_v = v 29 | if 'attn.q.' in new_k: 30 | sub_item_k = k.replace('q.', 'kv.') 31 | new_k = new_k.replace('q.', 'attn.in_proj_') 32 | new_v = torch.cat([v, ckpt[sub_item_k]], dim=0) 33 | elif 'attn.kv.' in new_k: 34 | continue 35 | elif 'attn.proj.' in new_k: 36 | new_k = new_k.replace('proj.', 'attn.out_proj.') 37 | elif 'attn.sr.' in new_k: 38 | new_k = new_k.replace('sr.', 'sr.') 39 | elif 'mlp.' in new_k: 40 | string = f'{new_k}-' 41 | new_k = new_k.replace('mlp.', 'ffn.layers.') 42 | if 'fc1.weight' in new_k or 'fc2.weight' in new_k: 43 | new_v = v.reshape((*v.shape, 1, 1)) 44 | new_k = new_k.replace('fc1.', '0.') 45 | new_k = new_k.replace('dwconv.dwconv.', '1.') 46 | new_k = new_k.replace('fc2.', '4.') 47 | string += f'{new_k} {v.shape}-{new_v.shape}' 48 | # norm layer conversion 49 | elif k.startswith('norm'): 50 | stage_i = int(k.split('.')[0].replace('norm', '')) 51 | new_k = k.replace(f'norm{stage_i}', f'layers.{stage_i-1}.2') 52 | new_v = v 53 | else: 54 | new_k = k 55 | new_v = v 56 | new_ckpt[new_k] = new_v 57 | return new_ckpt 58 | 59 | 60 | def main(): 61 | parser = argparse.ArgumentParser( 62 | description='Convert keys in official pretrained segformer to ' 63 | 'MMSegmentation style.') 64 | parser.add_argument('src', help='src model path or url') 65 | # The dst path must be a full path of the new checkpoint. 66 | parser.add_argument('dst', help='save path') 67 | args = parser.parse_args() 68 | 69 | checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') 70 | if 'state_dict' in checkpoint: 71 | state_dict = checkpoint['state_dict'] 72 | elif 'model' in checkpoint: 73 | state_dict = checkpoint['model'] 74 | else: 75 | state_dict = checkpoint 76 | weight = convert_mit(state_dict) 77 | mmcv.mkdir_or_exist(osp.dirname(args.dst)) 78 | torch.save(weight, args.dst) 79 | 80 | 81 | if __name__ == '__main__': 82 | main() 83 | -------------------------------------------------------------------------------- /tools/model_converters/stdc2mmseg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os.path as osp 4 | 5 | import mmcv 6 | import torch 7 | from mmcv.runner import CheckpointLoader 8 | 9 | 10 | def convert_stdc(ckpt, stdc_type): 11 | new_state_dict = {} 12 | if stdc_type == 'STDC1': 13 | stage_lst = ['0', '1', '2.0', '2.1', '3.0', '3.1', '4.0', '4.1'] 14 | else: 15 | stage_lst = [ 16 | '0', '1', '2.0', '2.1', '2.2', '2.3', '3.0', '3.1', '3.2', '3.3', 17 | '3.4', '4.0', '4.1', '4.2' 18 | ] 19 | for k, v in ckpt.items(): 20 | ori_k = k 21 | flag = False 22 | if 'cp.' in k: 23 | k = k.replace('cp.', '') 24 | if 'features.' in k: 25 | num_layer = int(k.split('.')[1]) 26 | feature_key_lst = 'features.' + str(num_layer) + '.' 27 | stages_key_lst = 'stages.' + stage_lst[num_layer] + '.' 28 | k = k.replace(feature_key_lst, stages_key_lst) 29 | flag = True 30 | if 'conv_list' in k: 31 | k = k.replace('conv_list', 'layers') 32 | flag = True 33 | if 'avd_layer.' in k: 34 | if 'avd_layer.0' in k: 35 | k = k.replace('avd_layer.0', 'downsample.conv') 36 | elif 'avd_layer.1' in k: 37 | k = k.replace('avd_layer.1', 'downsample.bn') 38 | flag = True 39 | if flag: 40 | new_state_dict[k] = ckpt[ori_k] 41 | 42 | return new_state_dict 43 | 44 | 45 | def main(): 46 | parser = argparse.ArgumentParser( 47 | description='Convert keys in official pretrained STDC1/2 to ' 48 | 'MMSegmentation style.') 49 | parser.add_argument('src', help='src model path') 50 | # The dst path must be a full path of the new checkpoint. 51 | parser.add_argument('dst', help='save path') 52 | parser.add_argument('type', help='model type: STDC1 or STDC2') 53 | args = parser.parse_args() 54 | 55 | checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') 56 | if 'state_dict' in checkpoint: 57 | state_dict = checkpoint['state_dict'] 58 | elif 'model' in checkpoint: 59 | state_dict = checkpoint['model'] 60 | else: 61 | state_dict = checkpoint 62 | 63 | assert args.type in ['STDC1', 64 | 'STDC2'], 'STD type should be STDC1 or STDC2!' 65 | weight = convert_stdc(state_dict, args.type) 66 | mmcv.mkdir_or_exist(osp.dirname(args.dst)) 67 | torch.save(weight, args.dst) 68 | 69 | 70 | if __name__ == '__main__': 71 | main() 72 | -------------------------------------------------------------------------------- /tools/model_converters/swin2mmseg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os.path as osp 4 | from collections import OrderedDict 5 | 6 | import mmcv 7 | import torch 8 | from mmcv.runner import CheckpointLoader 9 | 10 | 11 | def convert_swin(ckpt): 12 | new_ckpt = OrderedDict() 13 | 14 | def correct_unfold_reduction_order(x): 15 | out_channel, in_channel = x.shape 16 | x = x.reshape(out_channel, 4, in_channel // 4) 17 | x = x[:, [0, 2, 1, 3], :].transpose(1, 18 | 2).reshape(out_channel, in_channel) 19 | return x 20 | 21 | def correct_unfold_norm_order(x): 22 | in_channel = x.shape[0] 23 | x = x.reshape(4, in_channel // 4) 24 | x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel) 25 | return x 26 | 27 | for k, v in ckpt.items(): 28 | if k.startswith('head'): 29 | continue 30 | elif k.startswith('layers'): 31 | new_v = v 32 | if 'attn.' in k: 33 | new_k = k.replace('attn.', 'attn.w_msa.') 34 | elif 'mlp.' in k: 35 | if 'mlp.fc1.' in k: 36 | new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.') 37 | elif 'mlp.fc2.' in k: 38 | new_k = k.replace('mlp.fc2.', 'ffn.layers.1.') 39 | else: 40 | new_k = k.replace('mlp.', 'ffn.') 41 | elif 'downsample' in k: 42 | new_k = k 43 | if 'reduction.' in k: 44 | new_v = correct_unfold_reduction_order(v) 45 | elif 'norm.' in k: 46 | new_v = correct_unfold_norm_order(v) 47 | else: 48 | new_k = k 49 | new_k = new_k.replace('layers', 'stages', 1) 50 | elif k.startswith('patch_embed'): 51 | new_v = v 52 | if 'proj' in k: 53 | new_k = k.replace('proj', 'projection') 54 | else: 55 | new_k = k 56 | else: 57 | new_v = v 58 | new_k = k 59 | 60 | new_ckpt[new_k] = new_v 61 | 62 | return new_ckpt 63 | 64 | 65 | def main(): 66 | parser = argparse.ArgumentParser( 67 | description='Convert keys in official pretrained swin models to' 68 | 'MMSegmentation style.') 69 | parser.add_argument('src', help='src model path or url') 70 | # The dst path must be a full path of the new checkpoint. 71 | parser.add_argument('dst', help='save path') 72 | args = parser.parse_args() 73 | 74 | checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') 75 | if 'state_dict' in checkpoint: 76 | state_dict = checkpoint['state_dict'] 77 | elif 'model' in checkpoint: 78 | state_dict = checkpoint['model'] 79 | else: 80 | state_dict = checkpoint 81 | weight = convert_swin(state_dict) 82 | mmcv.mkdir_or_exist(osp.dirname(args.dst)) 83 | torch.save(weight, args.dst) 84 | 85 | 86 | if __name__ == '__main__': 87 | main() 88 | -------------------------------------------------------------------------------- /tools/model_converters/twins2mmseg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os.path as osp 4 | from collections import OrderedDict 5 | 6 | import mmcv 7 | import torch 8 | from mmcv.runner import CheckpointLoader 9 | 10 | 11 | def convert_twins(args, ckpt): 12 | 13 | new_ckpt = OrderedDict() 14 | 15 | for k, v in list(ckpt.items()): 16 | new_v = v 17 | if k.startswith('head'): 18 | continue 19 | elif k.startswith('patch_embeds'): 20 | if 'proj.' in k: 21 | new_k = k.replace('proj.', 'projection.') 22 | else: 23 | new_k = k 24 | elif k.startswith('blocks'): 25 | # Union 26 | if 'attn.q.' in k: 27 | new_k = k.replace('q.', 'attn.in_proj_') 28 | new_v = torch.cat([v, ckpt[k.replace('attn.q.', 'attn.kv.')]], 29 | dim=0) 30 | elif 'mlp.fc1' in k: 31 | new_k = k.replace('mlp.fc1', 'ffn.layers.0.0') 32 | elif 'mlp.fc2' in k: 33 | new_k = k.replace('mlp.fc2', 'ffn.layers.1') 34 | # Only pcpvt 35 | elif args.model == 'pcpvt': 36 | if 'attn.proj.' in k: 37 | new_k = k.replace('proj.', 'attn.out_proj.') 38 | else: 39 | new_k = k 40 | 41 | # Only svt 42 | else: 43 | if 'attn.proj.' in k: 44 | k_lst = k.split('.') 45 | if int(k_lst[2]) % 2 == 1: 46 | new_k = k.replace('proj.', 'attn.out_proj.') 47 | else: 48 | new_k = k 49 | else: 50 | new_k = k 51 | new_k = new_k.replace('blocks.', 'layers.') 52 | elif k.startswith('pos_block'): 53 | new_k = k.replace('pos_block', 'position_encodings') 54 | if 'proj.0.' in new_k: 55 | new_k = new_k.replace('proj.0.', 'proj.') 56 | else: 57 | new_k = k 58 | if 'attn.kv.' not in k: 59 | new_ckpt[new_k] = new_v 60 | return new_ckpt 61 | 62 | 63 | def main(): 64 | parser = argparse.ArgumentParser( 65 | description='Convert keys in timm pretrained vit models to ' 66 | 'MMSegmentation style.') 67 | parser.add_argument('src', help='src model path or url') 68 | # The dst path must be a full path of the new checkpoint. 69 | parser.add_argument('dst', help='save path') 70 | parser.add_argument('model', help='model: pcpvt or svt') 71 | args = parser.parse_args() 72 | 73 | checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') 74 | 75 | if 'state_dict' in checkpoint: 76 | # timm checkpoint 77 | state_dict = checkpoint['state_dict'] 78 | else: 79 | state_dict = checkpoint 80 | 81 | weight = convert_twins(args, state_dict) 82 | mmcv.mkdir_or_exist(osp.dirname(args.dst)) 83 | torch.save(weight, args.dst) 84 | 85 | 86 | if __name__ == '__main__': 87 | main() 88 | -------------------------------------------------------------------------------- /tools/model_converters/vit2mmseg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os.path as osp 4 | from collections import OrderedDict 5 | 6 | import mmcv 7 | import torch 8 | from mmcv.runner import CheckpointLoader 9 | 10 | 11 | def convert_vit(ckpt): 12 | 13 | new_ckpt = OrderedDict() 14 | 15 | for k, v in ckpt.items(): 16 | if k.startswith('head'): 17 | continue 18 | if k.startswith('norm'): 19 | new_k = k.replace('norm.', 'ln1.') 20 | elif k.startswith('patch_embed'): 21 | if 'proj' in k: 22 | new_k = k.replace('proj', 'projection') 23 | else: 24 | new_k = k 25 | elif k.startswith('blocks'): 26 | if 'norm' in k: 27 | new_k = k.replace('norm', 'ln') 28 | elif 'mlp.fc1' in k: 29 | new_k = k.replace('mlp.fc1', 'ffn.layers.0.0') 30 | elif 'mlp.fc2' in k: 31 | new_k = k.replace('mlp.fc2', 'ffn.layers.1') 32 | elif 'attn.qkv' in k: 33 | new_k = k.replace('attn.qkv.', 'attn.attn.in_proj_') 34 | elif 'attn.proj' in k: 35 | new_k = k.replace('attn.proj', 'attn.attn.out_proj') 36 | else: 37 | new_k = k 38 | new_k = new_k.replace('blocks.', 'layers.') 39 | else: 40 | new_k = k 41 | new_ckpt[new_k] = v 42 | 43 | return new_ckpt 44 | 45 | 46 | def main(): 47 | parser = argparse.ArgumentParser( 48 | description='Convert keys in timm pretrained vit models to ' 49 | 'MMSegmentation style.') 50 | parser.add_argument('src', help='src model path or url') 51 | # The dst path must be a full path of the new checkpoint. 52 | parser.add_argument('dst', help='save path') 53 | args = parser.parse_args() 54 | 55 | checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') 56 | if 'state_dict' in checkpoint: 57 | # timm checkpoint 58 | state_dict = checkpoint['state_dict'] 59 | elif 'model' in checkpoint: 60 | # deit checkpoint 61 | state_dict = checkpoint['model'] 62 | else: 63 | state_dict = checkpoint 64 | weight = convert_vit(state_dict) 65 | mmcv.mkdir_or_exist(osp.dirname(args.dst)) 66 | torch.save(weight, args.dst) 67 | 68 | 69 | if __name__ == '__main__': 70 | main() 71 | -------------------------------------------------------------------------------- /tools/model_converters/vitjax2mmseg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os.path as osp 4 | 5 | import mmcv 6 | import numpy as np 7 | import torch 8 | 9 | 10 | def vit_jax_to_torch(jax_weights, num_layer=12): 11 | torch_weights = dict() 12 | 13 | # patch embedding 14 | conv_filters = jax_weights['embedding/kernel'] 15 | conv_filters = conv_filters.permute(3, 2, 0, 1) 16 | torch_weights['patch_embed.projection.weight'] = conv_filters 17 | torch_weights['patch_embed.projection.bias'] = jax_weights[ 18 | 'embedding/bias'] 19 | 20 | # pos embedding 21 | torch_weights['pos_embed'] = jax_weights[ 22 | 'Transformer/posembed_input/pos_embedding'] 23 | 24 | # cls token 25 | torch_weights['cls_token'] = jax_weights['cls'] 26 | 27 | # head 28 | torch_weights['ln1.weight'] = jax_weights['Transformer/encoder_norm/scale'] 29 | torch_weights['ln1.bias'] = jax_weights['Transformer/encoder_norm/bias'] 30 | 31 | # transformer blocks 32 | for i in range(num_layer): 33 | jax_block = f'Transformer/encoderblock_{i}' 34 | torch_block = f'layers.{i}' 35 | 36 | # attention norm 37 | torch_weights[f'{torch_block}.ln1.weight'] = jax_weights[ 38 | f'{jax_block}/LayerNorm_0/scale'] 39 | torch_weights[f'{torch_block}.ln1.bias'] = jax_weights[ 40 | f'{jax_block}/LayerNorm_0/bias'] 41 | 42 | # attention 43 | query_weight = jax_weights[ 44 | f'{jax_block}/MultiHeadDotProductAttention_1/query/kernel'] 45 | query_bias = jax_weights[ 46 | f'{jax_block}/MultiHeadDotProductAttention_1/query/bias'] 47 | key_weight = jax_weights[ 48 | f'{jax_block}/MultiHeadDotProductAttention_1/key/kernel'] 49 | key_bias = jax_weights[ 50 | f'{jax_block}/MultiHeadDotProductAttention_1/key/bias'] 51 | value_weight = jax_weights[ 52 | f'{jax_block}/MultiHeadDotProductAttention_1/value/kernel'] 53 | value_bias = jax_weights[ 54 | f'{jax_block}/MultiHeadDotProductAttention_1/value/bias'] 55 | 56 | qkv_weight = torch.from_numpy( 57 | np.stack((query_weight, key_weight, value_weight), 1)) 58 | qkv_weight = torch.flatten(qkv_weight, start_dim=1) 59 | qkv_bias = torch.from_numpy( 60 | np.stack((query_bias, key_bias, value_bias), 0)) 61 | qkv_bias = torch.flatten(qkv_bias, start_dim=0) 62 | 63 | torch_weights[f'{torch_block}.attn.attn.in_proj_weight'] = qkv_weight 64 | torch_weights[f'{torch_block}.attn.attn.in_proj_bias'] = qkv_bias 65 | to_out_weight = jax_weights[ 66 | f'{jax_block}/MultiHeadDotProductAttention_1/out/kernel'] 67 | to_out_weight = torch.flatten(to_out_weight, start_dim=0, end_dim=1) 68 | torch_weights[ 69 | f'{torch_block}.attn.attn.out_proj.weight'] = to_out_weight 70 | torch_weights[f'{torch_block}.attn.attn.out_proj.bias'] = jax_weights[ 71 | f'{jax_block}/MultiHeadDotProductAttention_1/out/bias'] 72 | 73 | # mlp norm 74 | torch_weights[f'{torch_block}.ln2.weight'] = jax_weights[ 75 | f'{jax_block}/LayerNorm_2/scale'] 76 | torch_weights[f'{torch_block}.ln2.bias'] = jax_weights[ 77 | f'{jax_block}/LayerNorm_2/bias'] 78 | 79 | # mlp 80 | torch_weights[f'{torch_block}.ffn.layers.0.0.weight'] = jax_weights[ 81 | f'{jax_block}/MlpBlock_3/Dense_0/kernel'] 82 | torch_weights[f'{torch_block}.ffn.layers.0.0.bias'] = jax_weights[ 83 | f'{jax_block}/MlpBlock_3/Dense_0/bias'] 84 | torch_weights[f'{torch_block}.ffn.layers.1.weight'] = jax_weights[ 85 | f'{jax_block}/MlpBlock_3/Dense_1/kernel'] 86 | torch_weights[f'{torch_block}.ffn.layers.1.bias'] = jax_weights[ 87 | f'{jax_block}/MlpBlock_3/Dense_1/bias'] 88 | 89 | # transpose weights 90 | for k, v in torch_weights.items(): 91 | if 'weight' in k and 'patch_embed' not in k and 'ln' not in k: 92 | v = v.permute(1, 0) 93 | torch_weights[k] = v 94 | 95 | return torch_weights 96 | 97 | 98 | def main(): 99 | # stole refactoring code from Robin Strudel, thanks 100 | parser = argparse.ArgumentParser( 101 | description='Convert keys from jax official pretrained vit models to ' 102 | 'MMSegmentation style.') 103 | parser.add_argument('src', help='src model path or url') 104 | # The dst path must be a full path of the new checkpoint. 105 | parser.add_argument('dst', help='save path') 106 | args = parser.parse_args() 107 | 108 | jax_weights = np.load(args.src) 109 | jax_weights_tensor = {} 110 | for key in jax_weights.files: 111 | value = torch.from_numpy(jax_weights[key]) 112 | jax_weights_tensor[key] = value 113 | if 'L_16-i21k' in args.src: 114 | num_layer = 24 115 | else: 116 | num_layer = 12 117 | torch_weights = vit_jax_to_torch(jax_weights_tensor, num_layer) 118 | mmcv.mkdir_or_exist(osp.dirname(args.dst)) 119 | torch.save(torch_weights, args.dst) 120 | 121 | 122 | if __name__ == '__main__': 123 | main() 124 | -------------------------------------------------------------------------------- /tools/onnx2tensorrt.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os 4 | import os.path as osp 5 | import warnings 6 | from typing import Iterable, Optional, Union 7 | 8 | import matplotlib.pyplot as plt 9 | import mmcv 10 | import numpy as np 11 | import onnxruntime as ort 12 | import torch 13 | from mmcv.ops import get_onnxruntime_op_path 14 | from mmcv.tensorrt import (TRTWraper, is_tensorrt_plugin_loaded, onnx2trt, 15 | save_trt_engine) 16 | 17 | from mmseg.apis.inference import LoadImage 18 | from mmseg.datasets import DATASETS 19 | from mmseg.datasets.pipelines import Compose 20 | 21 | 22 | def get_GiB(x: int): 23 | """return x GiB.""" 24 | return x * (1 << 30) 25 | 26 | 27 | def _prepare_input_img(img_path: str, 28 | test_pipeline: Iterable[dict], 29 | shape: Optional[Iterable] = None, 30 | rescale_shape: Optional[Iterable] = None) -> dict: 31 | # build the data pipeline 32 | if shape is not None: 33 | test_pipeline[1]['img_scale'] = (shape[1], shape[0]) 34 | test_pipeline[1]['transforms'][0]['keep_ratio'] = False 35 | test_pipeline = [LoadImage()] + test_pipeline[1:] 36 | test_pipeline = Compose(test_pipeline) 37 | # prepare data 38 | data = dict(img=img_path) 39 | data = test_pipeline(data) 40 | imgs = data['img'] 41 | img_metas = [i.data for i in data['img_metas']] 42 | 43 | if rescale_shape is not None: 44 | for img_meta in img_metas: 45 | img_meta['ori_shape'] = tuple(rescale_shape) + (3, ) 46 | 47 | mm_inputs = {'imgs': imgs, 'img_metas': img_metas} 48 | 49 | return mm_inputs 50 | 51 | 52 | def _update_input_img(img_list: Iterable, img_meta_list: Iterable): 53 | # update img and its meta list 54 | N = img_list[0].size(0) 55 | img_meta = img_meta_list[0][0] 56 | img_shape = img_meta['img_shape'] 57 | ori_shape = img_meta['ori_shape'] 58 | pad_shape = img_meta['pad_shape'] 59 | new_img_meta_list = [[{ 60 | 'img_shape': 61 | img_shape, 62 | 'ori_shape': 63 | ori_shape, 64 | 'pad_shape': 65 | pad_shape, 66 | 'filename': 67 | img_meta['filename'], 68 | 'scale_factor': 69 | (img_shape[1] / ori_shape[1], img_shape[0] / ori_shape[0]) * 2, 70 | 'flip': 71 | False, 72 | } for _ in range(N)]] 73 | 74 | return img_list, new_img_meta_list 75 | 76 | 77 | def show_result_pyplot(img: Union[str, np.ndarray], 78 | result: np.ndarray, 79 | palette: Optional[Iterable] = None, 80 | fig_size: Iterable[int] = (15, 10), 81 | opacity: float = 0.5, 82 | title: str = '', 83 | block: bool = True): 84 | img = mmcv.imread(img) 85 | img = img.copy() 86 | seg = result[0] 87 | seg = mmcv.imresize(seg, img.shape[:2][::-1]) 88 | palette = np.array(palette) 89 | assert palette.shape[1] == 3 90 | assert len(palette.shape) == 2 91 | assert 0 < opacity <= 1.0 92 | color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) 93 | for label, color in enumerate(palette): 94 | color_seg[seg == label, :] = color 95 | # convert to BGR 96 | color_seg = color_seg[..., ::-1] 97 | 98 | img = img * (1 - opacity) + color_seg * opacity 99 | img = img.astype(np.uint8) 100 | 101 | plt.figure(figsize=fig_size) 102 | plt.imshow(mmcv.bgr2rgb(img)) 103 | plt.title(title) 104 | plt.tight_layout() 105 | plt.show(block=block) 106 | 107 | 108 | def onnx2tensorrt(onnx_file: str, 109 | trt_file: str, 110 | config: dict, 111 | input_config: dict, 112 | fp16: bool = False, 113 | verify: bool = False, 114 | show: bool = False, 115 | dataset: str = 'CityscapesDataset', 116 | workspace_size: int = 1, 117 | verbose: bool = False): 118 | import tensorrt as trt 119 | min_shape = input_config['min_shape'] 120 | max_shape = input_config['max_shape'] 121 | # create trt engine and wrapper 122 | opt_shape_dict = {'input': [min_shape, min_shape, max_shape]} 123 | max_workspace_size = get_GiB(workspace_size) 124 | trt_engine = onnx2trt( 125 | onnx_file, 126 | opt_shape_dict, 127 | log_level=trt.Logger.VERBOSE if verbose else trt.Logger.ERROR, 128 | fp16_mode=fp16, 129 | max_workspace_size=max_workspace_size) 130 | save_dir, _ = osp.split(trt_file) 131 | if save_dir: 132 | os.makedirs(save_dir, exist_ok=True) 133 | save_trt_engine(trt_engine, trt_file) 134 | print(f'Successfully created TensorRT engine: {trt_file}') 135 | 136 | if verify: 137 | inputs = _prepare_input_img( 138 | input_config['input_path'], 139 | config.data.test.pipeline, 140 | shape=min_shape[2:]) 141 | 142 | imgs = inputs['imgs'] 143 | img_metas = inputs['img_metas'] 144 | img_list = [img[None, :] for img in imgs] 145 | img_meta_list = [[img_meta] for img_meta in img_metas] 146 | # update img_meta 147 | img_list, img_meta_list = _update_input_img(img_list, img_meta_list) 148 | 149 | if max_shape[0] > 1: 150 | # concate flip image for batch test 151 | flip_img_list = [_.flip(-1) for _ in img_list] 152 | img_list = [ 153 | torch.cat((ori_img, flip_img), 0) 154 | for ori_img, flip_img in zip(img_list, flip_img_list) 155 | ] 156 | 157 | # Get results from ONNXRuntime 158 | ort_custom_op_path = get_onnxruntime_op_path() 159 | session_options = ort.SessionOptions() 160 | if osp.exists(ort_custom_op_path): 161 | session_options.register_custom_ops_library(ort_custom_op_path) 162 | sess = ort.InferenceSession(onnx_file, session_options) 163 | sess.set_providers(['CPUExecutionProvider'], [{}]) # use cpu mode 164 | onnx_output = sess.run(['output'], 165 | {'input': img_list[0].detach().numpy()})[0][0] 166 | 167 | # Get results from TensorRT 168 | trt_model = TRTWraper(trt_file, ['input'], ['output']) 169 | with torch.no_grad(): 170 | trt_outputs = trt_model({'input': img_list[0].contiguous().cuda()}) 171 | trt_output = trt_outputs['output'][0].cpu().detach().numpy() 172 | 173 | if show: 174 | dataset = DATASETS.get(dataset) 175 | assert dataset is not None 176 | palette = dataset.PALETTE 177 | 178 | show_result_pyplot( 179 | input_config['input_path'], 180 | (onnx_output[0].astype(np.uint8), ), 181 | palette=palette, 182 | title='ONNXRuntime', 183 | block=False) 184 | show_result_pyplot( 185 | input_config['input_path'], (trt_output[0].astype(np.uint8), ), 186 | palette=palette, 187 | title='TensorRT') 188 | 189 | np.testing.assert_allclose( 190 | onnx_output, trt_output, rtol=1e-03, atol=1e-05) 191 | print('TensorRT and ONNXRuntime output all close.') 192 | 193 | 194 | def parse_args(): 195 | parser = argparse.ArgumentParser( 196 | description='Convert MMSegmentation models from ONNX to TensorRT') 197 | parser.add_argument('config', help='Config file of the model') 198 | parser.add_argument('model', help='Path to the input ONNX model') 199 | parser.add_argument( 200 | '--trt-file', type=str, help='Path to the output TensorRT engine') 201 | parser.add_argument( 202 | '--max-shape', 203 | type=int, 204 | nargs=4, 205 | default=[1, 3, 400, 600], 206 | help='Maximum shape of model input.') 207 | parser.add_argument( 208 | '--min-shape', 209 | type=int, 210 | nargs=4, 211 | default=[1, 3, 400, 600], 212 | help='Minimum shape of model input.') 213 | parser.add_argument('--fp16', action='store_true', help='Enable fp16 mode') 214 | parser.add_argument( 215 | '--workspace-size', 216 | type=int, 217 | default=1, 218 | help='Max workspace size in GiB') 219 | parser.add_argument( 220 | '--input-img', type=str, default='', help='Image for test') 221 | parser.add_argument( 222 | '--show', action='store_true', help='Whether to show output results') 223 | parser.add_argument( 224 | '--dataset', 225 | type=str, 226 | default='CityscapesDataset', 227 | help='Dataset name') 228 | parser.add_argument( 229 | '--verify', 230 | action='store_true', 231 | help='Verify the outputs of ONNXRuntime and TensorRT') 232 | parser.add_argument( 233 | '--verbose', 234 | action='store_true', 235 | help='Whether to verbose logging messages while creating \ 236 | TensorRT engine.') 237 | args = parser.parse_args() 238 | return args 239 | 240 | 241 | if __name__ == '__main__': 242 | 243 | assert is_tensorrt_plugin_loaded(), 'TensorRT plugin should be compiled.' 244 | args = parse_args() 245 | 246 | if not args.input_img: 247 | args.input_img = osp.join(osp.dirname(__file__), '../demo/demo.png') 248 | 249 | # check arguments 250 | assert osp.exists(args.config), 'Config {} not found.'.format(args.config) 251 | assert osp.exists(args.model), \ 252 | 'ONNX model {} not found.'.format(args.model) 253 | assert args.workspace_size >= 0, 'Workspace size less than 0.' 254 | assert DATASETS.get(args.dataset) is not None, \ 255 | 'Dataset {} does not found.'.format(args.dataset) 256 | for max_value, min_value in zip(args.max_shape, args.min_shape): 257 | assert max_value >= min_value, \ 258 | 'max_shape should be larger than min shape' 259 | 260 | input_config = { 261 | 'min_shape': args.min_shape, 262 | 'max_shape': args.max_shape, 263 | 'input_path': args.input_img 264 | } 265 | 266 | cfg = mmcv.Config.fromfile(args.config) 267 | onnx2tensorrt( 268 | args.model, 269 | args.trt_file, 270 | cfg, 271 | input_config, 272 | fp16=args.fp16, 273 | verify=args.verify, 274 | show=args.show, 275 | dataset=args.dataset, 276 | workspace_size=args.workspace_size, 277 | verbose=args.verbose) 278 | 279 | # Following strings of text style are from colorama package 280 | bright_style, reset_style = '\x1b[1m', '\x1b[0m' 281 | red_text, blue_text = '\x1b[31m', '\x1b[34m' 282 | white_background = '\x1b[107m' 283 | 284 | msg = white_background + bright_style + red_text 285 | msg += 'DeprecationWarning: This tool will be deprecated in future. ' 286 | msg += blue_text + 'Welcome to use the unified model deployment toolbox ' 287 | msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy' 288 | msg += reset_style 289 | warnings.warn(msg) 290 | -------------------------------------------------------------------------------- /tools/print_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import warnings 4 | 5 | from mmcv import Config, DictAction 6 | 7 | from mmseg.apis import init_segmentor 8 | 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser(description='Print the whole config') 12 | parser.add_argument('config', help='config file path') 13 | parser.add_argument( 14 | '--graph', action='store_true', help='print the models graph') 15 | parser.add_argument( 16 | '--options', 17 | nargs='+', 18 | action=DictAction, 19 | help="--options is deprecated in favor of --cfg_options' and it will " 20 | 'not be supported in version v0.22.0. Override some settings in the ' 21 | 'used config, the key-value pair in xxx=yyy format will be merged ' 22 | 'into config file. If the value to be overwritten is a list, it ' 23 | 'should be like key="[a,b]" or key=a,b It also allows nested ' 24 | 'list/tuple values, e.g. key="[(a,b),(c,d)]" Note that the quotation ' 25 | 'marks are necessary and that no white space is allowed.') 26 | parser.add_argument( 27 | '--cfg-options', 28 | nargs='+', 29 | action=DictAction, 30 | help='override some settings in the used config, the key-value pair ' 31 | 'in xxx=yyy format will be merged into config file. If the value to ' 32 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 33 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 34 | 'Note that the quotation marks are necessary and that no white space ' 35 | 'is allowed.') 36 | args = parser.parse_args() 37 | 38 | if args.options and args.cfg_options: 39 | raise ValueError( 40 | '--options and --cfg-options cannot be both ' 41 | 'specified, --options is deprecated in favor of --cfg-options. ' 42 | '--options will not be supported in version v0.22.0.') 43 | if args.options: 44 | warnings.warn('--options is deprecated in favor of --cfg-options, ' 45 | '--options will not be supported in version v0.22.0.') 46 | args.cfg_options = args.options 47 | 48 | return args 49 | 50 | 51 | def main(): 52 | args = parse_args() 53 | 54 | cfg = Config.fromfile(args.config) 55 | if args.cfg_options is not None: 56 | cfg.merge_from_dict(args.cfg_options) 57 | print(f'Config:\n{cfg.pretty_text}') 58 | # dump config 59 | cfg.dump('example.py') 60 | # dump models graph 61 | if args.graph: 62 | model = init_segmentor(args.config, device='cpu') 63 | print(f'Model graph:\n{str(model)}') 64 | with open('example-graph.txt', 'w') as f: 65 | f.writelines(str(model)) 66 | 67 | 68 | if __name__ == '__main__': 69 | main() 70 | -------------------------------------------------------------------------------- /tools/publish_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import subprocess 4 | 5 | import torch 6 | 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser( 10 | description='Process a checkpoint to be published') 11 | parser.add_argument('in_file', help='input checkpoint filename') 12 | parser.add_argument('out_file', help='output checkpoint filename') 13 | args = parser.parse_args() 14 | return args 15 | 16 | 17 | def process_checkpoint(in_file, out_file): 18 | checkpoint = torch.load(in_file, map_location='cpu') 19 | # remove optimizer for smaller file size 20 | if 'optimizer' in checkpoint: 21 | del checkpoint['optimizer'] 22 | # if it is necessary to remove some sensitive data in checkpoint['meta'], 23 | # add the code here. 24 | torch.save(checkpoint, out_file) 25 | sha = subprocess.check_output(['sha256sum', out_file]).decode() 26 | final_file = out_file.rstrip('.pth') + '-{}.pth'.format(sha[:8]) 27 | subprocess.Popen(['mv', out_file, final_file]) 28 | 29 | 30 | def main(): 31 | args = parse_args() 32 | process_checkpoint(args.in_file, args.out_file) 33 | 34 | 35 | if __name__ == '__main__': 36 | main() 37 | -------------------------------------------------------------------------------- /tools/pytorch2torchscript.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | 4 | import mmcv 5 | import numpy as np 6 | import torch 7 | import torch._C 8 | import torch.serialization 9 | from mmcv.runner import load_checkpoint 10 | from torch import nn 11 | 12 | from mmseg.models import build_segmentor 13 | 14 | torch.manual_seed(3) 15 | 16 | 17 | def digit_version(version_str): 18 | digit_version = [] 19 | for x in version_str.split('.'): 20 | if x.isdigit(): 21 | digit_version.append(int(x)) 22 | elif x.find('rc') != -1: 23 | patch_version = x.split('rc') 24 | digit_version.append(int(patch_version[0]) - 1) 25 | digit_version.append(int(patch_version[1])) 26 | return digit_version 27 | 28 | 29 | def check_torch_version(): 30 | torch_minimum_version = '1.8.0' 31 | torch_version = digit_version(torch.__version__) 32 | 33 | assert (torch_version >= digit_version(torch_minimum_version)), \ 34 | f'Torch=={torch.__version__} is not support for converting to ' \ 35 | f'torchscript. Please install pytorch>={torch_minimum_version}.' 36 | 37 | 38 | def _convert_batchnorm(module): 39 | module_output = module 40 | if isinstance(module, torch.nn.SyncBatchNorm): 41 | module_output = torch.nn.BatchNorm2d(module.num_features, module.eps, 42 | module.momentum, module.affine, 43 | module.track_running_stats) 44 | if module.affine: 45 | module_output.weight.data = module.weight.data.clone().detach() 46 | module_output.bias.data = module.bias.data.clone().detach() 47 | # keep requires_grad unchanged 48 | module_output.weight.requires_grad = module.weight.requires_grad 49 | module_output.bias.requires_grad = module.bias.requires_grad 50 | module_output.running_mean = module.running_mean 51 | module_output.running_var = module.running_var 52 | module_output.num_batches_tracked = module.num_batches_tracked 53 | for name, child in module.named_children(): 54 | module_output.add_module(name, _convert_batchnorm(child)) 55 | del module 56 | return module_output 57 | 58 | 59 | def _demo_mm_inputs(input_shape, num_classes): 60 | """Create a superset of inputs needed to run test or train batches. 61 | 62 | Args: 63 | input_shape (tuple): 64 | input batch dimensions 65 | num_classes (int): 66 | number of semantic classes 67 | """ 68 | (N, C, H, W) = input_shape 69 | rng = np.random.RandomState(0) 70 | imgs = rng.rand(*input_shape) 71 | segs = rng.randint( 72 | low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8) 73 | img_metas = [{ 74 | 'img_shape': (H, W, C), 75 | 'ori_shape': (H, W, C), 76 | 'pad_shape': (H, W, C), 77 | 'filename': '.png', 78 | 'scale_factor': 1.0, 79 | 'flip': False, 80 | } for _ in range(N)] 81 | mm_inputs = { 82 | 'imgs': torch.FloatTensor(imgs).requires_grad_(True), 83 | 'img_metas': img_metas, 84 | 'gt_semantic_seg': torch.LongTensor(segs) 85 | } 86 | return mm_inputs 87 | 88 | 89 | def pytorch2libtorch(model, 90 | input_shape, 91 | show=False, 92 | output_file='tmp.pt', 93 | verify=False): 94 | """Export Pytorch model to TorchScript model and verify the outputs are 95 | same between Pytorch and TorchScript. 96 | 97 | Args: 98 | model (nn.Module): Pytorch model we want to export. 99 | input_shape (tuple): Use this input shape to construct 100 | the corresponding dummy input and execute the model. 101 | show (bool): Whether print the computation graph. Default: False. 102 | output_file (string): The path to where we store the 103 | output TorchScript model. Default: `tmp.pt`. 104 | verify (bool): Whether compare the outputs between 105 | Pytorch and TorchScript. Default: False. 106 | """ 107 | if isinstance(model.decode_head, nn.ModuleList): 108 | num_classes = model.decode_head[-1].num_classes 109 | else: 110 | num_classes = model.decode_head.num_classes 111 | 112 | mm_inputs = _demo_mm_inputs(input_shape, num_classes) 113 | 114 | imgs = mm_inputs.pop('imgs') 115 | 116 | # replace the original forword with forward_dummy 117 | model.forward = model.forward_dummy 118 | model.eval() 119 | traced_model = torch.jit.trace( 120 | model, 121 | example_inputs=imgs, 122 | check_trace=verify, 123 | ) 124 | 125 | if show: 126 | print(traced_model.graph) 127 | 128 | traced_model.save(output_file) 129 | print('Successfully exported TorchScript model: {}'.format(output_file)) 130 | 131 | 132 | def parse_args(): 133 | parser = argparse.ArgumentParser( 134 | description='Convert MMSeg to TorchScript') 135 | parser.add_argument('config', help='test config file path') 136 | parser.add_argument('--checkpoint', help='checkpoint file', default=None) 137 | parser.add_argument( 138 | '--show', action='store_true', help='show TorchScript graph') 139 | parser.add_argument( 140 | '--verify', action='store_true', help='verify the TorchScript model') 141 | parser.add_argument('--output-file', type=str, default='tmp.pt') 142 | parser.add_argument( 143 | '--shape', 144 | type=int, 145 | nargs='+', 146 | default=[512, 512], 147 | help='input image size (height, width)') 148 | args = parser.parse_args() 149 | return args 150 | 151 | 152 | if __name__ == '__main__': 153 | args = parse_args() 154 | check_torch_version() 155 | 156 | if len(args.shape) == 1: 157 | input_shape = (1, 3, args.shape[0], args.shape[0]) 158 | elif len(args.shape) == 2: 159 | input_shape = ( 160 | 1, 161 | 3, 162 | ) + tuple(args.shape) 163 | else: 164 | raise ValueError('invalid input shape') 165 | 166 | cfg = mmcv.Config.fromfile(args.config) 167 | cfg.model.pretrained = None 168 | 169 | # build the model and load checkpoint 170 | cfg.model.train_cfg = None 171 | segmentor = build_segmentor( 172 | cfg.model, train_cfg=None, test_cfg=cfg.get('test_cfg')) 173 | # convert SyncBN to BN 174 | segmentor = _convert_batchnorm(segmentor) 175 | 176 | if args.checkpoint: 177 | load_checkpoint(segmentor, args.checkpoint, map_location='cpu') 178 | 179 | # convert the PyTorch model to LibTorch model 180 | pytorch2libtorch( 181 | segmentor, 182 | input_shape, 183 | show=args.show, 184 | output_file=args.output_file, 185 | verify=args.verify) 186 | -------------------------------------------------------------------------------- /tools/slurm_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | CONFIG=$3 8 | CHECKPOINT=$4 9 | GPUS=${GPUS:-4} 10 | GPUS_PER_NODE=${GPUS_PER_NODE:-4} 11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 12 | PY_ARGS=${@:5} 13 | SRUN_ARGS=${SRUN_ARGS:-""} 14 | 15 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 16 | srun -p ${PARTITION} \ 17 | --job-name=${JOB_NAME} \ 18 | --gres=gpu:${GPUS_PER_NODE} \ 19 | --ntasks=${GPUS} \ 20 | --ntasks-per-node=${GPUS_PER_NODE} \ 21 | --cpus-per-task=${CPUS_PER_TASK} \ 22 | --kill-on-bad-exit=1 \ 23 | ${SRUN_ARGS} \ 24 | python -u tools/test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS} 25 | -------------------------------------------------------------------------------- /tools/slurm_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | CONFIG=$3 8 | GPUS=${GPUS:-4} 9 | GPUS_PER_NODE=${GPUS_PER_NODE:-4} 10 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 11 | SRUN_ARGS=${SRUN_ARGS:-""} 12 | PY_ARGS=${@:4} 13 | 14 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 15 | srun -p ${PARTITION} \ 16 | --job-name=${JOB_NAME} \ 17 | --gres=gpu:${GPUS_PER_NODE} \ 18 | --ntasks=${GPUS} \ 19 | --ntasks-per-node=${GPUS_PER_NODE} \ 20 | --cpus-per-task=${CPUS_PER_TASK} \ 21 | --kill-on-bad-exit=1 \ 22 | ${SRUN_ARGS} \ 23 | python -u tools/train.py ${CONFIG} --launcher="slurm" ${PY_ARGS} 24 | -------------------------------------------------------------------------------- /tools/torchserve/mmseg2torchserve.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from argparse import ArgumentParser, Namespace 3 | from pathlib import Path 4 | from tempfile import TemporaryDirectory 5 | 6 | import mmcv 7 | 8 | try: 9 | from model_archiver.model_packaging import package_model 10 | from model_archiver.model_packaging_utils import ModelExportUtils 11 | except ImportError: 12 | package_model = None 13 | 14 | 15 | def mmseg2torchserve( 16 | config_file: str, 17 | checkpoint_file: str, 18 | output_folder: str, 19 | model_name: str, 20 | model_version: str = '1.0', 21 | force: bool = False, 22 | ): 23 | """Converts mmsegmentation model (config + checkpoint) to TorchServe 24 | `.mar`. 25 | 26 | Args: 27 | config_file: 28 | In MMSegmentation config format. 29 | The contents vary for each task repository. 30 | checkpoint_file: 31 | In MMSegmentation checkpoint format. 32 | The contents vary for each task repository. 33 | output_folder: 34 | Folder where `{model_name}.mar` will be created. 35 | The file created will be in TorchServe archive format. 36 | model_name: 37 | If not None, used for naming the `{model_name}.mar` file 38 | that will be created under `output_folder`. 39 | If None, `{Path(checkpoint_file).stem}` will be used. 40 | model_version: 41 | Model's version. 42 | force: 43 | If True, if there is an existing `{model_name}.mar` 44 | file under `output_folder` it will be overwritten. 45 | """ 46 | mmcv.mkdir_or_exist(output_folder) 47 | 48 | config = mmcv.Config.fromfile(config_file) 49 | 50 | with TemporaryDirectory() as tmpdir: 51 | config.dump(f'{tmpdir}/config.py') 52 | 53 | args = Namespace( 54 | **{ 55 | 'model_file': f'{tmpdir}/config.py', 56 | 'serialized_file': checkpoint_file, 57 | 'handler': f'{Path(__file__).parent}/mmseg_handler.py', 58 | 'model_name': model_name or Path(checkpoint_file).stem, 59 | 'version': model_version, 60 | 'export_path': output_folder, 61 | 'force': force, 62 | 'requirements_file': None, 63 | 'extra_files': None, 64 | 'runtime': 'python', 65 | 'archive_format': 'default' 66 | }) 67 | manifest = ModelExportUtils.generate_manifest_json(args) 68 | package_model(args, manifest) 69 | 70 | 71 | def parse_args(): 72 | parser = ArgumentParser( 73 | description='Convert mmseg models to TorchServe `.mar` format.') 74 | parser.add_argument('config', type=str, help='config file path') 75 | parser.add_argument('checkpoint', type=str, help='checkpoint file path') 76 | parser.add_argument( 77 | '--output-folder', 78 | type=str, 79 | required=True, 80 | help='Folder where `{model_name}.mar` will be created.') 81 | parser.add_argument( 82 | '--model-name', 83 | type=str, 84 | default=None, 85 | help='If not None, used for naming the `{model_name}.mar`' 86 | 'file that will be created under `output_folder`.' 87 | 'If None, `{Path(checkpoint_file).stem}` will be used.') 88 | parser.add_argument( 89 | '--model-version', 90 | type=str, 91 | default='1.0', 92 | help='Number used for versioning.') 93 | parser.add_argument( 94 | '-f', 95 | '--force', 96 | action='store_true', 97 | help='overwrite the existing `{model_name}.mar`') 98 | args = parser.parse_args() 99 | 100 | return args 101 | 102 | 103 | if __name__ == '__main__': 104 | args = parse_args() 105 | 106 | if package_model is None: 107 | raise ImportError('`torch-model-archiver` is required.' 108 | 'Try: pip install torch-model-archiver') 109 | 110 | mmseg2torchserve(args.config, args.checkpoint, args.output_folder, 111 | args.model_name, args.model_version, args.force) 112 | -------------------------------------------------------------------------------- /tools/torchserve/mmseg_handler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import base64 3 | import os 4 | 5 | import cv2 6 | import mmcv 7 | import torch 8 | from mmcv.cnn.utils.sync_bn import revert_sync_batchnorm 9 | from ts.torch_handler.base_handler import BaseHandler 10 | 11 | from mmseg.apis import inference_segmentor, init_segmentor 12 | 13 | 14 | class MMsegHandler(BaseHandler): 15 | 16 | def initialize(self, context): 17 | properties = context.system_properties 18 | self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu' 19 | self.device = torch.device(self.map_location + ':' + 20 | str(properties.get('gpu_id')) if torch.cuda. 21 | is_available() else self.map_location) 22 | self.manifest = context.manifest 23 | 24 | model_dir = properties.get('model_dir') 25 | serialized_file = self.manifest['model']['serializedFile'] 26 | checkpoint = os.path.join(model_dir, serialized_file) 27 | self.config_file = os.path.join(model_dir, 'config.py') 28 | 29 | self.model = init_segmentor(self.config_file, checkpoint, self.device) 30 | self.model = revert_sync_batchnorm(self.model) 31 | self.initialized = True 32 | 33 | def preprocess(self, data): 34 | images = [] 35 | 36 | for row in data: 37 | image = row.get('data') or row.get('body') 38 | if isinstance(image, str): 39 | image = base64.b64decode(image) 40 | image = mmcv.imfrombytes(image) 41 | images.append(image) 42 | 43 | return images 44 | 45 | def inference(self, data, *args, **kwargs): 46 | results = [inference_segmentor(self.model, img) for img in data] 47 | return results 48 | 49 | def postprocess(self, data): 50 | output = [] 51 | 52 | for image_result in data: 53 | _, buffer = cv2.imencode('.png', image_result[0].astype('uint8')) 54 | content = buffer.tobytes() 55 | output.append(content) 56 | return output 57 | -------------------------------------------------------------------------------- /tools/torchserve/test_torchserve.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from argparse import ArgumentParser 3 | from io import BytesIO 4 | 5 | import matplotlib.pyplot as plt 6 | import mmcv 7 | import requests 8 | 9 | from mmseg.apis import inference_segmentor, init_segmentor 10 | 11 | 12 | def parse_args(): 13 | parser = ArgumentParser( 14 | description='Compare result of torchserve and pytorch,' 15 | 'and visualize them.') 16 | parser.add_argument('img', help='Image file') 17 | parser.add_argument('config', help='Config file') 18 | parser.add_argument('checkpoint', help='Checkpoint file') 19 | parser.add_argument('model_name', help='The model name in the server') 20 | parser.add_argument( 21 | '--inference-addr', 22 | default='127.0.0.1:8080', 23 | help='Address and port of the inference server') 24 | parser.add_argument( 25 | '--result-image', 26 | type=str, 27 | default=None, 28 | help='save server output in result-image') 29 | parser.add_argument( 30 | '--device', default='cuda:0', help='Device used for inference') 31 | 32 | args = parser.parse_args() 33 | return args 34 | 35 | 36 | def main(args): 37 | url = 'http://' + args.inference_addr + '/predictions/' + args.model_name 38 | with open(args.img, 'rb') as image: 39 | tmp_res = requests.post(url, image) 40 | content = tmp_res.content 41 | if args.result_image: 42 | with open(args.result_image, 'wb') as out_image: 43 | out_image.write(content) 44 | plt.imshow(mmcv.imread(args.result_image, 'grayscale')) 45 | plt.show() 46 | else: 47 | plt.imshow(plt.imread(BytesIO(content))) 48 | plt.show() 49 | model = init_segmentor(args.config, args.checkpoint, args.device) 50 | image = mmcv.imread(args.img) 51 | result = inference_segmentor(model, image) 52 | plt.imshow(result[0]) 53 | plt.show() 54 | 55 | 56 | if __name__ == '__main__': 57 | args = parse_args() 58 | main(args) 59 | -------------------------------------------------------------------------------- /tools/van.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from timm.models.layers import DropPath 5 | from mmcv.cnn.utils.weight_init import (constant_init, normal_init, 6 | trunc_normal_init) 7 | from torch.nn.modules.utils import _pair as to_2tuple 8 | from mmseg.models.builder import BACKBONES 9 | 10 | from mmcv.cnn import build_norm_layer 11 | from mmcv.runner import BaseModule 12 | import math 13 | import warnings 14 | 15 | 16 | class Mlp(nn.Module): 17 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., linear=False): 18 | super().__init__() 19 | out_features = out_features or in_features 20 | hidden_features = hidden_features or in_features 21 | self.fc1 = nn.Conv2d(in_features, hidden_features, 1) 22 | self.dwconv = DWConv(hidden_features) 23 | self.act = act_layer() 24 | self.fc2 = nn.Conv2d(hidden_features, out_features, 1) 25 | self.drop = nn.Dropout(drop) 26 | self.linear = linear 27 | if self.linear: 28 | self.relu = nn.ReLU(inplace=True) 29 | 30 | def forward(self, x): 31 | x = self.fc1(x) 32 | if self.linear: 33 | x = self.relu(x) 34 | x = self.dwconv(x) 35 | x = self.act(x) 36 | x = self.drop(x) 37 | x = self.fc2(x) 38 | x = self.drop(x) 39 | return x 40 | 41 | 42 | class AttentionModule(nn.Module): 43 | def __init__(self, dim): 44 | super().__init__() 45 | self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim) 46 | self.conv_spatial = nn.Conv2d( 47 | dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3) 48 | self.conv1 = nn.Conv2d(dim, dim, 1) 49 | 50 | def forward(self, x): 51 | u = x.clone() 52 | attn = self.conv0(x) 53 | attn = self.conv_spatial(attn) 54 | attn = self.conv1(attn) 55 | return u * attn 56 | 57 | 58 | class SpatialAttention(nn.Module): 59 | def __init__(self, d_model): 60 | super().__init__() 61 | self.d_model = d_model 62 | self.proj_1 = nn.Conv2d(d_model, d_model, 1) 63 | self.activation = nn.GELU() 64 | self.spatial_gating_unit = AttentionModule(d_model) 65 | self.proj_2 = nn.Conv2d(d_model, d_model, 1) 66 | 67 | def forward(self, x): 68 | shorcut = x.clone() 69 | x = self.proj_1(x) 70 | x = self.activation(x) 71 | x = self.spatial_gating_unit(x) 72 | x = self.proj_2(x) 73 | x = x + shorcut 74 | return x 75 | 76 | 77 | class Block(nn.Module): 78 | 79 | def __init__(self, 80 | dim, 81 | mlp_ratio=4., 82 | drop=0., 83 | drop_path=0., 84 | act_layer=nn.GELU, 85 | linear=False, 86 | norm_cfg=dict(type='SyncBN', requires_grad=True)): 87 | super().__init__() 88 | self.norm1 = build_norm_layer(norm_cfg, dim)[1] 89 | self.attn = SpatialAttention(dim) 90 | self.drop_path = DropPath( 91 | drop_path) if drop_path > 0. else nn.Identity() 92 | 93 | self.norm2 = build_norm_layer(norm_cfg, dim)[1] 94 | mlp_hidden_dim = int(dim * mlp_ratio) 95 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, 96 | act_layer=act_layer, drop=drop, linear=linear) 97 | layer_scale_init_value = 1e-2 98 | self.layer_scale_1 = nn.Parameter( 99 | layer_scale_init_value * torch.ones((dim)), requires_grad=True) 100 | self.layer_scale_2 = nn.Parameter( 101 | layer_scale_init_value * torch.ones((dim)), requires_grad=True) 102 | 103 | def forward(self, x, H, W): 104 | B, N, C = x.shape 105 | x = x.permute(0, 2, 1).view(B, C, H, W) 106 | x = x + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) 107 | * self.attn(self.norm1(x))) 108 | x = x + self.drop_path(self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) 109 | * self.mlp(self.norm2(x))) 110 | x = x.view(B, C, N).permute(0, 2, 1) 111 | return x 112 | 113 | 114 | class OverlapPatchEmbed(nn.Module): 115 | """ Image to Patch Embedding 116 | """ 117 | 118 | def __init__(self, 119 | patch_size=7, 120 | stride=4, 121 | in_chans=3, 122 | embed_dim=768, 123 | norm_cfg=dict(type='SyncBN', requires_grad=True)): 124 | super().__init__() 125 | patch_size = to_2tuple(patch_size) 126 | 127 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, 128 | padding=(patch_size[0] // 2, patch_size[1] // 2)) 129 | self.norm = build_norm_layer(norm_cfg, embed_dim)[1] 130 | 131 | def forward(self, x): 132 | x = self.proj(x) 133 | _, _, H, W = x.shape 134 | x = self.norm(x) 135 | 136 | x = x.flatten(2).transpose(1, 2) 137 | 138 | return x, H, W 139 | 140 | 141 | @BACKBONES.register_module() 142 | class VAN(BaseModule): 143 | def __init__(self, 144 | in_chans=3, 145 | embed_dims=[64, 128, 256, 512], 146 | mlp_ratios=[8, 8, 4, 4], 147 | drop_rate=0., 148 | drop_path_rate=0., 149 | depths=[3, 4, 6, 3], 150 | num_stages=4, 151 | linear=False, 152 | pretrained=None, 153 | init_cfg=None, 154 | norm_cfg=dict(type='SyncBN', requires_grad=True)): 155 | super(VAN, self).__init__(init_cfg=init_cfg) 156 | 157 | assert not (init_cfg and pretrained), \ 158 | 'init_cfg and pretrained cannot be set at the same time' 159 | if isinstance(pretrained, str): 160 | warnings.warn('DeprecationWarning: pretrained is deprecated, ' 161 | 'please use "init_cfg" instead') 162 | self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) 163 | elif pretrained is not None: 164 | raise TypeError('pretrained must be a str or None') 165 | 166 | self.depths = depths 167 | self.num_stages = num_stages 168 | self.linear = linear 169 | 170 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, 171 | sum(depths))] # stochastic depth decay rule 172 | cur = 0 173 | 174 | for i in range(num_stages): 175 | patch_embed = OverlapPatchEmbed(patch_size=7 if i == 0 else 3, 176 | stride=4 if i == 0 else 2, 177 | in_chans=in_chans if i == 0 else embed_dims[i - 1], 178 | embed_dim=embed_dims[i]) 179 | 180 | block = nn.ModuleList([Block(dim=embed_dims[i], 181 | mlp_ratio=mlp_ratios[i], 182 | drop=drop_rate, 183 | drop_path=dpr[cur + j], 184 | linear=linear, 185 | norm_cfg=norm_cfg) 186 | for j in range(depths[i])]) 187 | norm = nn.LayerNorm(embed_dims[i]) 188 | cur += depths[i] 189 | 190 | setattr(self, f"patch_embed{i + 1}", patch_embed) 191 | setattr(self, f"block{i + 1}", block) 192 | setattr(self, f"norm{i + 1}", norm) 193 | 194 | def init_weights(self): 195 | print('init cfg', self.init_cfg) 196 | if self.init_cfg is None: 197 | for m in self.modules(): 198 | if isinstance(m, nn.Linear): 199 | trunc_normal_init(m, std=.02, bias=0.) 200 | elif isinstance(m, nn.LayerNorm): 201 | constant_init(m, val=1.0, bias=0.) 202 | elif isinstance(m, nn.Conv2d): 203 | fan_out = m.kernel_size[0] * m.kernel_size[ 204 | 1] * m.out_channels 205 | fan_out //= m.groups 206 | normal_init( 207 | m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0) 208 | else: 209 | super(VAN, self).init_weights() 210 | 211 | def forward(self, x): 212 | B = x.shape[0] 213 | outs = [] 214 | 215 | for i in range(self.num_stages): 216 | patch_embed = getattr(self, f"patch_embed{i + 1}") 217 | block = getattr(self, f"block{i + 1}") 218 | norm = getattr(self, f"norm{i + 1}") 219 | x, H, W = patch_embed(x) 220 | for blk in block: 221 | x = blk(x, H, W) 222 | x = norm(x) 223 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 224 | outs.append(x) 225 | 226 | return outs 227 | 228 | 229 | class DWConv(nn.Module): 230 | def __init__(self, dim=768): 231 | super(DWConv, self).__init__() 232 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 233 | 234 | def forward(self, x): 235 | x = self.dwconv(x) 236 | return x 237 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import copy 4 | import os 5 | import os.path as osp 6 | import time 7 | import warnings 8 | 9 | import mmcv 10 | import torch 11 | import torch.distributed as dist 12 | from mmcv.cnn.utils import revert_sync_batchnorm 13 | from mmcv.runner import get_dist_info, init_dist 14 | from mmcv.utils import Config, DictAction, get_git_hash 15 | 16 | from mmseg import __version__ 17 | from mmseg.apis import init_random_seed, set_random_seed, train_segmentor 18 | from mmseg.datasets import build_dataset 19 | from mmseg.models import build_segmentor 20 | from mmseg.utils import (collect_env, get_device, get_root_logger, 21 | setup_multi_processes) 22 | import van 23 | 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser(description='Train a segmentor') 27 | parser.add_argument('config', help='train config file path') 28 | parser.add_argument('--work-dir', help='the dir to save logs and models') 29 | parser.add_argument( 30 | '--load-from', help='the checkpoint file to load weights from') 31 | parser.add_argument( 32 | '--resume-from', help='the checkpoint file to resume from') 33 | parser.add_argument( 34 | '--no-validate', 35 | action='store_true', 36 | help='whether not to evaluate the checkpoint during training') 37 | group_gpus = parser.add_mutually_exclusive_group() 38 | group_gpus.add_argument( 39 | '--gpus', 40 | type=int, 41 | help='(Deprecated, please use --gpu-id) number of gpus to use ' 42 | '(only applicable to non-distributed training)') 43 | group_gpus.add_argument( 44 | '--gpu-ids', 45 | type=int, 46 | nargs='+', 47 | help='(Deprecated, please use --gpu-id) ids of gpus to use ' 48 | '(only applicable to non-distributed training)') 49 | group_gpus.add_argument( 50 | '--gpu-id', 51 | type=int, 52 | default=0, 53 | help='id of gpu to use ' 54 | '(only applicable to non-distributed training)') 55 | parser.add_argument('--seed', type=int, default=None, help='random seed') 56 | parser.add_argument( 57 | '--diff_seed', 58 | action='store_true', 59 | help='Whether or not set different seeds for different ranks') 60 | parser.add_argument( 61 | '--deterministic', 62 | action='store_true', 63 | help='whether to set deterministic options for CUDNN backend.') 64 | parser.add_argument( 65 | '--options', 66 | nargs='+', 67 | action=DictAction, 68 | help="--options is deprecated in favor of --cfg_options' and it will " 69 | 'not be supported in version v0.22.0. Override some settings in the ' 70 | 'used config, the key-value pair in xxx=yyy format will be merged ' 71 | 'into config file. If the value to be overwritten is a list, it ' 72 | 'should be like key="[a,b]" or key=a,b It also allows nested ' 73 | 'list/tuple values, e.g. key="[(a,b),(c,d)]" Note that the quotation ' 74 | 'marks are necessary and that no white space is allowed.') 75 | parser.add_argument( 76 | '--cfg-options', 77 | nargs='+', 78 | action=DictAction, 79 | help='override some settings in the used config, the key-value pair ' 80 | 'in xxx=yyy format will be merged into config file. If the value to ' 81 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 82 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 83 | 'Note that the quotation marks are necessary and that no white space ' 84 | 'is allowed.') 85 | parser.add_argument( 86 | '--launcher', 87 | choices=['none', 'pytorch', 'slurm', 'mpi'], 88 | default='none', 89 | help='job launcher') 90 | parser.add_argument('--local_rank', type=int, default=0) 91 | parser.add_argument( 92 | '--auto-resume', 93 | action='store_true', 94 | help='resume from the latest checkpoint automatically.') 95 | args = parser.parse_args() 96 | if 'LOCAL_RANK' not in os.environ: 97 | os.environ['LOCAL_RANK'] = str(args.local_rank) 98 | 99 | if args.options and args.cfg_options: 100 | raise ValueError( 101 | '--options and --cfg-options cannot be both ' 102 | 'specified, --options is deprecated in favor of --cfg-options. ' 103 | '--options will not be supported in version v0.22.0.') 104 | if args.options: 105 | warnings.warn('--options is deprecated in favor of --cfg-options. ' 106 | '--options will not be supported in version v0.22.0.') 107 | args.cfg_options = args.options 108 | 109 | return args 110 | 111 | 112 | def main(): 113 | args = parse_args() 114 | 115 | cfg = Config.fromfile(args.config) 116 | if args.cfg_options is not None: 117 | cfg.merge_from_dict(args.cfg_options) 118 | 119 | # set cudnn_benchmark 120 | if cfg.get('cudnn_benchmark', False): 121 | torch.backends.cudnn.benchmark = True 122 | 123 | # work_dir is determined in this priority: CLI > segment in file > filename 124 | if args.work_dir is not None: 125 | # update configs according to CLI args if args.work_dir is not None 126 | cfg.work_dir = args.work_dir 127 | elif cfg.get('work_dir', None) is None: 128 | # use config filename as default work_dir if cfg.work_dir is None 129 | cfg.work_dir = osp.join('./work_dirs', 130 | osp.splitext(osp.basename(args.config))[0]) 131 | if args.load_from is not None: 132 | cfg.load_from = args.load_from 133 | if args.resume_from is not None: 134 | cfg.resume_from = args.resume_from 135 | if args.gpus is not None: 136 | cfg.gpu_ids = range(1) 137 | warnings.warn('`--gpus` is deprecated because we only support ' 138 | 'single GPU mode in non-distributed training. ' 139 | 'Use `gpus=1` now.') 140 | if args.gpu_ids is not None: 141 | cfg.gpu_ids = args.gpu_ids[0:1] 142 | warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. ' 143 | 'Because we only support single GPU mode in ' 144 | 'non-distributed training. Use the first GPU ' 145 | 'in `gpu_ids` now.') 146 | if args.gpus is None and args.gpu_ids is None: 147 | cfg.gpu_ids = [args.gpu_id] 148 | 149 | cfg.auto_resume = args.auto_resume 150 | 151 | # init distributed env first, since logger depends on the dist info. 152 | if args.launcher == 'none': 153 | distributed = False 154 | else: 155 | distributed = True 156 | init_dist(args.launcher, **cfg.dist_params) 157 | # gpu_ids is used to calculate iter when resuming checkpoint 158 | _, world_size = get_dist_info() 159 | cfg.gpu_ids = range(world_size) 160 | 161 | # create work_dir 162 | mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) 163 | # dump config 164 | cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) 165 | # init the logger before other steps 166 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 167 | log_file = osp.join(cfg.work_dir, f'{timestamp}.log') 168 | logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) 169 | 170 | # set multi-process settings 171 | setup_multi_processes(cfg) 172 | 173 | # init the meta dict to record some important information such as 174 | # environment info and seed, which will be logged 175 | meta = dict() 176 | # log env info 177 | env_info_dict = collect_env() 178 | env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()]) 179 | dash_line = '-' * 60 + '\n' 180 | logger.info('Environment info:\n' + dash_line + env_info + '\n' + 181 | dash_line) 182 | meta['env_info'] = env_info 183 | 184 | # log some basic info 185 | logger.info(f'Distributed training: {distributed}') 186 | logger.info(f'Config:\n{cfg.pretty_text}') 187 | 188 | # set random seeds 189 | cfg.device = get_device() 190 | seed = init_random_seed(args.seed, device=cfg.device) 191 | seed = seed + dist.get_rank() if args.diff_seed else seed 192 | logger.info(f'Set random seed to {seed}, ' 193 | f'deterministic: {args.deterministic}') 194 | set_random_seed(seed, deterministic=args.deterministic) 195 | cfg.seed = seed 196 | meta['seed'] = seed 197 | meta['exp_name'] = osp.basename(args.config) 198 | 199 | model = build_segmentor( 200 | cfg.model, 201 | train_cfg=cfg.get('train_cfg'), 202 | test_cfg=cfg.get('test_cfg')) 203 | model.init_weights() 204 | 205 | # SyncBN is not support for DP 206 | if not distributed: 207 | warnings.warn( 208 | 'SyncBN is only supported with DDP. To be compatible with DP, ' 209 | 'we convert SyncBN to BN. Please use dist_train.sh which can ' 210 | 'avoid this error.') 211 | model = revert_sync_batchnorm(model) 212 | 213 | logger.info(model) 214 | 215 | datasets = [build_dataset(cfg.data.train)] 216 | if len(cfg.workflow) == 2: 217 | val_dataset = copy.deepcopy(cfg.data.val) 218 | val_dataset.pipeline = cfg.data.train.pipeline 219 | datasets.append(build_dataset(val_dataset)) 220 | if cfg.checkpoint_config is not None: 221 | # save mmseg version, config file content and class names in 222 | # checkpoints as meta data 223 | cfg.checkpoint_config.meta = dict( 224 | mmseg_version=f'{__version__}+{get_git_hash()[:7]}', 225 | config=cfg.pretty_text, 226 | CLASSES=datasets[0].CLASSES, 227 | PALETTE=datasets[0].PALETTE) 228 | # add an attribute for visualization convenience 229 | model.CLASSES = datasets[0].CLASSES 230 | # passing checkpoint meta for saving best checkpoint 231 | meta.update(cfg.checkpoint_config.meta) 232 | train_segmentor( 233 | model, 234 | datasets, 235 | cfg, 236 | distributed=distributed, 237 | validate=(not args.no_validate), 238 | timestamp=timestamp, 239 | meta=meta) 240 | 241 | 242 | if __name__ == '__main__': 243 | main() 244 | -------------------------------------------------------------------------------- /van.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from timm.models.layers import DropPath 5 | from mmcv.cnn.utils.weight_init import (constant_init, normal_init, 6 | trunc_normal_init) 7 | from torch.nn.modules.utils import _pair as to_2tuple 8 | from mmseg.models.builder import BACKBONES 9 | 10 | from mmcv.cnn import build_norm_layer 11 | from mmcv.runner import BaseModule 12 | import math 13 | import warnings 14 | 15 | 16 | class Mlp(nn.Module): 17 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., linear=False): 18 | super().__init__() 19 | out_features = out_features or in_features 20 | hidden_features = hidden_features or in_features 21 | self.fc1 = nn.Conv2d(in_features, hidden_features, 1) 22 | self.dwconv = DWConv(hidden_features) 23 | self.act = act_layer() 24 | self.fc2 = nn.Conv2d(hidden_features, out_features, 1) 25 | self.drop = nn.Dropout(drop) 26 | self.linear = linear 27 | if self.linear: 28 | self.relu = nn.ReLU(inplace=True) 29 | 30 | def forward(self, x): 31 | x = self.fc1(x) 32 | if self.linear: 33 | x = self.relu(x) 34 | x = self.dwconv(x) 35 | x = self.act(x) 36 | x = self.drop(x) 37 | x = self.fc2(x) 38 | x = self.drop(x) 39 | return x 40 | 41 | 42 | class AttentionModule(nn.Module): 43 | def __init__(self, dim): 44 | super().__init__() 45 | self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim) 46 | self.conv_spatial = nn.Conv2d( 47 | dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3) 48 | self.conv1 = nn.Conv2d(dim, dim, 1) 49 | 50 | def forward(self, x): 51 | u = x.clone() 52 | attn = self.conv0(x) 53 | attn = self.conv_spatial(attn) 54 | attn = self.conv1(attn) 55 | return u * attn 56 | 57 | 58 | class SpatialAttention(nn.Module): 59 | def __init__(self, d_model): 60 | super().__init__() 61 | self.d_model = d_model 62 | self.proj_1 = nn.Conv2d(d_model, d_model, 1) 63 | self.activation = nn.GELU() 64 | self.spatial_gating_unit = AttentionModule(d_model) 65 | self.proj_2 = nn.Conv2d(d_model, d_model, 1) 66 | 67 | def forward(self, x): 68 | shorcut = x.clone() 69 | x = self.proj_1(x) 70 | x = self.activation(x) 71 | x = self.spatial_gating_unit(x) 72 | x = self.proj_2(x) 73 | x = x + shorcut 74 | return x 75 | 76 | 77 | class Block(nn.Module): 78 | 79 | def __init__(self, 80 | dim, 81 | mlp_ratio=4., 82 | drop=0., 83 | drop_path=0., 84 | act_layer=nn.GELU, 85 | linear=False, 86 | norm_cfg=dict(type='SyncBN', requires_grad=True)): 87 | super().__init__() 88 | self.norm1 = build_norm_layer(norm_cfg, dim)[1] 89 | self.attn = SpatialAttention(dim) 90 | self.drop_path = DropPath( 91 | drop_path) if drop_path > 0. else nn.Identity() 92 | 93 | self.norm2 = build_norm_layer(norm_cfg, dim)[1] 94 | mlp_hidden_dim = int(dim * mlp_ratio) 95 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, 96 | act_layer=act_layer, drop=drop, linear=linear) 97 | layer_scale_init_value = 1e-2 98 | self.layer_scale_1 = nn.Parameter( 99 | layer_scale_init_value * torch.ones((dim)), requires_grad=True) 100 | self.layer_scale_2 = nn.Parameter( 101 | layer_scale_init_value * torch.ones((dim)), requires_grad=True) 102 | 103 | def forward(self, x, H, W): 104 | B, N, C = x.shape 105 | x = x.permute(0, 2, 1).view(B, C, H, W) 106 | x = x + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) 107 | * self.attn(self.norm1(x))) 108 | x = x + self.drop_path(self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) 109 | * self.mlp(self.norm2(x))) 110 | x = x.view(B, C, N).permute(0, 2, 1) 111 | return x 112 | 113 | 114 | class OverlapPatchEmbed(nn.Module): 115 | """ Image to Patch Embedding 116 | """ 117 | 118 | def __init__(self, 119 | patch_size=7, 120 | stride=4, 121 | in_chans=3, 122 | embed_dim=768, 123 | norm_cfg=dict(type='SyncBN', requires_grad=True)): 124 | super().__init__() 125 | patch_size = to_2tuple(patch_size) 126 | 127 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, 128 | padding=(patch_size[0] // 2, patch_size[1] // 2)) 129 | self.norm = build_norm_layer(norm_cfg, embed_dim)[1] 130 | 131 | def forward(self, x): 132 | x = self.proj(x) 133 | _, _, H, W = x.shape 134 | x = self.norm(x) 135 | 136 | x = x.flatten(2).transpose(1, 2) 137 | 138 | return x, H, W 139 | 140 | 141 | @BACKBONES.register_module() 142 | class VAN(BaseModule): 143 | def __init__(self, 144 | in_chans=3, 145 | embed_dims=[64, 128, 256, 512], 146 | mlp_ratios=[8, 8, 4, 4], 147 | drop_rate=0., 148 | drop_path_rate=0., 149 | depths=[3, 4, 6, 3], 150 | num_stages=4, 151 | linear=False, 152 | pretrained=None, 153 | init_cfg=None, 154 | norm_cfg=dict(type='SyncBN', requires_grad=True)): 155 | super(VAN, self).__init__(init_cfg=init_cfg) 156 | 157 | assert not (init_cfg and pretrained), \ 158 | 'init_cfg and pretrained cannot be set at the same time' 159 | if isinstance(pretrained, str): 160 | warnings.warn('DeprecationWarning: pretrained is deprecated, ' 161 | 'please use "init_cfg" instead') 162 | self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) 163 | elif pretrained is not None: 164 | raise TypeError('pretrained must be a str or None') 165 | 166 | self.depths = depths 167 | self.num_stages = num_stages 168 | self.linear = linear 169 | 170 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, 171 | sum(depths))] # stochastic depth decay rule 172 | cur = 0 173 | 174 | for i in range(num_stages): 175 | patch_embed = OverlapPatchEmbed(patch_size=7 if i == 0 else 3, 176 | stride=4 if i == 0 else 2, 177 | in_chans=in_chans if i == 0 else embed_dims[i - 1], 178 | embed_dim=embed_dims[i]) 179 | 180 | block = nn.ModuleList([Block(dim=embed_dims[i], 181 | mlp_ratio=mlp_ratios[i], 182 | drop=drop_rate, 183 | drop_path=dpr[cur + j], 184 | linear=linear, 185 | norm_cfg=norm_cfg) 186 | for j in range(depths[i])]) 187 | norm = nn.LayerNorm(embed_dims[i]) 188 | cur += depths[i] 189 | 190 | setattr(self, f"patch_embed{i + 1}", patch_embed) 191 | setattr(self, f"block{i + 1}", block) 192 | setattr(self, f"norm{i + 1}", norm) 193 | 194 | def init_weights(self): 195 | print('init cfg', self.init_cfg) 196 | if self.init_cfg is None: 197 | for m in self.modules(): 198 | if isinstance(m, nn.Linear): 199 | trunc_normal_init(m, std=.02, bias=0.) 200 | elif isinstance(m, nn.LayerNorm): 201 | constant_init(m, val=1.0, bias=0.) 202 | elif isinstance(m, nn.Conv2d): 203 | fan_out = m.kernel_size[0] * m.kernel_size[ 204 | 1] * m.out_channels 205 | fan_out //= m.groups 206 | normal_init( 207 | m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0) 208 | else: 209 | super(VAN, self).init_weights() 210 | 211 | def forward(self, x): 212 | B = x.shape[0] 213 | outs = [] 214 | 215 | for i in range(self.num_stages): 216 | patch_embed = getattr(self, f"patch_embed{i + 1}") 217 | block = getattr(self, f"block{i + 1}") 218 | norm = getattr(self, f"norm{i + 1}") 219 | x, H, W = patch_embed(x) 220 | for blk in block: 221 | x = blk(x, H, W) 222 | x = norm(x) 223 | x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() 224 | outs.append(x) 225 | 226 | return outs 227 | 228 | 229 | class DWConv(nn.Module): 230 | def __init__(self, dim=768): 231 | super(DWConv, self).__init__() 232 | self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) 233 | 234 | def forward(self, x): 235 | x = self.dwconv(x) 236 | return x 237 | --------------------------------------------------------------------------------