├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── NOTICE ├── README.md ├── builder.py ├── configs ├── deit_unept_ade20k.py ├── deit_unept_pcontext.py └── res50_unept_ade20k.py ├── models ├── UN_EPT.py ├── __init__.py ├── base.py ├── context_branch.py ├── deformable_attn.py ├── ops │ ├── functions │ │ ├── __init__.py │ │ └── ms_deform_attn_func.py │ ├── modules │ │ ├── __init__.py │ │ └── ms_deform_attn.py │ └── src │ │ ├── cuda │ │ ├── ms_deform_attn_cuda.cu │ │ ├── ms_deform_attn_cuda.h │ │ └── ms_deform_im2col_cuda.cuh │ │ ├── ms_deform_attn.h │ │ └── vision.cpp ├── spatial_branch.py └── vision_transformer.py ├── modified_mmseg ├── __init__.py ├── apis │ ├── __init__.py │ └── train.py └── datasets │ ├── __init__.py │ ├── ade.py │ ├── builder.py │ ├── cityscapes.py │ ├── custom.py │ ├── dataset_wrappers.py │ ├── pascal_context.py │ └── pipelines │ ├── __init__.py │ ├── compose.py │ ├── formating.py │ ├── loading.py │ ├── test_time_aug.py │ └── transforms.py ├── requirements.txt ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | modified_mmseg/**/__pycache__ 2 | 3 | work_dirs/ -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unified-EPT 2 | 3 | Code for the ICCV 2021 Workshop paper: [A Unified Efficient Pyramid Transformer for Semantic Segmentation](https://openaccess.thecvf.com/content/ICCV2021W/VSPW/papers/Zhu_A_Unified_Efficient_Pyramid_Transformer_for_Semantic_Segmentation_ICCVW_2021_paper.pdf). 4 | 5 | ## Installation 6 | 7 | * Linux, CUDA>=10.0, GCC>=5.4 8 | * Python>=3.7 9 | * Create a conda environment: 10 | 11 | ```bash 12 | conda create -n unept python=3.7 pip 13 | ``` 14 | 15 | Then, activate the environment: 16 | ```bash 17 | conda activate unept 18 | ``` 19 | * PyTorch>=1.5.1, torchvision>=0.6.1 (following instructions [here](https://pytorch.org/)) 20 | 21 | For example: 22 | ``` 23 | conda install pytorch==1.5.1 torchvision==0.6.1 cudatoolkit=10.2 -c pytorch 24 | ``` 25 | 26 | * Install [MMCV](https://mmcv.readthedocs.io/en/latest/), [MMSegmentation](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/install.md), [timm](https://pypi.org/project/timm/) 27 | 28 | ``` 29 | pip install -r requirements.txt 30 | ``` 31 | 32 | * Install [Deformable DETR](https://github.com/fundamentalvision/Deformable-DETR) and compile the CUDA operators 33 | (the instructions can be found [here](https://github.com/fundamentalvision/Deformable-DETR#installation)). 34 | 35 | 36 | 37 | ## Data Preparation 38 | Please following the code from [openseg](https://github.com/openseg-group/openseg.pytorch) to generate ground truth for boundary refinement. 39 | 40 | The data format should be like this. 41 | 42 | ### ADE20k 43 | You can download the processed ```dt_offset``` file [here](https://drive.google.com/drive/folders/1UKIXzc6hHQUfNqynZtcgSjSnGpQJ0GLs?usp=sharing). 44 | 45 | ``` 46 | path/to/ADEChallengeData2016/ 47 | images/ 48 | training/ 49 | validation/ 50 | annotations/ 51 | training/ 52 | validation/ 53 | dt_offset/ 54 | training/ 55 | validation/ 56 | ``` 57 | ### PASCAL-Context 58 | You can download the processed dataset [here](https://drive.google.com/file/d/18-3ySBQEZcBfr0Rs3_mWJWo2jNzyS6VO/view?usp=sharing). 59 | 60 | ``` 61 | path/to/PASCAL-Context/ 62 | train/ 63 | image/ 64 | label/ 65 | dt_offset/ 66 | val/ 67 | image/ 68 | label/ 69 | dt_offset/ 70 | ``` 71 | 72 | ## Usage 73 | ### Training 74 | **The default is for multi-gpu, DistributedDataParallel training.** 75 | 76 | ``` 77 | python -m torch.distributed.launch --nproc_per_node=8 \ # specify gpu number 78 | --master_port=29500 \ 79 | train.py --launcher pytorch \ 80 | --config /path/to/config_file 81 | ``` 82 | 83 | - specify the ```data_root``` in the config file; 84 | - log dir will be created in ```./work_dirs```; 85 | - download the [DeiT pretrained model](https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth) and specify the ```pretrained``` path in the config file. 86 | 87 | 88 | ### Evaluation 89 | 90 | ``` 91 | # single-gpu testing 92 | python test.py --checkpoint /path/to/checkpoint \ 93 | --config /path/to/config_file \ 94 | --eval mIoU \ 95 | [--out ${RESULT_FILE}] [--show] \ 96 | --aug-test \ # for multi-scale flip aug 97 | 98 | # multi-gpu testing (4 gpus, 1 sample per gpu) 99 | python -m torch.distributed.launch --nproc_per_node=4 --master_port=29500 \ 100 | test.py --launcher pytorch --eval mIoU \ 101 | --config_file /path/to/config_file \ 102 | --checkpoint /path/to/checkpoint \ 103 | --aug-test \ # for multi-scale flip aug 104 | ``` 105 | 106 | ## Results 107 | We report results on validation sets. 108 | 109 | | Backbone | Crop Size | Batch Size | Dataset | Lr schd | Mem(GB) | mIoU(ms+flip) | config | 110 | | :------: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | 111 | | Res-50 | 480x480 | 16 | ADE20K | 160K | 7.0G | 46.1 | [config](configs/res50_unept_ade20k.py) | 112 | | DeiT | 480x480 | 16 | ADE20K | 160K | 8.5G | 50.5 | [config](configs/deit_unept_ade20k.py) | 113 | | DeiT | 480x480 | 16 | PASCAL-Context | 160K | 8.5G | 55.2 | [config](configs/deit_unept_pcontext.py) | 114 | 115 | ## Security 116 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. 117 | 118 | 119 | ## License 120 | 121 | This project is licensed under the Apache-2.0 License. 122 | 123 | ## Citation 124 | 125 | If you use this code and models for your research, please consider citing: 126 | 127 | ``` 128 | @article{zhu2021unified, 129 | title={A Unified Efficient Pyramid Transformer for Semantic Segmentation}, 130 | author={Zhu, Fangrui and Zhu, Yi and Zhang, Li and Wu, Chongruo and Fu, Yanwei and Li, Mu}, 131 | journal={arXiv preprint arXiv:2107.14209}, 132 | year={2021} 133 | } 134 | ``` 135 | 136 | ## Acknowledgment 137 | 138 | We thank the authors and contributors of [MMCV](https://mmcv.readthedocs.io/en/latest/), [MMSegmentation](https://github.com/open-mmlab/mmsegmentation/blob/master/docs/install.md), [timm](https://pypi.org/project/timm/) and [Deformable DETR](https://github.com/fundamentalvision/Deformable-DETR). 139 | -------------------------------------------------------------------------------- /builder.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Modified from mmsegmentation (https://github.com/open-mmlab/mmsegmentation) 3 | # Apache-2.0 License 4 | # Copyright (c) Open-MMLab. 5 | # ------------------------------------------------------------------------------------------------ 6 | 7 | from mmcv.utils import Registry, build_from_cfg 8 | from torch import nn 9 | 10 | BACKBONES = Registry('backbone') 11 | NECKS = Registry('neck') 12 | HEADS = Registry('head') 13 | LOSSES = Registry('loss') 14 | SEGMENTORS = Registry('segmentor') 15 | 16 | 17 | def build(cfg, registry, default_args=None): 18 | """Build a module. 19 | 20 | Args: 21 | cfg (dict, list[dict]): The config of modules, is is either a dict 22 | or a list of configs. 23 | registry (:obj:`Registry`): A registry the module belongs to. 24 | default_args (dict, optional): Default arguments to build the module. 25 | Defaults to None. 26 | 27 | Returns: 28 | nn.Module: A built nn module. 29 | """ 30 | 31 | if isinstance(cfg, list): 32 | import ipdb; ipdb.set_trace() 33 | modules = [ 34 | build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg 35 | ] 36 | return nn.Sequential(*modules) 37 | else: 38 | return build_from_cfg(cfg, registry, default_args) 39 | 40 | 41 | def build_backbone(cfg): 42 | """Build backbone.""" 43 | return build(cfg, BACKBONES) 44 | 45 | 46 | def build_neck(cfg): 47 | """Build neck.""" 48 | return build(cfg, NECKS) 49 | 50 | 51 | def build_head(cfg): 52 | """Build head.""" 53 | return build(cfg, HEADS) 54 | 55 | 56 | def build_loss(cfg): 57 | """Build loss.""" 58 | return build(cfg, LOSSES) 59 | 60 | 61 | def build_segmentor(cfg, train_cfg=None, test_cfg=None): 62 | """Build segmentor.""" 63 | return build(cfg, SEGMENTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg)) 64 | -------------------------------------------------------------------------------- /configs/deit_unept_ade20k.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'ADE20KDataset' 3 | data_root = '/home/ubuntu/dataset/ADE20K/ADEChallengeData2016' 4 | img_norm_cfg = dict( 5 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 6 | crop_size = (480, 480) 7 | train_pipeline = [ 8 | dict(type='LoadImageFromFile'), 9 | dict(type='LoadAnnotations', reduce_zero_label=True), 10 | dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), 11 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 12 | dict(type='RandomFlip', prob=0.5), 13 | dict(type='PhotoMetricDistortion'), 14 | dict(type='Normalize', **img_norm_cfg), 15 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), 16 | dict(type='DefaultFormatBundle'), 17 | dict(type='Collect', keys=['img', 'gt_semantic_seg', 'distance_map', 'angle_map']), 18 | ] 19 | test_pipeline = [ 20 | dict(type='LoadImageFromFile'), 21 | dict( 22 | type='MultiScaleFlipAug', 23 | img_scale=(2048, 512), 24 | img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 25 | flip=True, 26 | transforms=[ 27 | dict(type='Resize', keep_ratio=True), 28 | dict(type='RandomFlip'), 29 | dict(type='Normalize', **img_norm_cfg), 30 | dict(type='ImageToTensor', keys=['img']), 31 | dict(type='Collect', keys=['img']), 32 | ]) 33 | ] 34 | data = dict( 35 | samples_per_gpu=2, 36 | workers_per_gpu=4, 37 | train=dict( 38 | type=dataset_type, 39 | data_root=data_root, 40 | img_dir='images/training', 41 | ann_dir='annotations/training', 42 | dt_dir='dt_offset/training', 43 | pipeline=train_pipeline), 44 | val=dict( 45 | type=dataset_type, 46 | data_root=data_root, 47 | img_dir='images/validation', 48 | ann_dir='annotations/validation', 49 | dt_dir='dt_offset/validation', 50 | pipeline=test_pipeline), 51 | test=dict( 52 | type=dataset_type, 53 | data_root=data_root, 54 | img_dir='images/validation', 55 | ann_dir='annotations/validation', 56 | dt_dir='dt_offset/validation', 57 | pipeline=test_pipeline)) 58 | 59 | # model settings 60 | norm_cfg = dict(type='SyncBN', requires_grad=True) 61 | model = dict( 62 | type='UN_EPT', 63 | feat_dim=256, 64 | k=16, 65 | L=3, 66 | dropout=0.1, 67 | heads=8, 68 | hidden_dim=2048, 69 | depth=2, 70 | pretrained='deit_base_distilled_patch16_384-d0272ac0.pth', 71 | backbone_cfg=dict( 72 | type='DeiT', 73 | img_size=480, 74 | patch_size=16, 75 | embed_dim=768, 76 | bb_depth=12, 77 | num_heads=12, 78 | mlp_ratio=4), 79 | loss_decode=dict( 80 | type='CrossEntropyLoss', 81 | use_sigmoid=False, 82 | loss_weight=1.0)) 83 | # model training and testing settings 84 | train_cfg = dict() 85 | test_cfg = dict(mode='slide', num_classes=150, stride=(160,160), crop_size=(480, 480), num_queries=3600) 86 | 87 | # optimizer 88 | optimizer = dict(type='AdamW', lr=0.0001, weight_decay=0.0001, betas=(0.9, 0.999), eps=1e-8, 89 | paramwise_cfg = dict(custom_keys={'backbone': dict(lr_mult=0.1)})) 90 | 91 | optimizer_config = dict() 92 | # learning policy 93 | lr_config = dict(policy='step', step=126000, by_epoch=False) 94 | # runtime settings 95 | # total_iters = 640000 96 | runner = dict(type='IterBasedRunner', max_iters=160000) 97 | checkpoint_config = dict(by_epoch=False, interval=10000) 98 | evaluation = dict(interval=10000, metric='mIoU') 99 | 100 | 101 | # yapf:disable 102 | log_config = dict( 103 | interval=200, 104 | hooks=[ 105 | dict(type='TextLoggerHook', by_epoch=False), 106 | # dict(type='TensorboardLoggerHook') 107 | ]) 108 | # yapf:enable 109 | dist_params = dict(backend='nccl') 110 | log_level = 'INFO' 111 | load_from = None 112 | resume_from = None 113 | workflow = [('train', 1)] 114 | cudnn_benchmark = True -------------------------------------------------------------------------------- /configs/deit_unept_pcontext.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'PascalContextDataset' 3 | data_root = '/home/ubuntu/dataset/PASCAL_Context/' 4 | img_norm_cfg = dict( 5 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 6 | img_scale = (520, 520) 7 | crop_size = (480, 480) 8 | 9 | train_pipeline = [ 10 | dict(type='LoadImageFromFile'), 11 | dict(type='LoadAnnotations'), 12 | dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), 13 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 14 | dict(type='RandomFlip', prob=0.5), 15 | dict(type='PhotoMetricDistortion'), 16 | dict(type='Normalize', **img_norm_cfg), 17 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), 18 | dict(type='DefaultFormatBundle'), 19 | dict(type='Collect', keys=['img', 'gt_semantic_seg', 'distance_map', 'angle_map']), 20 | ] 21 | test_pipeline = [ 22 | dict(type='LoadImageFromFile'), 23 | dict( 24 | type='MultiScaleFlipAug', 25 | img_scale=img_scale, 26 | img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 27 | flip=True, 28 | transforms=[ 29 | dict(type='Resize', keep_ratio=True), 30 | dict(type='RandomFlip'), 31 | dict(type='Normalize', **img_norm_cfg), 32 | dict(type='ImageToTensor', keys=['img']), 33 | dict(type='Collect', keys=['img']), 34 | ]) 35 | ] 36 | 37 | data = dict( 38 | samples_per_gpu=2, 39 | workers_per_gpu=4, 40 | train=dict( 41 | type=dataset_type, 42 | data_root=data_root, 43 | img_dir='train/image', 44 | ann_dir='train/label', 45 | dt_dir='train/dt_offset', 46 | pipeline=train_pipeline), 47 | val=dict( 48 | type=dataset_type, 49 | data_root=data_root, 50 | img_dir='val/image', 51 | ann_dir='val/label', 52 | dt_dir='val/dt_offset', 53 | pipeline=test_pipeline), 54 | test=dict( 55 | type=dataset_type, 56 | data_root=data_root, 57 | img_dir='val/image', 58 | ann_dir='val/label', 59 | dt_dir='val/dt_offset', 60 | pipeline=test_pipeline)) 61 | 62 | # model settings 63 | norm_cfg = dict(type='SyncBN', requires_grad=True) 64 | model = dict( 65 | type='UN_EPT', 66 | feat_dim=256, 67 | k=16, 68 | L=3, 69 | dropout=0.1, 70 | heads=8, 71 | hidden_dim=2048, 72 | depth=2, 73 | pretrained='deit_base_distilled_patch16_384-d0272ac0.pth', 74 | backbone_cfg=dict( 75 | type='DeiT', 76 | img_size=480, 77 | patch_size=16, 78 | embed_dim=768, 79 | bb_depth=12, 80 | num_heads=12, 81 | mlp_ratio=4), 82 | loss_decode=dict( 83 | type='CrossEntropyLoss', 84 | use_sigmoid=False, 85 | loss_weight=1.0)) 86 | # model training and testing settings 87 | train_cfg = dict() 88 | test_cfg = dict(mode='slide', num_classes=60, stride=(160,160), crop_size=(480, 480), num_queries=3600) 89 | 90 | # optimizer 91 | optimizer = dict(type='AdamW', lr=0.0001, weight_decay=0.0001, betas=(0.9, 0.999), eps=1e-8, 92 | paramwise_cfg = dict(custom_keys={'backbone': dict(lr_mult=0.1)})) 93 | 94 | optimizer_config = dict() 95 | # learning policy 96 | lr_config = dict(policy='step', step=126000, by_epoch=False) 97 | # runtime settings 98 | 99 | runner = dict(type='IterBasedRunner', max_iters=160000) 100 | checkpoint_config = dict(by_epoch=False, interval=10000) 101 | evaluation = dict(interval=10000, metric='mIoU') 102 | 103 | 104 | # yapf:disable 105 | log_config = dict( 106 | interval=200, 107 | hooks=[ 108 | dict(type='TextLoggerHook', by_epoch=False), 109 | # dict(type='TensorboardLoggerHook') 110 | ]) 111 | # yapf:enable 112 | dist_params = dict(backend='nccl') 113 | log_level = 'INFO' 114 | load_from = None 115 | resume_from = None 116 | workflow = [('train', 1)] 117 | cudnn_benchmark = True -------------------------------------------------------------------------------- /configs/res50_unept_ade20k.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | dataset_type = 'ADE20KDataset' 3 | data_root = '/home/ubuntu/dataset/ADE20K/ADEChallengeData2016' 4 | img_norm_cfg = dict( 5 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 6 | crop_size = (480, 480) 7 | train_pipeline = [ 8 | dict(type='LoadImageFromFile'), 9 | dict(type='LoadAnnotations', reduce_zero_label=True), 10 | dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), 11 | dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), 12 | dict(type='RandomFlip', prob=0.5), 13 | dict(type='PhotoMetricDistortion'), 14 | dict(type='Normalize', **img_norm_cfg), 15 | dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), 16 | dict(type='DefaultFormatBundle'), 17 | dict(type='Collect', keys=['img', 'gt_semantic_seg', 'distance_map', 'angle_map']), 18 | ] 19 | test_pipeline = [ 20 | dict(type='LoadImageFromFile'), 21 | dict( 22 | type='MultiScaleFlipAug', 23 | img_scale=(2048, 512), 24 | img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 25 | flip=True, 26 | transforms=[ 27 | dict(type='Resize', keep_ratio=True), 28 | dict(type='RandomFlip'), 29 | dict(type='Normalize', **img_norm_cfg), 30 | dict(type='ImageToTensor', keys=['img']), 31 | dict(type='Collect', keys=['img']), 32 | ]) 33 | ] 34 | data = dict( 35 | samples_per_gpu=2, 36 | workers_per_gpu=4, 37 | train=dict( 38 | type=dataset_type, 39 | data_root=data_root, 40 | img_dir='images/training', 41 | ann_dir='annotations/training', 42 | dt_dir='dt_offset/training', 43 | pipeline=train_pipeline), 44 | val=dict( 45 | type=dataset_type, 46 | data_root=data_root, 47 | img_dir='images/validation', 48 | ann_dir='annotations/validation', 49 | dt_dir='dt_offset/validation', 50 | pipeline=test_pipeline), 51 | test=dict( 52 | type=dataset_type, 53 | data_root=data_root, 54 | img_dir='images/validation', 55 | ann_dir='annotations/validation', 56 | dt_dir='dt_offset/validation', 57 | pipeline=test_pipeline)) 58 | 59 | # model settings 60 | norm_cfg = dict(type='SyncBN', requires_grad=True) 61 | model = dict( 62 | type='UN_EPT', 63 | feat_dim=256, 64 | k=16, 65 | L=3, 66 | dropout=0.1, 67 | heads=8, 68 | hidden_dim=2048, 69 | depth=2, 70 | pretrained='open-mmlab://resnet50_v1c', 71 | backbone_cfg=dict( 72 | type='ResNetV1c', 73 | depth=50, 74 | num_stages=4, 75 | out_indices=(0, 1, 2, 3), 76 | dilations=(1, 1, 2, 4), 77 | strides=(1, 2, 1, 1), 78 | norm_cfg=norm_cfg, 79 | norm_eval=False, 80 | style='pytorch', 81 | contract_dilation=True), 82 | loss_decode=dict( 83 | type='CrossEntropyLoss', 84 | use_sigmoid=False, 85 | loss_weight=1.0)) 86 | # model training and testing settings 87 | train_cfg = dict() 88 | test_cfg = dict(mode='slide', num_classes=150, stride=(160,160), crop_size=(480, 480), num_queries=3600) 89 | # test_cfg = dict(mode='whole', num_classes=150, num_queries=4096) 90 | 91 | 92 | # optimizer 93 | optimizer = dict(type='AdamW', lr=0.0001, weight_decay=0.0001, betas=(0.9, 0.999), eps=1e-8, 94 | paramwise_cfg = dict(custom_keys={'backbone': dict(lr_mult=0.1)})) 95 | 96 | optimizer_config = dict() 97 | # learning policy 98 | lr_config = dict(policy='step', step=126000, by_epoch=False) 99 | # runtime settings 100 | # total_iters = 640000 101 | runner = dict(type='IterBasedRunner', max_iters=160000) 102 | checkpoint_config = dict(by_epoch=False, interval=10000) 103 | evaluation = dict(interval=10000, metric='mIoU') 104 | 105 | 106 | # yapf:disable 107 | log_config = dict( 108 | interval=200, 109 | hooks=[ 110 | dict(type='TextLoggerHook', by_epoch=False), 111 | # dict(type='TensorboardLoggerHook') 112 | ]) 113 | # yapf:enable 114 | dist_params = dict(backend='nccl') 115 | log_level = 'INFO' 116 | load_from = None 117 | resume_from = None 118 | workflow = [('train', 1)] 119 | cudnn_benchmark = True 120 | 121 | -------------------------------------------------------------------------------- /models/UN_EPT.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Modified from mmsegmentation (https://github.com/open-mmlab/mmsegmentation) 3 | # Apache-2.0 License 4 | # Copyright (c) Open-MMLab. 5 | # and 6 | # openseg.pytorch (https://github.com/openseg-group/openseg.pytorch) 7 | # MIT License 8 | # ------------------------------------------------------------------------------------------------ 9 | 10 | import math 11 | import numpy as np 12 | import warnings 13 | import logging 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | import torch.distributed as dist 19 | 20 | from mmseg.ops import resize 21 | from mmcv.runner import force_fp32 22 | from mmseg.models.builder import build_loss, build_backbone 23 | from mmseg.models.losses import accuracy 24 | 25 | from .vision_transformer import deit_base_distilled_patch16_384 26 | from .spatial_branch import spatial_branch 27 | from .context_branch import context_branch 28 | from .base import base_segmentor 29 | 30 | import sys 31 | sys.path.append("..") 32 | from builder import SEGMENTORS 33 | 34 | 35 | @SEGMENTORS.register_module() 36 | class UN_EPT(base_segmentor): 37 | def __init__(self, 38 | heads, 39 | feat_dim, 40 | k, 41 | L, 42 | dropout, 43 | depth, 44 | hidden_dim, 45 | backbone_cfg, 46 | loss_decode=dict( 47 | type='CrossEntropyLoss', 48 | use_sigmoid=False, 49 | loss_weight=1.0), 50 | ignore_index=255, 51 | activation="relu", 52 | train_cfg=None, 53 | test_cfg=None, 54 | pretrained=None, 55 | auxiliary_head=None): 56 | 57 | """ 58 | params: 59 | heads: head number of the transformer in the context branch; 60 | feat_dim: input feature dimension of the context branch; 61 | k: #points for each scale; 62 | L: #scale; 63 | depth: transformer encoder/decoder number in the context branch; 64 | hidden_dim: transforme hidden dimension in the context branch. 65 | 66 | """ 67 | 68 | super(UN_EPT, self).__init__() 69 | 70 | self.train_cfg = train_cfg 71 | self.test_cfg = test_cfg 72 | 73 | if self.test_cfg.mode == 'stride': 74 | self.test_cfg.stride = test_cfg.stride 75 | self.test_cfg.crop_size = test_cfg.crop_size 76 | self.num_classes = self.test_cfg.num_classes 77 | self.ignore_index = ignore_index 78 | self.align_corners = False 79 | self.feat_dim = feat_dim 80 | 81 | self.loss_decode = build_loss(loss_decode) 82 | 83 | if pretrained is not None: 84 | logger = logging.getLogger() 85 | logger.info(f'load model from: {pretrained}') 86 | 87 | if backbone_cfg.type == 'DeiT': 88 | self.backbone = deit_base_distilled_patch16_384( 89 | img_size=backbone_cfg.img_size, 90 | patch_size=backbone_cfg.patch_size, 91 | embed_dim=backbone_cfg.embed_dim, 92 | depth=backbone_cfg.bb_depth, 93 | num_heads=backbone_cfg.num_heads, 94 | mlp_ratio=backbone_cfg.mlp_ratio, 95 | pretrained=pretrained) 96 | elif backbone_cfg.type == 'ResNetV1c': 97 | self.backbone = build_backbone(backbone_cfg) 98 | self.backbone.init_weights(pretrained=pretrained) 99 | 100 | 101 | self.cls = nn.Conv2d(feat_dim, self.num_classes, kernel_size=1) 102 | 103 | # get pyramid features 104 | self.layers = nn.ModuleList([]) 105 | self.backbone_type = backbone_cfg.type 106 | if self.backbone_type == 'DeiT': 107 | self.layers.append(nn.Conv2d(backbone_cfg.embed_dim, feat_dim, kernel_size=1, stride=1)) 108 | self.layers.append(nn.Conv2d(backbone_cfg.embed_dim, feat_dim, kernel_size=1, stride=1)) 109 | self.layers.append(nn.Conv2d(backbone_cfg.embed_dim, feat_dim, kernel_size=1, stride=1)) 110 | elif self.backbone_type == 'ResNetV1c': 111 | self.layers.append(nn.Conv2d(512, feat_dim, kernel_size=1, stride=1)) 112 | self.layers.append(nn.Conv2d(1024, feat_dim, kernel_size=1, stride=1)) 113 | self.layers.append(nn.Conv2d(2048, feat_dim, kernel_size=1, stride=1)) 114 | 115 | self.context_branch = context_branch(d_model=feat_dim, nhead=heads, 116 | num_encoder_layers=depth, num_decoder_layers=depth, dim_feedforward=hidden_dim, dropout=dropout, 117 | activation=activation, num_feature_levels=L, dec_n_points=k, enc_n_points=k) 118 | 119 | self.num_queries = self.test_cfg.num_queries 120 | self.query_embed = nn.Embedding(self.num_queries, feat_dim) 121 | self.spatial_branch = spatial_branch() 122 | 123 | self.dir_head = nn.Sequential(nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0, bias=False), 124 | nn.SyncBatchNorm(256), 125 | nn.ReLU(), 126 | nn.Conv2d(256, 8, kernel_size=1, stride=1, padding=0, bias=False)) 127 | 128 | self.mask_head = nn.Sequential(nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0, bias=False), 129 | nn.SyncBatchNorm(256), 130 | nn.ReLU(), 131 | nn.Conv2d(256, 2, kernel_size=1, stride=1, padding=0, bias=False)) 132 | 133 | 134 | def encode_decode(self, x): 135 | 136 | bsize, c, h, w = x.shape 137 | backbone_feats = self.backbone(x) 138 | if self.backbone_type == 'ResNetV1c': 139 | backbone_feats = backbone_feats[1:] 140 | 141 | context = self.spatial_branch(x) 142 | 143 | mask_map = self.mask_head(context) 144 | dir_map = self.dir_head(context) 145 | 146 | context = context.flatten(2).permute(2, 0, 1) 147 | 148 | pyramid_feats = [] 149 | for i, conv_layer in enumerate(self.layers): 150 | feature = conv_layer(backbone_feats[i]) 151 | pyramid_feats.append(feature) 152 | 153 | q_H = q_W = int(math.sqrt(self.num_queries)) 154 | out = self.context_branch(pyramid_feats, context, self.query_embed.weight, q_H, q_W) 155 | 156 | out = out.unsqueeze(0).reshape([h//8, w//8, bsize, self.feat_dim]).permute(2, 3, 0, 1) 157 | 158 | out = self.cls(out) 159 | 160 | seg_logits = resize( 161 | input=out, 162 | size=x.shape[2:], 163 | mode='bilinear', 164 | align_corners=self.align_corners) 165 | 166 | return seg_logits, mask_map, dir_map 167 | 168 | def distance_to_mask_label(self, distance_map, seg_label_map, return_tensor=False): 169 | 170 | if return_tensor: 171 | assert isinstance(distance_map, torch.Tensor) 172 | assert isinstance(seg_label_map, torch.Tensor) 173 | else: 174 | assert isinstance(distance_map, np.ndarray) 175 | assert isinstance(seg_label_map, np.ndarray) 176 | 177 | if return_tensor: 178 | mask_label_map = torch.zeros_like(seg_label_map).long().to(distance_map.device) 179 | else: 180 | mask_label_map = np.zeros(seg_label_map.shape, dtype=np.int) 181 | 182 | keep_mask = (distance_map <= 5) & (distance_map >= 0) 183 | mask_label_map[keep_mask] = 1 184 | mask_label_map[seg_label_map == -1] = -1 185 | 186 | return mask_label_map 187 | 188 | def calc_weights(self, label_map, num_classes): 189 | 190 | weights = [] 191 | for i in range(num_classes): 192 | weights.append((label_map == i).sum().data) 193 | weights = torch.FloatTensor(weights) 194 | weights_sum = weights.sum() 195 | return (1 - weights / weights_sum).cuda() 196 | 197 | def angle_to_direction_label(self, angle_map, seg_label_map=None, distance_map=None, num_classes=8, extra_ignore_mask=None, return_tensor=False): 198 | 199 | if return_tensor: 200 | assert isinstance(angle_map, torch.Tensor) 201 | assert isinstance(seg_label_map, torch.Tensor) or seg_label_map is None 202 | else: 203 | assert isinstance(angle_map, np.ndarray) 204 | assert isinstance(seg_label_map, np.ndarray) or seg_label_map is None 205 | 206 | _, label_map = self.align_angle(angle_map, num_classes=num_classes, return_tensor=return_tensor) 207 | if distance_map is not None: 208 | label_map[distance_map > 5] = num_classes 209 | if seg_label_map is None: 210 | if return_tensor: 211 | ignore_mask = torch.zeros(angle_map.shape, dtype=torch.uint8).to(angle_map.device) 212 | else: 213 | ignore_mask = np.zeros(angle_map.shape, dtype=np.bool) 214 | else: 215 | ignore_mask = seg_label_map == -1 216 | 217 | if extra_ignore_mask is not None: 218 | extra_ignore_mask = extra_ignore_mask.unsqueeze(1) 219 | ignore_mask = ignore_mask | extra_ignore_mask 220 | label_map[ignore_mask] = -1 221 | 222 | return label_map 223 | 224 | def align_angle(self, angle_map, 225 | num_classes=8, 226 | return_tensor=False): 227 | 228 | if return_tensor: 229 | assert isinstance(angle_map, torch.Tensor) 230 | else: 231 | assert isinstance(angle_map, np.ndarray) 232 | 233 | step = 360 / num_classes 234 | if return_tensor: 235 | new_angle_map = torch.zeros(angle_map.shape).float().to(angle_map.device) 236 | angle_index_map = torch.zeros(angle_map.shape).long().to(angle_map.device) 237 | else: 238 | new_angle_map = np.zeros(angle_map.shape, dtype=np.float) 239 | angle_index_map = np.zeros(angle_map.shape, dtype=np.int) 240 | mask = (angle_map <= (-180 + step/2)) | (angle_map > (180 - step/2)) 241 | new_angle_map[mask] = -180 242 | angle_index_map[mask] = 0 243 | 244 | for i in range(1, num_classes): 245 | middle = -180 + step * i 246 | mask = (angle_map > (middle - step / 2)) & (angle_map <= (middle + step / 2)) 247 | new_angle_map[mask] = middle 248 | angle_index_map[mask] = i 249 | 250 | return new_angle_map, angle_index_map 251 | 252 | def shift(self, x, offset): 253 | """ 254 | x: b x c x h x w 255 | offset: b x 2 x h x w 256 | """ 257 | def gen_coord_map(H, W): 258 | coord_vecs = [torch.arange(length, dtype=torch.float) for length in (H, W)] 259 | coord_h, coord_w = torch.meshgrid(coord_vecs) 260 | coord_h = coord_h.cuda() 261 | coord_w = coord_w.cuda() 262 | return coord_h, coord_w 263 | 264 | b, c, h, w = x.shape 265 | 266 | coord_map = gen_coord_map(h, w) 267 | norm_factor = torch.FloatTensor([(w-1)/2, (h-1)/2]).cuda() 268 | grid_h = offset[:, 0]+coord_map[0] 269 | grid_w = offset[:, 1]+coord_map[1] 270 | grid = torch.stack([grid_w, grid_h], dim=-1) / norm_factor - 1 271 | 272 | x = F.grid_sample(x.float(), grid, padding_mode='border', mode='bilinear', align_corners=True) 273 | 274 | return x 275 | 276 | def _get_offset(self, mask_logits, dir_logits): 277 | 278 | edge_mask = mask_logits[:, 1] > 0.5 279 | dir_logits = torch.softmax(dir_logits, dim=1) 280 | n, _, h, w = dir_logits.shape 281 | 282 | keep_mask = edge_mask 283 | 284 | dir_label = torch.argmax(dir_logits, dim=1).float() 285 | offset = self.label_to_vector(dir_label) 286 | offset = offset.permute(0, 2, 3, 1) 287 | offset[~keep_mask, :] = 0 288 | 289 | return offset 290 | 291 | def label_to_vector(self, labelmap, 292 | num_classes=8): 293 | 294 | assert isinstance(labelmap, torch.Tensor) 295 | 296 | label_to_vector_mapping = { 297 | 8: [ 298 | [0, -1], [-1, -1], [-1, 0], [-1, 1], 299 | [0, 1], [1, 1], [1, 0], [1, -1] 300 | ], 301 | 16: [ 302 | [0, -2], [-1, -2], [-2, -2], [-2, -1], 303 | [-2, 0], [-2, 1], [-2, 2], [-1, 2], 304 | [0, 2], [1, 2], [2, 2], [2, 1], 305 | [2, 0], [2, -1], [2, -2], [1, -2] 306 | ] 307 | } 308 | 309 | mapping = label_to_vector_mapping[num_classes] 310 | offset_h = torch.zeros_like(labelmap).long() 311 | offset_w = torch.zeros_like(labelmap).long() 312 | 313 | for idx, (hdir, wdir) in enumerate(mapping): 314 | mask = labelmap == idx 315 | offset_h[mask] = hdir 316 | offset_w[mask] = wdir 317 | 318 | return torch.stack([offset_h, offset_w], dim=-1).permute(0, 3, 1, 2).to(labelmap.device) 319 | 320 | def forward_train(self, img, img_metas, gt_semantic_seg, distance_map, angle_map): 321 | 322 | seg_logits, pred_mask, pred_direction = self.encode_decode(img) 323 | losses = dict() 324 | 325 | loss_decode = self.losses(seg_logits, pred_mask, pred_direction, gt_semantic_seg, distance_map, angle_map) 326 | losses.update(loss_decode) 327 | 328 | return losses 329 | 330 | @force_fp32(apply_to=('seg_logit', )) 331 | def losses(self, seg_logit, pred_mask, pred_direction, seg_label, distance_map, angle_map): 332 | """Compute segmentation loss.""" 333 | loss = dict() 334 | 335 | seg_weight = None 336 | 337 | gt_mask = self.distance_to_mask_label(distance_map, seg_label, return_tensor=True) 338 | gt_size = gt_mask.shape[2:] 339 | mask_weights = self.calc_weights(gt_mask, 2) 340 | 341 | pred_direction = F.interpolate(pred_direction, size=gt_size, mode="bilinear", align_corners=True) 342 | pred_mask = F.interpolate(pred_mask, size=gt_size, mode="bilinear", align_corners=True) 343 | mask_loss = F.cross_entropy(pred_mask, gt_mask[:,0], weight=mask_weights, ignore_index=-1) 344 | 345 | mask_threshold = 0.5 346 | binary_pred_mask = torch.softmax(pred_mask, dim=1)[:, 1, :, :] > mask_threshold 347 | 348 | gt_direction = self.angle_to_direction_label( 349 | angle_map, 350 | seg_label_map=seg_label, 351 | extra_ignore_mask=(binary_pred_mask == 0), 352 | return_tensor=True 353 | ) 354 | 355 | direction_loss_mask = gt_direction != -1 356 | direction_weights = self.calc_weights(gt_direction[direction_loss_mask], pred_direction.size(1)) 357 | direction_loss = F.cross_entropy(pred_direction, gt_direction[:,0], weight=direction_weights, ignore_index=-1) 358 | 359 | offset = self._get_offset(pred_mask, pred_direction) 360 | refine_map = self.shift(seg_logit, offset.permute(0,3,1,2)) 361 | 362 | seg_label = seg_label.squeeze(1) 363 | 364 | loss['loss_seg'] = 0.8*self.loss_decode( 365 | seg_logit, 366 | seg_label, 367 | weight=seg_weight, 368 | ignore_index=self.ignore_index) + 5*mask_loss + 0.6*direction_loss + \ 369 | self.loss_decode( 370 | refine_map, 371 | seg_label, 372 | weight=seg_weight, 373 | ignore_index=self.ignore_index) 374 | 375 | loss['acc_seg'] = accuracy(seg_logit, seg_label) 376 | 377 | return loss -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .UN_EPT import UN_EPT 2 | -------------------------------------------------------------------------------- /models/base.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Modified from mmsegmentation (https://github.com/open-mmlab/mmsegmentation) 3 | # Apache-2.0 License 4 | # Copyright (c) Open-MMLab. 5 | # ------------------------------------------------------------------------------------------------ 6 | 7 | 8 | from abc import ABCMeta, abstractmethod 9 | from collections import OrderedDict 10 | import math 11 | import numpy as np 12 | import warnings 13 | import logging 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | import torch.distributed as dist 19 | 20 | from mmseg.ops import resize 21 | 22 | 23 | class base_segmentor(nn.Module, metaclass=ABCMeta): 24 | def __init__(self): 25 | 26 | super(base_segmentor, self).__init__() 27 | 28 | @staticmethod 29 | def _parse_losses(losses): 30 | """Parse the raw outputs (losses) of the network. 31 | Args: 32 | losses (dict): Raw output of the network, which usually contain 33 | losses and other necessary information. 34 | Returns: 35 | tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor 36 | which may be a weighted sum of all losses, log_vars contains 37 | all the variables to be sent to the logger. 38 | """ 39 | log_vars = OrderedDict() 40 | for loss_name, loss_value in losses.items(): 41 | if isinstance(loss_value, torch.Tensor): 42 | log_vars[loss_name] = loss_value.mean() 43 | elif isinstance(loss_value, list): 44 | log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) 45 | else: 46 | raise TypeError( 47 | f'{loss_name} is not a tensor or list of tensors') 48 | 49 | loss = sum(_value for _key, _value in log_vars.items() 50 | if 'loss' in _key) 51 | 52 | log_vars['loss'] = loss 53 | 54 | for loss_name, loss_value in log_vars.items(): 55 | # reduce loss when distributed training 56 | if dist.is_available() and dist.is_initialized(): 57 | loss_value = loss_value.data.clone() 58 | dist.all_reduce(loss_value.div_(dist.get_world_size())) 59 | log_vars[loss_name] = loss_value.item() 60 | 61 | return loss, log_vars 62 | 63 | @abstractmethod 64 | def encode_decode(self, img): 65 | pass 66 | 67 | @abstractmethod 68 | def forward_train(self, img, img_metas, **kwargs): 69 | pass 70 | 71 | @abstractmethod 72 | def losses(self, seg_logit, seg_label): 73 | pass 74 | 75 | def train_step(self, data_batch, optimizer, **kwargs): 76 | """The iteration step during training. 77 | This method defines an iteration step during training, except for the 78 | back propagation and optimizer updating, which are done in an optimizer 79 | hook. Note that in some complicated cases or models, the whole process 80 | including back propagation and optimizer updating is also defined in 81 | this method, such as GAN. 82 | Args: 83 | data (dict): The output of dataloader. 84 | optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of 85 | runner is passed to ``train_step()``. This argument is unused 86 | and reserved. 87 | Returns: 88 | dict: It should contain at least 3 keys: ``loss``, ``log_vars``, 89 | ``num_samples``. 90 | ``loss`` is a tensor for back propagation, which can be a 91 | weighted sum of multiple losses. 92 | ``log_vars`` contains all the variables to be sent to the 93 | logger. 94 | ``num_samples`` indicates the batch size (when the model is 95 | DDP, it means the batch size on each GPU), which is used for 96 | averaging the logs. 97 | """ 98 | losses = self(**data_batch) 99 | loss, log_vars = self._parse_losses(losses) 100 | 101 | outputs = dict( 102 | loss=loss, 103 | log_vars=log_vars, 104 | num_samples=len(data_batch['img'].data)) 105 | return outputs 106 | 107 | def forward(self, img, img_metas, return_loss=True, **kwargs): 108 | if return_loss: 109 | return self.forward_train(img, img_metas, **kwargs) 110 | else: 111 | return self.forward_test(img, img_metas, **kwargs) 112 | 113 | def forward_test(self, imgs, img_metas, **kwargs): 114 | """ 115 | Args: 116 | imgs (List[Tensor]): the outer list indicates test-time 117 | augmentations and inner Tensor should have a shape NxCxHxW, 118 | which contains all images in the batch. 119 | img_metas (List[List[dict]]): the outer list indicates test-time 120 | augs (multiscale, flip, etc.) and the inner list indicates 121 | images in a batch. 122 | """ 123 | for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]: 124 | if not isinstance(var, list): 125 | raise TypeError(f'{name} must be a list, but got ' 126 | f'{type(var)}') 127 | 128 | num_augs = len(imgs) 129 | if num_augs != len(img_metas): 130 | raise ValueError(f'num of augmentations ({len(imgs)}) != ' 131 | f'num of image meta ({len(img_metas)})') 132 | # all images in the same aug batch all of the same ori_shape and pad 133 | # shape 134 | for img_meta in img_metas: 135 | ori_shapes = [_['ori_shape'] for _ in img_meta] 136 | assert all(shape == ori_shapes[0] for shape in ori_shapes) 137 | img_shapes = [_['img_shape'] for _ in img_meta] 138 | assert all(shape == img_shapes[0] for shape in img_shapes) 139 | pad_shapes = [_['pad_shape'] for _ in img_meta] 140 | assert all(shape == pad_shapes[0] for shape in pad_shapes) 141 | 142 | if num_augs == 1: 143 | return self.simple_test(imgs[0], img_metas[0], **kwargs) 144 | else: 145 | return self.aug_test(imgs, img_metas, **kwargs) 146 | 147 | def slide_inference(self, img, img_meta, rescale): 148 | """Inference by sliding-window with overlap.""" 149 | 150 | h_stride, w_stride = self.test_cfg.stride 151 | h_crop, w_crop = self.test_cfg.crop_size 152 | batch_size, _, h_img, w_img = img.size() 153 | num_classes = self.num_classes 154 | h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 155 | w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 156 | preds = img.new_zeros((batch_size, num_classes, h_img, w_img)) 157 | count_mat = img.new_zeros((batch_size, 1, h_img, w_img)) 158 | for h_idx in range(h_grids): 159 | for w_idx in range(w_grids): 160 | y1 = h_idx * h_stride 161 | x1 = w_idx * w_stride 162 | y2 = min(y1 + h_crop, h_img) 163 | x2 = min(x1 + w_crop, w_img) 164 | y1 = max(y2 - h_crop, 0) 165 | x1 = max(x2 - w_crop, 0) 166 | crop_img = img[:, :, y1:y2, x1:x2] 167 | pad_img = crop_img.new_zeros( 168 | (crop_img.size(0), crop_img.size(1), h_crop, w_crop)) 169 | pad_img[:, :, :y2 - y1, :x2 - x1] = crop_img 170 | if len(self.encode_decode(pad_img)) != 1: 171 | pad_seg_logit, _, _ = self.encode_decode(pad_img) 172 | else: 173 | pad_seg_logit = self.encode_decode(pad_img) 174 | preds[:, :, y1:y2, 175 | x1:x2] += pad_seg_logit[:, :, :y2 - y1, :x2 - x1] 176 | count_mat[:, :, y1:y2, x1:x2] += 1 177 | assert (count_mat == 0).sum() == 0 178 | preds = preds / count_mat 179 | if rescale: 180 | preds = resize( 181 | preds, 182 | size=img_meta[0]['ori_shape'][:2], 183 | mode='bilinear', 184 | align_corners=self.align_corners, 185 | warning=False) 186 | 187 | return preds 188 | 189 | def whole_inference(self, img, img_meta, rescale): 190 | """Inference with full image.""" 191 | 192 | seg_logit = self.encode_decode(img) 193 | if rescale: 194 | seg_logit = resize( 195 | seg_logit, 196 | size=img_meta[0]['ori_shape'][:2], 197 | mode='bilinear', 198 | align_corners=self.align_corners, 199 | warning=False) 200 | 201 | return seg_logit 202 | 203 | def inference(self, img, img_meta, rescale): 204 | """Inference with slide/whole style. 205 | Args: 206 | img (Tensor): The input image of shape (N, 3, H, W). 207 | img_meta (dict): Image info dict where each dict has: 'img_shape', 208 | 'scale_factor', 'flip', and may also contain 209 | 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. 210 | For details on the values of these keys see 211 | `mmseg/datasets/pipelines/formatting.py:Collect`. 212 | rescale (bool): Whether rescale back to original shape. 213 | Returns: 214 | Tensor: The output segmentation map. 215 | """ 216 | 217 | assert self.test_cfg.mode in ['slide', 'whole'] 218 | ori_shape = img_meta[0]['ori_shape'] 219 | assert all(_['ori_shape'] == ori_shape for _ in img_meta) 220 | if self.test_cfg.mode == 'slide': 221 | seg_logit = self.slide_inference(img, img_meta, rescale) 222 | else: 223 | seg_logit = self.whole_inference(img, img_meta, rescale) 224 | output = F.softmax(seg_logit, dim=1) 225 | flip = img_meta[0]['flip'] 226 | flip_direction = img_meta[0]['flip_direction'] 227 | if flip: 228 | assert flip_direction in ['horizontal', 'vertical'] 229 | if flip_direction == 'horizontal': 230 | output = output.flip(dims=(3, )) 231 | elif flip_direction == 'vertical': 232 | output = output.flip(dims=(2, )) 233 | 234 | return output 235 | 236 | def simple_test(self, img, img_meta, rescale=True): 237 | """Simple test with single image.""" 238 | seg_logit = self.inference(img, img_meta, rescale) 239 | seg_pred = seg_logit.argmax(dim=1) 240 | seg_pred = seg_pred.cpu().numpy() 241 | # unravel batch dim 242 | seg_pred = list(seg_pred) 243 | return seg_pred 244 | 245 | def aug_test(self, imgs, img_metas, rescale=True): 246 | """Test with augmentations. 247 | Only rescale=True is supported. 248 | """ 249 | # aug_test rescale all imgs back to ori_shape for now 250 | assert rescale 251 | # to save memory, we get augmented seg logit inplace 252 | seg_logit = self.inference(imgs[0], img_metas[0], rescale) 253 | for i in range(1, len(imgs)): 254 | cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale) 255 | seg_logit += cur_seg_logit 256 | seg_logit /= len(imgs) 257 | seg_pred = seg_logit.argmax(dim=1) 258 | seg_pred = seg_pred.cpu().numpy() 259 | # unravel batch dim 260 | seg_pred = list(seg_pred) 261 | return seg_pred 262 | 263 | 264 | 265 | 266 | -------------------------------------------------------------------------------- /models/context_branch.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Modified from Deformable-DETR (https://github.com/fundamentalvision/Deformable-DETR) 3 | # Apache-2.0 License 4 | # Copyright (c) SenseTime, Inc. and its affiliates. 5 | # ------------------------------------------------------------------------------------------------ 6 | 7 | 8 | import copy 9 | import math 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_ 15 | 16 | from .deformable_attn import MSDeformAttn 17 | 18 | 19 | ## EPT 20 | class PositionEmbeddingSine(nn.Module): 21 | """ 22 | This is a more standard version of the position embedding, very similar to the one 23 | used by the Attention is all you need paper, generalized to work on images. 24 | """ 25 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 26 | super().__init__() 27 | self.num_pos_feats = num_pos_feats 28 | self.temperature = temperature 29 | self.normalize = normalize 30 | if scale is not None and normalize is False: 31 | raise ValueError("normalize should be True if scale is passed") 32 | if scale is None: 33 | scale = 2 * math.pi 34 | self.scale = scale 35 | 36 | def forward(self, bsize, h, w): 37 | mask = torch.ones(bsize, h, w).bool().cuda() 38 | assert mask is not None 39 | y_embed = mask.cumsum(1, dtype=torch.float32) 40 | x_embed = mask.cumsum(2, dtype=torch.float32) 41 | if self.normalize: 42 | eps = 1e-6 43 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 44 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 45 | 46 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32).cuda() 47 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 48 | 49 | pos_x = x_embed[:, :, :, None] / dim_t 50 | 51 | pos_y = y_embed[:, :, :, None] / dim_t 52 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 53 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 54 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 55 | return pos 56 | 57 | 58 | 59 | class EncoderLayer(nn.Module): 60 | def __init__(self, d_model, d_ffn, dropout, activation, 61 | n_levels, n_heads, n_points): 62 | 63 | super().__init__() 64 | # self attention 65 | self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) 66 | self.dropout1 = nn.Dropout(dropout) 67 | self.norm1 = nn.LayerNorm(d_model) 68 | 69 | # ffn 70 | self.linear1 = nn.Linear(d_model, d_ffn) 71 | self.activation = _get_activation_fn(activation) 72 | self.dropout2 = nn.Dropout(dropout) 73 | self.linear2 = nn.Linear(d_ffn, d_model) 74 | self.dropout3 = nn.Dropout(dropout) 75 | self.norm2 = nn.LayerNorm(d_model) 76 | 77 | @staticmethod 78 | def with_pos_embed(tensor, pos): 79 | return tensor if pos is None else tensor + pos 80 | 81 | def forward_ffn(self, src): 82 | src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) 83 | src = src + self.dropout3(src2) 84 | src = self.norm2(src) 85 | return src 86 | 87 | def forward(self, src, pos, reference_points, spatial_shapes, level_start_index): 88 | # self attention 89 | src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index) 90 | src = src + self.dropout1(src2) 91 | src = self.norm1(src) 92 | 93 | # ffn 94 | src = self.forward_ffn(src) 95 | 96 | return src 97 | 98 | class Encoder(nn.Module): 99 | def __init__(self, encoder_layer, num_layers): 100 | super().__init__() 101 | self.layers = _get_clones(encoder_layer, num_layers) 102 | self.num_layers = num_layers 103 | 104 | @staticmethod 105 | def get_reference_points(spatial_shapes, device): 106 | 107 | reference_points_list = [] 108 | for lvl, (H_, W_) in enumerate(spatial_shapes): 109 | 110 | ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), 111 | torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device)) 112 | ref_y = ref_y.reshape(-1)[None] / H_ 113 | ref_x = ref_x.reshape(-1)[None] / W_ 114 | ref = torch.stack((ref_x, ref_y), -1) 115 | reference_points_list.append(ref) 116 | 117 | reference_points = torch.cat(reference_points_list, 1) 118 | 119 | reference_points = reference_points[:, :, None].repeat(1, 1, len(spatial_shapes), 1) 120 | 121 | return reference_points 122 | 123 | def forward(self, src, spatial_shapes, level_start_index, pos): 124 | output = src 125 | 126 | reference_points = self.get_reference_points(spatial_shapes, device=src.device) 127 | 128 | for _, layer in enumerate(self.layers): 129 | output = layer(output, pos, reference_points, spatial_shapes, level_start_index) 130 | 131 | return output 132 | 133 | class DecoderLayer(nn.Module): 134 | def __init__(self, d_model, d_ffn, dropout, activation, n_levels, n_heads, n_points): 135 | 136 | super().__init__() 137 | 138 | # cross attention 139 | self.cross_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) 140 | self.dropout1 = nn.Dropout(dropout) 141 | self.norm1 = nn.LayerNorm(d_model) 142 | 143 | # self attention 144 | self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) 145 | self.dropout2 = nn.Dropout(dropout) 146 | self.norm2 = nn.LayerNorm(d_model) 147 | 148 | # ffn 149 | self.linear1 = nn.Linear(d_model, d_ffn) 150 | self.activation = _get_activation_fn(activation) 151 | self.dropout3 = nn.Dropout(dropout) 152 | self.linear2 = nn.Linear(d_ffn, d_model) 153 | self.dropout4 = nn.Dropout(dropout) 154 | self.norm3 = nn.LayerNorm(d_model) 155 | 156 | @staticmethod 157 | def with_pos_embed(tensor, pos): 158 | return tensor if pos is None else tensor + pos 159 | 160 | @staticmethod 161 | def get_reference_points(spatial_shapes, device, h, w): 162 | 163 | ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, h - 0.5, h, dtype=torch.float32, device=device), 164 | torch.linspace(0.5, w - 0.5, w, dtype=torch.float32, device=device)) 165 | ref_y = ref_y.reshape(-1)[None] / h 166 | ref_x = ref_x.reshape(-1)[None] / w 167 | ref = torch.stack((ref_x, ref_y), -1) 168 | 169 | reference_points = ref[:, :, None].repeat(1, 1, len(spatial_shapes), 1) 170 | 171 | return reference_points 172 | 173 | 174 | def forward_ffn(self, tgt): 175 | tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) 176 | tgt = tgt + self.dropout4(tgt2) 177 | tgt = self.norm3(tgt) 178 | return tgt 179 | 180 | def forward(self, tgt, query_pos, src, src_spatial_shapes, level_start_index, h, w): 181 | 182 | # self attention 183 | q = k = self.with_pos_embed(tgt, query_pos) 184 | tgt2 = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1))[0].transpose(0, 1) 185 | tgt = tgt + self.dropout2(tgt2) 186 | tgt = self.norm2(tgt) 187 | 188 | reference_points = self.get_reference_points(src_spatial_shapes, device=src.device, h=h, w=w) 189 | # cross attention 190 | tgt = tgt.permute(1, 0, 2) 191 | tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos), reference_points, src, src_spatial_shapes, level_start_index) 192 | tgt = tgt + self.dropout1(tgt2) 193 | tgt = self.norm1(tgt) 194 | 195 | # ffn 196 | tgt = self.forward_ffn(tgt) 197 | 198 | return tgt.permute(1,0,2) 199 | 200 | class Decoder(nn.Module): 201 | 202 | def __init__(self, decoder_layer, num_layers): 203 | super().__init__() 204 | self.layers = _get_clones(decoder_layer, num_layers) 205 | self.num_layers = num_layers 206 | 207 | def forward(self, tgt, src, src_spatial_shapes, src_level_start_index, h, w, query_pos=None): 208 | output = tgt 209 | 210 | for lid, layer in enumerate(self.layers): 211 | output = layer(output, query_pos, src, src_spatial_shapes, src_level_start_index, h, w) 212 | 213 | return output 214 | 215 | 216 | 217 | def _get_clones(module, N): 218 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 219 | 220 | 221 | def _get_activation_fn(activation): 222 | """Return an activation function given a string""" 223 | if activation == "relu": 224 | return F.relu 225 | if activation == "gelu": 226 | return F.gelu 227 | if activation == "glu": 228 | return F.glu 229 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 230 | 231 | 232 | class context_branch(nn.Module): 233 | def __init__(self, d_model, nhead, 234 | num_encoder_layers, num_decoder_layers, dim_feedforward, dropout, 235 | activation, num_feature_levels, dec_n_points, enc_n_points): 236 | 237 | super().__init__() 238 | self.d_model = d_model 239 | self.nhead = nhead 240 | self.num_feature_levels = num_feature_levels 241 | 242 | encoder_layer = EncoderLayer(d_model, dim_feedforward, 243 | dropout, activation, 244 | num_feature_levels, nhead, enc_n_points) 245 | self.encoder = Encoder(encoder_layer, num_encoder_layers) 246 | 247 | decoder_layer = DecoderLayer(d_model, dim_feedforward, dropout, activation, num_feature_levels, nhead, dec_n_points) 248 | 249 | self.decoder = Decoder(decoder_layer, num_decoder_layers) 250 | 251 | self.pos_embed = PositionEmbeddingSine(d_model//2, normalize=True) 252 | self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model)) 253 | self._reset_parameters() 254 | 255 | def _reset_parameters(self): 256 | for p in self.parameters(): 257 | if p.dim() > 1: 258 | nn.init.xavier_uniform_(p) 259 | for m in self.modules(): 260 | if isinstance(m, MSDeformAttn): 261 | m._reset_parameters() 262 | 263 | normal_(self.level_embed) 264 | 265 | def forward(self, ms_feats, context, query_embed, q_H, q_W): 266 | 267 | src_flatten = [] 268 | spatial_shapes = [] 269 | lvl_pos_embed_flatten = [] 270 | 271 | for lvl, src in enumerate(ms_feats): 272 | bs, c, h, w = src.shape 273 | spatial_shape = (h, w) 274 | spatial_shapes.append(spatial_shape) 275 | src = src.flatten(2).transpose(1, 2) 276 | pos_embed = self.pos_embed(bs, h, w).flatten(2).transpose(1, 2) 277 | lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) 278 | lvl_pos_embed_flatten.append(lvl_pos_embed) 279 | src_flatten.append(src) 280 | 281 | src_flatten = torch.cat(src_flatten, 1) 282 | lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) 283 | spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) 284 | level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) 285 | 286 | memory = self.encoder(src_flatten, spatial_shapes, level_start_index, lvl_pos_embed_flatten) 287 | 288 | query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) 289 | context = context + query_embed 290 | out = self.decoder(context, memory, spatial_shapes, level_start_index, q_H, q_W) 291 | 292 | return out -------------------------------------------------------------------------------- /models/deformable_attn.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from Deformable-DETR (https://github.com/fundamentalvision/Deformable-DETR) 3 | # Apache-2.0 License 4 | # Copyright (c) SenseTime, Inc. and its affiliates. 5 | # ------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import print_function 9 | from __future__ import division 10 | 11 | import warnings 12 | import math 13 | 14 | import torch 15 | from torch import nn 16 | import torch.nn.functional as F 17 | from torch.nn.init import xavier_uniform_, constant_ 18 | 19 | from .ops.functions import MSDeformAttnFunction 20 | 21 | 22 | def _is_power_of_2(n): 23 | if (not isinstance(n, int)) or (n < 0): 24 | raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) 25 | return (n & (n-1) == 0) and n != 0 26 | 27 | 28 | class MSDeformAttn(nn.Module): 29 | def __init__(self, d_model, n_levels, n_heads, n_points): 30 | """ 31 | Multi-Scale Deformable Attention Module 32 | :param d_model hidden dimension 33 | :param n_levels number of feature levels 34 | :param n_heads number of attention heads 35 | :param n_points number of sampling points per attention head per feature level 36 | """ 37 | super().__init__() 38 | if d_model % n_heads != 0: 39 | raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) 40 | _d_per_head = d_model // n_heads 41 | # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation 42 | if not _is_power_of_2(_d_per_head): 43 | warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " 44 | "which is more efficient in our CUDA implementation.") 45 | 46 | self.im2col_step = 64 47 | 48 | self.d_model = d_model 49 | self.n_levels = n_levels 50 | self.n_heads = n_heads 51 | self.n_points = n_points 52 | 53 | self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) 54 | self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) 55 | self.value_proj = nn.Linear(d_model, d_model) 56 | self.output_proj = nn.Linear(d_model, d_model) 57 | 58 | self._reset_parameters() 59 | 60 | def _reset_parameters(self): 61 | constant_(self.sampling_offsets.weight.data, 0.) 62 | thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) 63 | grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) 64 | grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) 65 | for i in range(self.n_points): 66 | grid_init[:, :, i, :] *= i + 1 67 | with torch.no_grad(): 68 | self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) 69 | constant_(self.attention_weights.weight.data, 0.) 70 | constant_(self.attention_weights.bias.data, 0.) 71 | xavier_uniform_(self.value_proj.weight.data) 72 | constant_(self.value_proj.bias.data, 0.) 73 | xavier_uniform_(self.output_proj.weight.data) 74 | constant_(self.output_proj.bias.data, 0.) 75 | 76 | def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index): 77 | """ 78 | :param query (N, Length_{query}, C) 79 | :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1) 80 | :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) 81 | :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] 82 | :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] 83 | 84 | :return output (N, Length_{query}, C) 85 | """ 86 | N, Len_q, _ = query.shape 87 | N, Len_in, _ = input_flatten.shape 88 | assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in 89 | 90 | value = self.value_proj(input_flatten) 91 | 92 | value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) 93 | sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) 94 | attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) 95 | attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) 96 | 97 | if reference_points.shape[-1] == 2: 98 | offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) 99 | 100 | sampling_locations = reference_points[:, :, None, :, None, :] \ 101 | + sampling_offsets / offset_normalizer[None, None, None, :, None, :] 102 | else: 103 | raise ValueError( 104 | 'Last dim of reference_points must be 2 , but get {} instead.'.format(reference_points.shape[-1])) 105 | 106 | output = MSDeformAttnFunction.apply( 107 | value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) 108 | output = self.output_proj(output) 109 | return output 110 | 111 | 112 | 113 | -------------------------------------------------------------------------------- /models/ops/functions/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .ms_deform_attn_func import MSDeformAttnFunction 10 | 11 | -------------------------------------------------------------------------------- /models/ops/functions/ms_deform_attn_func.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch.autograd import Function 16 | from torch.autograd.function import once_differentiable 17 | 18 | import MultiScaleDeformableAttention as MSDA 19 | 20 | 21 | class MSDeformAttnFunction(Function): 22 | @staticmethod 23 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): 24 | ctx.im2col_step = im2col_step 25 | output = MSDA.ms_deform_attn_forward( 26 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) 27 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) 28 | return output 29 | 30 | @staticmethod 31 | @once_differentiable 32 | def backward(ctx, grad_output): 33 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors 34 | grad_value, grad_sampling_loc, grad_attn_weight = \ 35 | MSDA.ms_deform_attn_backward( 36 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) 37 | 38 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None 39 | 40 | 41 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): 42 | # for debug and test only, 43 | # need to use cuda version instead 44 | 45 | N_, S_, M_, D_ = value.shape 46 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape 47 | 48 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) 49 | sampling_grids = 2 * sampling_locations - 1 # N * Lq * M * L * P * 2 50 | sampling_value_list = [] 51 | for lid_, (H_, W_) in enumerate(value_spatial_shapes): 52 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ 53 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) 54 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 55 | sampling_grid_l_ = sampling_grids[:, :, :, lid_, :, :].transpose(1, 2).flatten(0, 1) 56 | # N_*M_, D_, Lq_, P_ 57 | 58 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, 59 | mode='bilinear', padding_mode='zeros', align_corners=False) 60 | sampling_value_list.append(sampling_value_l_) 61 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) 62 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) 63 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) 64 | return output.transpose(1, 2).contiguous() 65 | -------------------------------------------------------------------------------- /models/ops/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .ms_deform_attn import MSDeformAttn 10 | -------------------------------------------------------------------------------- /models/ops/modules/ms_deform_attn.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import warnings 14 | import math 15 | 16 | import torch 17 | from torch import nn 18 | import torch.nn.functional as F 19 | from torch.nn.init import xavier_uniform_, constant_ 20 | 21 | from ..functions import MSDeformAttnFunction 22 | 23 | 24 | def _is_power_of_2(n): 25 | if (not isinstance(n, int)) or (n < 0): 26 | raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) 27 | return (n & (n-1) == 0) and n != 0 28 | 29 | 30 | class MSDeformAttn(nn.Module): 31 | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): 32 | """ 33 | Multi-Scale Deformable Attention Module 34 | :param d_model hidden dimension 35 | :param n_levels number of feature levels 36 | :param n_heads number of attention heads 37 | :param n_points number of sampling points per attention head per feature level 38 | """ 39 | super().__init__() 40 | if d_model % n_heads != 0: 41 | raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) 42 | _d_per_head = d_model // n_heads 43 | # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation 44 | if not _is_power_of_2(_d_per_head): 45 | warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " 46 | "which is more efficient in our CUDA implementation.") 47 | 48 | self.im2col_step = 64 49 | 50 | self.d_model = d_model 51 | self.n_levels = n_levels 52 | self.n_heads = n_heads 53 | self.n_points = n_points 54 | 55 | self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) 56 | self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) 57 | self.value_proj = nn.Linear(d_model, d_model) 58 | self.output_proj = nn.Linear(d_model, d_model) 59 | 60 | self._reset_parameters() 61 | 62 | def _reset_parameters(self): 63 | constant_(self.sampling_offsets.weight.data, 0.) 64 | thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) 65 | grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) 66 | grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) 67 | for i in range(self.n_points): 68 | grid_init[:, :, i, :] *= i + 1 69 | with torch.no_grad(): 70 | self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) 71 | constant_(self.attention_weights.weight.data, 0.) 72 | constant_(self.attention_weights.bias.data, 0.) 73 | xavier_uniform_(self.value_proj.weight.data) 74 | constant_(self.value_proj.bias.data, 0.) 75 | xavier_uniform_(self.output_proj.weight.data) 76 | constant_(self.output_proj.bias.data, 0.) 77 | 78 | def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): 79 | """ 80 | :param query (N, Length_{query}, C) 81 | :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area 82 | or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes 83 | :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) 84 | :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] 85 | :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] 86 | :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements 87 | 88 | :return output (N, Length_{query}, C) 89 | """ 90 | N, Len_q, _ = query.shape 91 | N, Len_in, _ = input_flatten.shape 92 | assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in 93 | 94 | value = self.value_proj(input_flatten) 95 | if input_padding_mask is not None: 96 | value = value.masked_fill(input_padding_mask[..., None], float(0)) 97 | value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) 98 | sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) 99 | attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) 100 | attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) 101 | # N, Len_q, n_heads, n_levels, n_points, 2 102 | if reference_points.shape[-1] == 2: 103 | offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) 104 | sampling_locations = reference_points[:, :, None, :, None, :] \ 105 | + sampling_offsets / offset_normalizer[None, None, None, :, None, :] 106 | elif reference_points.shape[-1] == 4: 107 | sampling_locations = reference_points[:, :, None, :, None, :2] \ 108 | + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 109 | else: 110 | raise ValueError( 111 | 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) 112 | # print('value shape: {}', value.shape) 113 | # print('input_spatial_shapes: {}', input_spatial_shapes) 114 | # print('input_level_start_index: {}', input_level_start_index) 115 | # print('sampling_locations shape: {}', sampling_locations.shape) 116 | # print('attention_weights: {}', attention_weights.shape) 117 | 118 | output = MSDeformAttnFunction.apply( 119 | value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) 120 | output = self.output_proj(output) 121 | return output 122 | -------------------------------------------------------------------------------- /models/ops/src/cuda/ms_deform_attn_cuda.cu: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include 12 | #include "cuda/ms_deform_im2col_cuda.cuh" 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | 20 | at::Tensor ms_deform_attn_cuda_forward( 21 | const at::Tensor &value, 22 | const at::Tensor &spatial_shapes, 23 | const at::Tensor &level_start_index, 24 | const at::Tensor &sampling_loc, 25 | const at::Tensor &attn_weight, 26 | const int im2col_step) 27 | { 28 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 29 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 30 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 31 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 32 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 33 | 34 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 35 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 36 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 37 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 38 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 39 | 40 | const int batch = value.size(0); 41 | const int spatial_size = value.size(1); 42 | const int num_heads = value.size(2); 43 | const int channels = value.size(3); 44 | 45 | const int num_levels = spatial_shapes.size(0); 46 | 47 | const int num_query = sampling_loc.size(1); 48 | const int num_point = sampling_loc.size(4); 49 | 50 | const int im2col_step_ = std::min(batch, im2col_step); 51 | 52 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 53 | 54 | auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); 55 | 56 | const int batch_n = im2col_step_; 57 | auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 58 | auto per_value_size = spatial_size * num_heads * channels; 59 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 60 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 61 | for (int n = 0; n < batch/im2col_step_; ++n) 62 | { 63 | auto columns = output_n.select(0, n); 64 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { 65 | ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), 66 | value.data() + n * im2col_step_ * per_value_size, 67 | spatial_shapes.data(), 68 | level_start_index.data(), 69 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 70 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 71 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 72 | columns.data()); 73 | 74 | })); 75 | } 76 | 77 | output = output.view({batch, num_query, num_heads*channels}); 78 | 79 | return output; 80 | } 81 | 82 | 83 | std::vector ms_deform_attn_cuda_backward( 84 | const at::Tensor &value, 85 | const at::Tensor &spatial_shapes, 86 | const at::Tensor &level_start_index, 87 | const at::Tensor &sampling_loc, 88 | const at::Tensor &attn_weight, 89 | const at::Tensor &grad_output, 90 | const int im2col_step) 91 | { 92 | 93 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 94 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 95 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 96 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 97 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 98 | AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); 99 | 100 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 101 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 102 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 103 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 104 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 105 | AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); 106 | 107 | const int batch = value.size(0); 108 | const int spatial_size = value.size(1); 109 | const int num_heads = value.size(2); 110 | const int channels = value.size(3); 111 | 112 | const int num_levels = spatial_shapes.size(0); 113 | 114 | const int num_query = sampling_loc.size(1); 115 | const int num_point = sampling_loc.size(4); 116 | 117 | const int im2col_step_ = std::min(batch, im2col_step); 118 | 119 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 120 | 121 | auto grad_value = at::zeros_like(value); 122 | auto grad_sampling_loc = at::zeros_like(sampling_loc); 123 | auto grad_attn_weight = at::zeros_like(attn_weight); 124 | 125 | const int batch_n = im2col_step_; 126 | auto per_value_size = spatial_size * num_heads * channels; 127 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 128 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 129 | auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 130 | 131 | for (int n = 0; n < batch/im2col_step_; ++n) 132 | { 133 | auto grad_output_g = grad_output_n.select(0, n); 134 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { 135 | ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), 136 | grad_output_g.data(), 137 | value.data() + n * im2col_step_ * per_value_size, 138 | spatial_shapes.data(), 139 | level_start_index.data(), 140 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 141 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 142 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 143 | grad_value.data() + n * im2col_step_ * per_value_size, 144 | grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 145 | grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); 146 | 147 | })); 148 | } 149 | 150 | return { 151 | grad_value, grad_sampling_loc, grad_attn_weight 152 | }; 153 | } -------------------------------------------------------------------------------- /models/ops/src/cuda/ms_deform_attn_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | at::Tensor ms_deform_attn_cuda_forward( 15 | const at::Tensor &value, 16 | const at::Tensor &spatial_shapes, 17 | const at::Tensor &level_start_index, 18 | const at::Tensor &sampling_loc, 19 | const at::Tensor &attn_weight, 20 | const int im2col_step); 21 | 22 | std::vector ms_deform_attn_cuda_backward( 23 | const at::Tensor &value, 24 | const at::Tensor &spatial_shapes, 25 | const at::Tensor &level_start_index, 26 | const at::Tensor &sampling_loc, 27 | const at::Tensor &attn_weight, 28 | const at::Tensor &grad_output, 29 | const int im2col_step); 30 | 31 | -------------------------------------------------------------------------------- /models/ops/src/ms_deform_attn.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "cpu/ms_deform_attn_cpu.h" 14 | 15 | #ifdef WITH_CUDA 16 | #include "cuda/ms_deform_attn_cuda.h" 17 | #endif 18 | 19 | 20 | at::Tensor 21 | ms_deform_attn_forward( 22 | const at::Tensor &value, 23 | const at::Tensor &spatial_shapes, 24 | const at::Tensor &level_start_index, 25 | const at::Tensor &sampling_loc, 26 | const at::Tensor &attn_weight, 27 | const int im2col_step) 28 | { 29 | if (value.type().is_cuda()) 30 | { 31 | #ifdef WITH_CUDA 32 | return ms_deform_attn_cuda_forward( 33 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); 34 | #else 35 | AT_ERROR("Not compiled with GPU support"); 36 | #endif 37 | } 38 | AT_ERROR("Not implemented on the CPU"); 39 | } 40 | 41 | std::vector 42 | ms_deform_attn_backward( 43 | const at::Tensor &value, 44 | const at::Tensor &spatial_shapes, 45 | const at::Tensor &level_start_index, 46 | const at::Tensor &sampling_loc, 47 | const at::Tensor &attn_weight, 48 | const at::Tensor &grad_output, 49 | const int im2col_step) 50 | { 51 | if (value.type().is_cuda()) 52 | { 53 | #ifdef WITH_CUDA 54 | return ms_deform_attn_cuda_backward( 55 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); 56 | #else 57 | AT_ERROR("Not compiled with GPU support"); 58 | #endif 59 | } 60 | AT_ERROR("Not implemented on the CPU"); 61 | } 62 | 63 | -------------------------------------------------------------------------------- /models/ops/src/vision.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include "ms_deform_attn.h" 12 | 13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 14 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); 15 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); 16 | } 17 | -------------------------------------------------------------------------------- /models/spatial_branch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class ConvBlock(nn.Module): 5 | def __init__(self, in_channels, out_channels, activation=True, kernel_size=3, stride=2,padding=1): 6 | super().__init__() 7 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) 8 | self.bn = nn.SyncBatchNorm(out_channels) 9 | self.activation = activation 10 | if activation: 11 | self.relu = nn.ReLU() 12 | 13 | def forward(self, input): 14 | x = self.conv1(input) 15 | if self.activation: 16 | return self.relu(self.bn(x)) 17 | else: 18 | return self.bn(x) 19 | 20 | 21 | class spatial_branch(nn.Module): 22 | def __init__(self): 23 | super().__init__() 24 | self.convblock1 = ConvBlock(in_channels=3, out_channels=64) 25 | self.convblock2 = ConvBlock(in_channels=64, out_channels=128) 26 | self.convblock3 = ConvBlock(in_channels=128, out_channels=256, activation=False) 27 | 28 | def forward(self, input): 29 | x = self.convblock1(input) 30 | x = self.convblock2(x) 31 | x = self.convblock3(x) 32 | return x 33 | -------------------------------------------------------------------------------- /models/vision_transformer.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Modified from pytorch-image-models (https://github.com/rwightman/pytorch-image-models) 3 | # ------------------------------------------------------------------------------------------------ 4 | 5 | """ Vision Transformer (ViT) in PyTorch 6 | A PyTorch implement of Vision Transformers as described in 7 | 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929 8 | The official jax code is released and available at https://github.com/google-research/vision_transformer 9 | Acknowledgments: 10 | * The paper authors for releasing code and weights, thanks! 11 | * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out 12 | for some einops/einsum fun 13 | * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT 14 | * Bert reference code checks against Huggingface Transformers and Tensorflow Bert 15 | DeiT model defs and weights from https://github.com/facebookresearch/deit, 16 | paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 17 | Hacked together by / Copyright 2020 Ross Wightman 18 | """ 19 | 20 | import math 21 | import logging 22 | from functools import partial 23 | from collections import OrderedDict 24 | 25 | import torch 26 | import torch.nn as nn 27 | import torch.nn.functional as F 28 | 29 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 30 | from timm.models.helpers import load_pretrained 31 | from timm.models.layers import StdConv2dSame, DropPath, to_2tuple, trunc_normal_ 32 | 33 | from timm.models import resnet26d, resnet50d, register_model 34 | from timm.models.resnetv2 import ResNetV2 35 | 36 | 37 | _logger = logging.getLogger(__name__) 38 | 39 | 40 | def _cfg(url='', **kwargs): 41 | return { 42 | 'url': url, 43 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 44 | 'crop_pct': .9, 'interpolation': 'bicubic', 45 | 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 46 | 'first_conv': 'patch_embed.proj', 'classifier': 'head', 47 | **kwargs 48 | } 49 | 50 | 51 | default_cfgs = { 52 | # patch models (my experiments) 53 | 'vit_small_patch16_224': _cfg( 54 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth', 55 | ), 56 | 57 | # patch models (weights ported from official Google JAX impl) 58 | 'vit_base_patch16_224': _cfg( 59 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', 60 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 61 | ), 62 | 'vit_base_patch32_224': _cfg( 63 | url='', # no official model weights for this combo, only for in21k 64 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 65 | 'vit_base_patch16_384': _cfg( 66 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth', 67 | input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), 68 | 'vit_base_patch32_384': _cfg( 69 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth', 70 | input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), 71 | 'vit_large_patch16_224': _cfg( 72 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth', 73 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 74 | 'vit_large_patch32_224': _cfg( 75 | url='', # no official model weights for this combo, only for in21k 76 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 77 | 'vit_large_patch16_384': _cfg( 78 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth', 79 | input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), 80 | 'vit_large_patch32_384': _cfg( 81 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth', 82 | input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), 83 | 84 | # patch models, imagenet21k (weights ported from official Google JAX impl) 85 | 'vit_base_patch16_224_in21k': _cfg( 86 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch16_224_in21k-e5005f0a.pth', 87 | num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 88 | 'vit_base_patch32_224_in21k': _cfg( 89 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_patch32_224_in21k-8db57226.pth', 90 | num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 91 | 'vit_large_patch16_224_in21k': _cfg( 92 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch16_224_in21k-606da67d.pth', 93 | num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 94 | 'vit_large_patch32_224_in21k': _cfg( 95 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', 96 | num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 97 | 'vit_huge_patch14_224_in21k': _cfg( 98 | url='', # FIXME I have weights for this but > 2GB limit for github release binaries 99 | num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 100 | 101 | # hybrid models (weights ported from official Google JAX impl) 102 | 'vit_base_resnet50_224_in21k': _cfg( 103 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth', 104 | num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.9, first_conv='patch_embed.backbone.stem.conv'), 105 | 'vit_base_resnet50_384': _cfg( 106 | url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth', 107 | input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, first_conv='patch_embed.backbone.stem.conv'), 108 | 109 | # hybrid models (my experiments) 110 | 'vit_small_resnet26d_224': _cfg(), 111 | 'vit_small_resnet50d_s3_224': _cfg(), 112 | 'vit_base_resnet26d_224': _cfg(), 113 | 'vit_base_resnet50d_224': _cfg(), 114 | 115 | # deit models (FB weights) 116 | 'vit_deit_tiny_patch16_224': _cfg( 117 | url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'), 118 | 'vit_deit_small_patch16_224': _cfg( 119 | url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'), 120 | 'vit_deit_base_patch16_224': _cfg( 121 | url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',), 122 | 'vit_deit_base_patch16_384': _cfg( 123 | url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth', 124 | input_size=(3, 384, 384), crop_pct=1.0), 125 | 'vit_deit_tiny_distilled_patch16_224': _cfg( 126 | url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth'), 127 | 'vit_deit_small_distilled_patch16_224': _cfg( 128 | url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth'), 129 | 'vit_deit_base_distilled_patch16_224': _cfg( 130 | url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', ), 131 | 'vit_deit_base_distilled_patch16_384': _cfg( 132 | url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth', 133 | input_size=(3, 384, 384), crop_pct=1.0), 134 | } 135 | 136 | 137 | class Mlp(nn.Module): 138 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 139 | super().__init__() 140 | out_features = out_features or in_features 141 | hidden_features = hidden_features or in_features 142 | self.fc1 = nn.Linear(in_features, hidden_features) 143 | self.act = act_layer() 144 | self.fc2 = nn.Linear(hidden_features, out_features) 145 | self.drop = nn.Dropout(drop) 146 | 147 | def forward(self, x): 148 | x = self.fc1(x) 149 | x = self.act(x) 150 | x = self.drop(x) 151 | x = self.fc2(x) 152 | x = self.drop(x) 153 | return x 154 | 155 | 156 | class Attention(nn.Module): 157 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 158 | super().__init__() 159 | self.num_heads = num_heads 160 | head_dim = dim // num_heads 161 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 162 | self.scale = qk_scale or head_dim ** -0.5 163 | 164 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 165 | self.attn_drop = nn.Dropout(attn_drop) 166 | self.proj = nn.Linear(dim, dim) 167 | self.proj_drop = nn.Dropout(proj_drop) 168 | 169 | def forward(self, x): 170 | B, N, C = x.shape 171 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 172 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 173 | 174 | attn = (q @ k.transpose(-2, -1)) * self.scale 175 | attn = attn.softmax(dim=-1) 176 | attn = self.attn_drop(attn) 177 | 178 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 179 | x = self.proj(x) 180 | x = self.proj_drop(x) 181 | return x 182 | 183 | 184 | class Block(nn.Module): 185 | 186 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 187 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 188 | super().__init__() 189 | self.norm1 = norm_layer(dim) 190 | self.attn = Attention( 191 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 192 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 193 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 194 | self.norm2 = norm_layer(dim) 195 | mlp_hidden_dim = int(dim * mlp_ratio) 196 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 197 | 198 | def forward(self, x): 199 | x = x + self.drop_path(self.attn(self.norm1(x))) 200 | x = x + self.drop_path(self.mlp(self.norm2(x))) 201 | return x 202 | 203 | 204 | class PatchEmbed(nn.Module): 205 | """ Image to Patch Embedding 206 | """ 207 | def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): 208 | super().__init__() 209 | img_size = to_2tuple(img_size) 210 | patch_size = to_2tuple(patch_size) 211 | num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) 212 | self.img_size = img_size 213 | self.patch_size = patch_size 214 | self.num_patches = num_patches 215 | 216 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 217 | 218 | def forward(self, x): 219 | B, C, H, W = x.shape 220 | # FIXME look at relaxing size constraints 221 | assert H == self.img_size[0] and W == self.img_size[1], \ 222 | f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." 223 | x = self.proj(x).flatten(2).transpose(1, 2) 224 | return x 225 | 226 | 227 | class HybridEmbed(nn.Module): 228 | """ CNN Feature Map Embedding 229 | Extract feature map from CNN, flatten, project to embedding dim. 230 | """ 231 | def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): 232 | super().__init__() 233 | assert isinstance(backbone, nn.Module) 234 | img_size = to_2tuple(img_size) 235 | self.img_size = img_size 236 | self.backbone = backbone 237 | if feature_size is None: 238 | with torch.no_grad(): 239 | # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature 240 | # map for all networks, the feature metadata has reliable channel and stride info, but using 241 | # stride to calc feature dim requires info about padding of each stage that isn't captured. 242 | training = backbone.training 243 | if training: 244 | backbone.eval() 245 | o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1])) 246 | if isinstance(o, (list, tuple)): 247 | o = o[-1] # last feature if backbone outputs list/tuple of features 248 | feature_size = o.shape[-2:] 249 | feature_dim = o.shape[1] 250 | backbone.train(training) 251 | else: 252 | feature_size = to_2tuple(feature_size) 253 | if hasattr(self.backbone, 'feature_info'): 254 | feature_dim = self.backbone.feature_info.channels()[-1] 255 | else: 256 | feature_dim = self.backbone.num_features 257 | self.num_patches = feature_size[0] * feature_size[1] 258 | self.proj = nn.Conv2d(feature_dim, embed_dim, 1) 259 | 260 | def forward(self, x): 261 | x = self.backbone(x) 262 | if isinstance(x, (list, tuple)): 263 | x = x[-1] # last feature if backbone outputs list/tuple of features 264 | x = self.proj(x).flatten(2).transpose(1, 2) 265 | return x 266 | 267 | 268 | class VisionTransformer(nn.Module): 269 | """ Vision Transformer 270 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - 271 | https://arxiv.org/abs/2010.11929 272 | """ 273 | def __init__(self, img_size=224, patch_size=16, seq_len=30, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 274 | num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, 275 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None): 276 | """ 277 | Args: 278 | img_size (int, tuple): input image size 279 | patch_size (int, tuple): patch size 280 | in_chans (int): number of input channels 281 | num_classes (int): number of classes for classification head 282 | embed_dim (int): embedding dimension 283 | depth (int): depth of transformer 284 | num_heads (int): number of attention heads 285 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 286 | qkv_bias (bool): enable bias for qkv if True 287 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set 288 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 289 | drop_rate (float): dropout rate 290 | attn_drop_rate (float): attention dropout rate 291 | drop_path_rate (float): stochastic depth rate 292 | hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module 293 | norm_layer: (nn.Module): normalization layer 294 | """ 295 | super().__init__() 296 | self.num_classes = num_classes 297 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 298 | self.seq_len = seq_len 299 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 300 | 301 | if hybrid_backbone is not None: 302 | self.patch_embed = HybridEmbed( 303 | hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) 304 | else: 305 | self.patch_embed = PatchEmbed( 306 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 307 | num_patches = self.patch_embed.num_patches 308 | 309 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 310 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) 311 | self.pos_drop = nn.Dropout(p=drop_rate) 312 | 313 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 314 | self.blocks = nn.ModuleList([ 315 | Block( 316 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 317 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 318 | for i in range(depth)]) 319 | 320 | trunc_normal_(self.pos_embed, std=.02) 321 | trunc_normal_(self.cls_token, std=.02) 322 | self.apply(self._init_weights) 323 | 324 | def _init_weights(self, m): 325 | if isinstance(m, nn.Linear): 326 | trunc_normal_(m.weight, std=.02) 327 | if isinstance(m, nn.Linear) and m.bias is not None: 328 | nn.init.constant_(m.bias, 0) 329 | elif isinstance(m, nn.LayerNorm): 330 | nn.init.constant_(m.bias, 0) 331 | nn.init.constant_(m.weight, 1.0) 332 | 333 | @torch.jit.ignore 334 | def no_weight_decay(self): 335 | return {'pos_embed', 'cls_token'} 336 | 337 | def get_classifier(self): 338 | return self.head 339 | 340 | def reset_classifier(self, num_classes, global_pool=''): 341 | self.num_classes = num_classes 342 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 343 | 344 | def forward_features(self, x): 345 | B = x.shape[0] 346 | x = self.patch_embed(x) 347 | 348 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 349 | x = torch.cat((cls_tokens, x), dim=1) 350 | x = x + self.pos_embed 351 | x = self.pos_drop(x) 352 | 353 | feats = [] 354 | cnt = 1 355 | for blk in self.blocks: 356 | x = blk(x) 357 | if cnt%4 == 0: 358 | feats.append(x[:,1:,]) 359 | cnt += 1 360 | 361 | return feats 362 | 363 | def forward(self, x): 364 | x = self.forward_features(x) 365 | x = self.head(x) 366 | return x 367 | 368 | class DistilledVisionTransformer(VisionTransformer): 369 | """ Vision Transformer with distillation token. 370 | Paper: `Training data-efficient image transformers & distillation through attention` - 371 | https://arxiv.org/abs/2012.12877 372 | This impl of distilled ViT is taken from https://github.com/facebookresearch/deit 373 | """ 374 | def __init__(self, *args, **kwargs): 375 | super().__init__(*args, **kwargs) 376 | self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) 377 | num_patches = self.patch_embed.num_patches 378 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim)) 379 | 380 | trunc_normal_(self.pos_embed, std=.02) 381 | 382 | def forward_features(self, x): 383 | # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py 384 | # with slight modifications to add the dist_token 385 | B = x.shape[0] 386 | x = self.patch_embed(x) 387 | 388 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 389 | dist_token = self.dist_token.expand(B, -1, -1) 390 | x = torch.cat((cls_tokens, dist_token, x), dim=1) 391 | x = x + self.pos_embed 392 | x = self.pos_drop(x) 393 | 394 | feats = [] 395 | cnt = 1 396 | for blk in self.blocks: 397 | 398 | x = blk(x) 399 | if cnt%4 == 0: 400 | feat = x[:,2:,] 401 | feat = feat.permute(0,2,1).unsqueeze(2).reshape([x[:,2:,].shape[0], self.embed_dim, self.seq_len, self.seq_len]) 402 | feats.append(feat) 403 | cnt += 1 404 | 405 | return feats 406 | 407 | def forward(self, x): 408 | 409 | feats = self.forward_features(x) 410 | return feats 411 | 412 | 413 | 414 | def resize_pos_embed(posemb, posemb_new): 415 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from 416 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 417 | _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) 418 | ntok_new = posemb_new.shape[1] 419 | if True: 420 | posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] 421 | ntok_new -= 1 422 | else: 423 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0] 424 | gs_old = int(math.sqrt(len(posemb_grid))) 425 | gs_new = int(math.sqrt(ntok_new)) 426 | _logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new) 427 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 428 | posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bilinear') 429 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1) 430 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 431 | return posemb 432 | 433 | 434 | def checkpoint_filter_fn(state_dict, model): 435 | """ convert patch embedding weight from manual patchify + linear proj to conv""" 436 | out_dict = {} 437 | if 'model' in state_dict: 438 | # For deit models 439 | state_dict = state_dict['model'] 440 | for k, v in state_dict.items(): 441 | if 'patch_embed.proj.weight' in k and len(v.shape) < 4: 442 | # For old models that I trained prior to conv based patchification 443 | O, I, H, W = model.patch_embed.proj.weight.shape 444 | v = v.reshape(O, -1, H, W) 445 | elif k == 'pos_embed' and v.shape != model.pos_embed.shape: 446 | # To resize pos embedding when using model at different size from pretrained weights 447 | v = resize_pos_embed(v, model.pos_embed) 448 | out_dict[k] = v 449 | return out_dict 450 | 451 | 452 | def deit_base_distilled_patch16_384( 453 | img_size, 454 | patch_size, 455 | embed_dim, 456 | depth, 457 | num_heads, 458 | mlp_ratio, 459 | pretrained=None, **kwargs): 460 | seq_len = img_size//patch_size 461 | model = DistilledVisionTransformer( 462 | img_size=img_size, patch_size=patch_size, seq_len=seq_len, embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=True, 463 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 464 | model.default_cfg = _cfg() 465 | if pretrained is not None: 466 | checkpoint = torch.load(pretrained, map_location='cpu') 467 | # checkpoint = torch.hub.load_state_dict_from_url( 468 | # url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth", 469 | # map_location="cpu", check_hash=True 470 | # ) 471 | 472 | checkpoint_model = checkpoint['model'] 473 | state_dict = model.state_dict() 474 | 475 | # interpolate position embedding 476 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 477 | embedding_size = pos_embed_checkpoint.shape[-1] 478 | num_patches = model.patch_embed.num_patches 479 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 480 | # height (== width) for the checkpoint position embedding 481 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 482 | # height (== width) for the new position embedding 483 | new_size = int(num_patches ** 0.5) 484 | # class_token and dist_token are kept unchanged 485 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 486 | # only the position tokens are interpolated 487 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 488 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 489 | pos_tokens = torch.nn.functional.interpolate( 490 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 491 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 492 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 493 | checkpoint_model['pos_embed'] = new_pos_embed 494 | model.load_state_dict(checkpoint["model"], strict=False) 495 | return model -------------------------------------------------------------------------------- /modified_mmseg/__init__.py: -------------------------------------------------------------------------------- 1 | import mmcv 2 | 3 | from mmseg.version import __version__, version_info 4 | 5 | MMCV_MIN = '1.1.4' 6 | MMCV_MAX = '1.3.0' 7 | 8 | 9 | def digit_version(version_str): 10 | digit_version = [] 11 | for x in version_str.split('.'): 12 | if x.isdigit(): 13 | digit_version.append(int(x)) 14 | elif x.find('rc') != -1: 15 | patch_version = x.split('rc') 16 | digit_version.append(int(patch_version[0]) - 1) 17 | digit_version.append(int(patch_version[1])) 18 | return digit_version 19 | 20 | 21 | mmcv_min_version = digit_version(MMCV_MIN) 22 | mmcv_max_version = digit_version(MMCV_MAX) 23 | mmcv_version = digit_version(mmcv.__version__) 24 | 25 | 26 | assert (mmcv_min_version <= mmcv_version <= mmcv_max_version), \ 27 | f'MMCV=={mmcv.__version__} is used but incompatible. ' \ 28 | f'Please install mmcv>={mmcv_min_version}, <={mmcv_max_version}.' 29 | 30 | __all__ = ['__version__', 'version_info'] -------------------------------------------------------------------------------- /modified_mmseg/apis/__init__.py: -------------------------------------------------------------------------------- 1 | from .train import get_root_logger, set_random_seed, train_segmentor 2 | 3 | __all__ = [ 4 | 'get_root_logger', 'set_random_seed', 'train_segmentor', 'init_segmentor', 5 | 'inference_segmentor', 'multi_gpu_test', 'single_gpu_test', 6 | 'show_result_pyplot' 7 | ] 8 | -------------------------------------------------------------------------------- /modified_mmseg/apis/train.py: -------------------------------------------------------------------------------- 1 | import random 2 | import warnings 3 | 4 | import numpy as np 5 | import torch 6 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel 7 | from mmcv.runner import build_optimizer, build_runner 8 | 9 | from mmseg.core import DistEvalHook, EvalHook 10 | 11 | import sys 12 | sys.path.append("/home/ubuntu/work/unified-ept/modified_mmseg") 13 | from datasets import build_dataloader, build_dataset 14 | from mmseg.utils import get_root_logger 15 | 16 | 17 | def set_random_seed(seed, deterministic=False): 18 | """Set random seed. 19 | 20 | Args: 21 | seed (int): Seed to be used. 22 | deterministic (bool): Whether to set the deterministic option for 23 | CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` 24 | to True and `torch.backends.cudnn.benchmark` to False. 25 | Default: False. 26 | """ 27 | random.seed(seed) 28 | np.random.seed(seed) 29 | torch.manual_seed(seed) 30 | torch.cuda.manual_seed_all(seed) 31 | if deterministic: 32 | torch.backends.cudnn.deterministic = True 33 | torch.backends.cudnn.benchmark = False 34 | 35 | 36 | def train_segmentor(model, 37 | dataset, 38 | cfg, 39 | distributed=False, 40 | validate=False, 41 | timestamp=None, 42 | meta=None): 43 | """Launch segmentor training.""" 44 | logger = get_root_logger(cfg.log_level) 45 | 46 | # prepare data loaders 47 | dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] 48 | data_loaders = [ 49 | build_dataloader( 50 | ds, 51 | cfg.data.samples_per_gpu, 52 | cfg.data.workers_per_gpu, 53 | # cfg.gpus will be ignored if distributed 54 | len(cfg.gpu_ids), 55 | dist=distributed, 56 | seed=cfg.seed, 57 | drop_last=True) for ds in dataset 58 | ] 59 | 60 | # put model on gpus 61 | if distributed: 62 | find_unused_parameters = cfg.get('find_unused_parameters', False) 63 | # Sets the `find_unused_parameters` parameter in 64 | # torch.nn.parallel.DistributedDataParallel 65 | model = MMDistributedDataParallel( 66 | model.cuda(), 67 | device_ids=[torch.cuda.current_device()], 68 | broadcast_buffers=False, 69 | find_unused_parameters=find_unused_parameters) 70 | else: 71 | model = MMDataParallel( 72 | model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids) 73 | 74 | # build runner 75 | optimizer = build_optimizer(model, cfg.optimizer) 76 | 77 | if cfg.get('runner') is None: 78 | cfg.runner = {'type': 'IterBasedRunner', 'max_iters': cfg.total_iters} 79 | warnings.warn( 80 | 'config is now expected to have a `runner` section, ' 81 | 'please set `runner` in your config.', UserWarning) 82 | 83 | runner = build_runner( 84 | cfg.runner, 85 | default_args=dict( 86 | model=model, 87 | batch_processor=None, 88 | optimizer=optimizer, 89 | work_dir=cfg.work_dir, 90 | logger=logger, 91 | meta=meta)) 92 | 93 | # register hooks 94 | runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config, 95 | cfg.checkpoint_config, cfg.log_config, 96 | cfg.get('momentum_config', None)) 97 | 98 | # an ugly walkaround to make the .log and .log.json filenames the same 99 | runner.timestamp = timestamp 100 | 101 | # register eval hooks 102 | if validate: 103 | val_dataset = build_dataset(cfg.data.val, dict(test_mode=True)) 104 | val_dataloader = build_dataloader( 105 | val_dataset, 106 | samples_per_gpu=1, 107 | workers_per_gpu=cfg.data.workers_per_gpu, 108 | dist=distributed, 109 | shuffle=False) 110 | eval_cfg = cfg.get('evaluation', {}) 111 | eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner' 112 | eval_hook = DistEvalHook if distributed else EvalHook 113 | runner.register_hook(eval_hook(val_dataloader, **eval_cfg)) 114 | 115 | if cfg.resume_from: 116 | runner.resume(cfg.resume_from) 117 | elif cfg.load_from: 118 | runner.load_checkpoint(cfg.load_from) 119 | runner.run(data_loaders, cfg.workflow) 120 | -------------------------------------------------------------------------------- /modified_mmseg/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .ade import ADE20KDataset 2 | from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset 3 | from .cityscapes import CityscapesDataset 4 | from .custom import CustomDataset 5 | from .dataset_wrappers import ConcatDataset, RepeatDataset 6 | from .pascal_context import PascalContextDataset 7 | 8 | 9 | __all__ = [ 10 | 'CustomDataset', 'build_dataloader', 'ConcatDataset', 'RepeatDataset', 11 | 'DATASETS', 'build_dataset', 'PIPELINES', 'CityscapesDataset', 12 | 'ADE20KDataset', 'PascalContextDataset' 13 | ] 14 | -------------------------------------------------------------------------------- /modified_mmseg/datasets/ade.py: -------------------------------------------------------------------------------- 1 | from .builder import DATASETS 2 | from .custom import CustomDataset 3 | 4 | 5 | @DATASETS.register_module() 6 | class ADE20KDataset(CustomDataset): 7 | """ADE20K dataset. 8 | In segmentation map annotation for ADE20K, 0 stands for background, which 9 | is not included in 150 categories. ``reduce_zero_label`` is fixed to True. 10 | The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is fixed to 11 | '.png'. 12 | """ 13 | CLASSES = ( 14 | 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ', 15 | 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth', 16 | 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car', 17 | 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug', 18 | 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe', 19 | 'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column', 20 | 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', 21 | 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path', 22 | 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door', 23 | 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table', 24 | 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove', 25 | 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar', 26 | 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', 27 | 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver', 28 | 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister', 29 | 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van', 30 | 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything', 31 | 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent', 32 | 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank', 33 | 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake', 34 | 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce', 35 | 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen', 36 | 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', 37 | 'clock', 'flag') 38 | 39 | PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], 40 | [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], 41 | [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], 42 | [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], 43 | [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], 44 | [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], 45 | [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], 46 | [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], 47 | [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], 48 | [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], 49 | [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], 50 | [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], 51 | [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], 52 | [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], 53 | [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], 54 | [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], 55 | [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], 56 | [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], 57 | [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], 58 | [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], 59 | [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], 60 | [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], 61 | [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], 62 | [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], 63 | [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], 64 | [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], 65 | [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], 66 | [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], 67 | [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], 68 | [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], 69 | [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], 70 | [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], 71 | [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], 72 | [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], 73 | [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], 74 | [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], 75 | [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], 76 | [102, 255, 0], [92, 0, 255]] 77 | 78 | def __init__(self, **kwargs): 79 | super(ADE20KDataset, self).__init__( 80 | img_suffix='.jpg', 81 | seg_map_suffix='.png', 82 | reduce_zero_label=True, 83 | **kwargs) -------------------------------------------------------------------------------- /modified_mmseg/datasets/builder.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import platform 3 | import random 4 | from functools import partial 5 | 6 | import numpy as np 7 | from mmcv.parallel import collate 8 | from mmcv.runner import get_dist_info 9 | from mmcv.utils import Registry, build_from_cfg 10 | from mmcv.utils.parrots_wrapper import DataLoader, PoolDataLoader 11 | from torch.utils.data import DistributedSampler 12 | 13 | if platform.system() != 'Windows': 14 | # https://github.com/pytorch/pytorch/issues/973 15 | import resource 16 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 17 | hard_limit = rlimit[1] 18 | soft_limit = min(4096, hard_limit) 19 | resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit)) 20 | 21 | DATASETS = Registry('dataset') 22 | PIPELINES = Registry('pipeline') 23 | 24 | 25 | def _concat_dataset(cfg, default_args=None): 26 | """Build :obj:`ConcatDataset by.""" 27 | from .dataset_wrappers import ConcatDataset 28 | img_dir = cfg['img_dir'] 29 | ann_dir = cfg.get('ann_dir', None) 30 | split = cfg.get('split', None) 31 | num_img_dir = len(img_dir) if isinstance(img_dir, (list, tuple)) else 1 32 | if ann_dir is not None: 33 | num_ann_dir = len(ann_dir) if isinstance(ann_dir, (list, tuple)) else 1 34 | else: 35 | num_ann_dir = 0 36 | if split is not None: 37 | num_split = len(split) if isinstance(split, (list, tuple)) else 1 38 | else: 39 | num_split = 0 40 | if num_img_dir > 1: 41 | assert num_img_dir == num_ann_dir or num_ann_dir == 0 42 | assert num_img_dir == num_split or num_split == 0 43 | else: 44 | assert num_split == num_ann_dir or num_ann_dir <= 1 45 | num_dset = max(num_split, num_img_dir) 46 | 47 | datasets = [] 48 | for i in range(num_dset): 49 | data_cfg = copy.deepcopy(cfg) 50 | if isinstance(img_dir, (list, tuple)): 51 | data_cfg['img_dir'] = img_dir[i] 52 | if isinstance(ann_dir, (list, tuple)): 53 | data_cfg['ann_dir'] = ann_dir[i] 54 | if isinstance(split, (list, tuple)): 55 | data_cfg['split'] = split[i] 56 | datasets.append(build_dataset(data_cfg, default_args)) 57 | 58 | return ConcatDataset(datasets) 59 | 60 | 61 | def build_dataset(cfg, default_args=None): 62 | """Build datasets.""" 63 | from .dataset_wrappers import ConcatDataset, RepeatDataset 64 | if isinstance(cfg, (list, tuple)): 65 | dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg]) 66 | elif cfg['type'] == 'RepeatDataset': 67 | dataset = RepeatDataset( 68 | build_dataset(cfg['dataset'], default_args), cfg['times']) 69 | elif isinstance(cfg.get('img_dir'), (list, tuple)) or isinstance( 70 | cfg.get('split', None), (list, tuple)): 71 | dataset = _concat_dataset(cfg, default_args) 72 | else: 73 | dataset = build_from_cfg(cfg, DATASETS, default_args) 74 | 75 | return dataset 76 | 77 | 78 | def build_dataloader(dataset, 79 | samples_per_gpu, 80 | workers_per_gpu, 81 | num_gpus=1, 82 | dist=True, 83 | shuffle=True, 84 | seed=None, 85 | drop_last=False, 86 | pin_memory=True, 87 | dataloader_type='PoolDataLoader', 88 | **kwargs): 89 | """Build PyTorch DataLoader. 90 | 91 | In distributed training, each GPU/process has a dataloader. 92 | In non-distributed training, there is only one dataloader for all GPUs. 93 | 94 | Args: 95 | dataset (Dataset): A PyTorch dataset. 96 | samples_per_gpu (int): Number of training samples on each GPU, i.e., 97 | batch size of each GPU. 98 | workers_per_gpu (int): How many subprocesses to use for data loading 99 | for each GPU. 100 | num_gpus (int): Number of GPUs. Only used in non-distributed training. 101 | dist (bool): Distributed training/test or not. Default: True. 102 | shuffle (bool): Whether to shuffle the data at every epoch. 103 | Default: True. 104 | seed (int | None): Seed to be used. Default: None. 105 | drop_last (bool): Whether to drop the last incomplete batch in epoch. 106 | Default: False 107 | pin_memory (bool): Whether to use pin_memory in DataLoader. 108 | Default: True 109 | dataloader_type (str): Type of dataloader. Default: 'PoolDataLoader' 110 | kwargs: any keyword argument to be used to initialize DataLoader 111 | 112 | Returns: 113 | DataLoader: A PyTorch dataloader. 114 | """ 115 | rank, world_size = get_dist_info() 116 | if dist: 117 | sampler = DistributedSampler( 118 | dataset, world_size, rank, shuffle=shuffle) 119 | shuffle = False 120 | batch_size = samples_per_gpu 121 | num_workers = workers_per_gpu 122 | else: 123 | sampler = None 124 | batch_size = num_gpus * samples_per_gpu 125 | num_workers = num_gpus * workers_per_gpu 126 | 127 | init_fn = partial( 128 | worker_init_fn, num_workers=num_workers, rank=rank, 129 | seed=seed) if seed is not None else None 130 | 131 | assert dataloader_type in ( 132 | 'DataLoader', 133 | 'PoolDataLoader'), f'unsupported dataloader {dataloader_type}' 134 | 135 | if dataloader_type == 'PoolDataLoader': 136 | dataloader = PoolDataLoader 137 | elif dataloader_type == 'DataLoader': 138 | dataloader = DataLoader 139 | 140 | data_loader = dataloader( 141 | dataset, 142 | batch_size=batch_size, 143 | sampler=sampler, 144 | num_workers=num_workers, 145 | collate_fn=partial(collate, samples_per_gpu=samples_per_gpu), 146 | pin_memory=pin_memory, 147 | shuffle=shuffle, 148 | worker_init_fn=init_fn, 149 | drop_last=drop_last, 150 | **kwargs) 151 | 152 | return data_loader 153 | 154 | 155 | def worker_init_fn(worker_id, num_workers, rank, seed): 156 | """Worker init func for dataloader. 157 | 158 | The seed of each worker equals to num_worker * rank + worker_id + user_seed 159 | 160 | Args: 161 | worker_id (int): Worker id. 162 | num_workers (int): Number of workers. 163 | rank (int): The rank of current process. 164 | seed (int): The random seed to use. 165 | """ 166 | 167 | worker_seed = num_workers * rank + worker_id + seed 168 | np.random.seed(worker_seed) 169 | random.seed(worker_seed) 170 | -------------------------------------------------------------------------------- /modified_mmseg/datasets/cityscapes.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import tempfile 3 | 4 | import mmcv 5 | import numpy as np 6 | from mmcv.utils import print_log 7 | from PIL import Image 8 | 9 | from .builder import DATASETS 10 | from .custom import CustomDataset 11 | 12 | 13 | @DATASETS.register_module() 14 | class CityscapesDataset(CustomDataset): 15 | """Cityscapes dataset. 16 | The ``img_suffix`` is fixed to '_leftImg8bit.png' and ``seg_map_suffix`` is 17 | fixed to '_gtFine_labelTrainIds.png' for Cityscapes dataset. 18 | """ 19 | 20 | CLASSES = ('road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 21 | 'traffic light', 'traffic sign', 'vegetation', 'terrain', 'sky', 22 | 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle', 23 | 'bicycle') 24 | 25 | PALETTE = [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], 26 | [190, 153, 153], [153, 153, 153], [250, 170, 30], [220, 220, 0], 27 | [107, 142, 35], [152, 251, 152], [70, 130, 180], [220, 20, 60], 28 | [255, 0, 0], [0, 0, 142], [0, 0, 70], [0, 60, 100], 29 | [0, 80, 100], [0, 0, 230], [119, 11, 32]] 30 | 31 | def __init__(self, **kwargs): 32 | super(CityscapesDataset, self).__init__( 33 | img_suffix='_leftImg8bit.png', 34 | seg_map_suffix='_gtFine_labelTrainIds.png', 35 | **kwargs) 36 | 37 | @staticmethod 38 | def _convert_to_label_id(result): 39 | """Convert trainId to id for cityscapes.""" 40 | if isinstance(result, str): 41 | result = np.load(result) 42 | import cityscapesscripts.helpers.labels as CSLabels 43 | result_copy = result.copy() 44 | for trainId, label in CSLabels.trainId2label.items(): 45 | result_copy[result == trainId] = label.id 46 | 47 | return result_copy 48 | 49 | def results2img(self, results, imgfile_prefix, to_label_id): 50 | """Write the segmentation results to images. 51 | Args: 52 | results (list[list | tuple | ndarray]): Testing results of the 53 | dataset. 54 | imgfile_prefix (str): The filename prefix of the png files. 55 | If the prefix is "somepath/xxx", 56 | the png files will be named "somepath/xxx.png". 57 | to_label_id (bool): whether convert output to label_id for 58 | submission 59 | Returns: 60 | list[str: str]: result txt files which contains corresponding 61 | semantic segmentation images. 62 | """ 63 | mmcv.mkdir_or_exist(imgfile_prefix) 64 | result_files = [] 65 | prog_bar = mmcv.ProgressBar(len(self)) 66 | for idx in range(len(self)): 67 | result = results[idx] 68 | if to_label_id: 69 | result = self._convert_to_label_id(result) 70 | filename = self.img_infos[idx]['filename'] 71 | basename = osp.splitext(osp.basename(filename))[0] 72 | 73 | png_filename = osp.join(imgfile_prefix, f'{basename}.png') 74 | 75 | output = Image.fromarray(result.astype(np.uint8)).convert('P') 76 | import cityscapesscripts.helpers.labels as CSLabels 77 | palette = np.zeros((len(CSLabels.id2label), 3), dtype=np.uint8) 78 | for label_id, label in CSLabels.id2label.items(): 79 | palette[label_id] = label.color 80 | 81 | output.putpalette(palette) 82 | output.save(png_filename) 83 | result_files.append(png_filename) 84 | prog_bar.update() 85 | 86 | return result_files 87 | 88 | def format_results(self, results, imgfile_prefix=None, to_label_id=True): 89 | """Format the results into dir (standard format for Cityscapes 90 | evaluation). 91 | Args: 92 | results (list): Testing results of the dataset. 93 | imgfile_prefix (str | None): The prefix of images files. It 94 | includes the file path and the prefix of filename, e.g., 95 | "a/b/prefix". If not specified, a temp file will be created. 96 | Default: None. 97 | to_label_id (bool): whether convert output to label_id for 98 | submission. Default: False 99 | Returns: 100 | tuple: (result_files, tmp_dir), result_files is a list containing 101 | the image paths, tmp_dir is the temporal directory created 102 | for saving json/png files when img_prefix is not specified. 103 | """ 104 | 105 | assert isinstance(results, list), 'results must be a list' 106 | assert len(results) == len(self), ( 107 | 'The length of results is not equal to the dataset len: ' 108 | f'{len(results)} != {len(self)}') 109 | 110 | if imgfile_prefix is None: 111 | tmp_dir = tempfile.TemporaryDirectory() 112 | imgfile_prefix = tmp_dir.name 113 | else: 114 | tmp_dir = None 115 | result_files = self.results2img(results, imgfile_prefix, to_label_id) 116 | 117 | return result_files, tmp_dir 118 | 119 | def evaluate(self, 120 | results, 121 | metric='mIoU', 122 | logger=None, 123 | imgfile_prefix=None, 124 | efficient_test=False): 125 | """Evaluation in Cityscapes/default protocol. 126 | Args: 127 | results (list): Testing results of the dataset. 128 | metric (str | list[str]): Metrics to be evaluated. 129 | logger (logging.Logger | None | str): Logger used for printing 130 | related information during evaluation. Default: None. 131 | imgfile_prefix (str | None): The prefix of output image file, 132 | for cityscapes evaluation only. It includes the file path and 133 | the prefix of filename, e.g., "a/b/prefix". 134 | If results are evaluated with cityscapes protocol, it would be 135 | the prefix of output png files. The output files would be 136 | png images under folder "a/b/prefix/xxx.png", where "xxx" is 137 | the image name of cityscapes. If not specified, a temp file 138 | will be created for evaluation. 139 | Default: None. 140 | Returns: 141 | dict[str, float]: Cityscapes/default metrics. 142 | """ 143 | 144 | eval_results = dict() 145 | metrics = metric.copy() if isinstance(metric, list) else [metric] 146 | if 'cityscapes' in metrics: 147 | eval_results.update( 148 | self._evaluate_cityscapes(results, logger, imgfile_prefix)) 149 | metrics.remove('cityscapes') 150 | if len(metrics) > 0: 151 | eval_results.update( 152 | super(CityscapesDataset, 153 | self).evaluate(results, metrics, logger, efficient_test)) 154 | 155 | return eval_results 156 | 157 | def _evaluate_cityscapes(self, results, logger, imgfile_prefix): 158 | """Evaluation in Cityscapes protocol. 159 | Args: 160 | results (list): Testing results of the dataset. 161 | logger (logging.Logger | str | None): Logger used for printing 162 | related information during evaluation. Default: None. 163 | imgfile_prefix (str | None): The prefix of output image file 164 | Returns: 165 | dict[str: float]: Cityscapes evaluation results. 166 | """ 167 | try: 168 | import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa 169 | except ImportError: 170 | raise ImportError('Please run "pip install cityscapesscripts" to ' 171 | 'install cityscapesscripts first.') 172 | msg = 'Evaluating in Cityscapes style' 173 | if logger is None: 174 | msg = '\n' + msg 175 | print_log(msg, logger=logger) 176 | 177 | result_files, tmp_dir = self.format_results(results, imgfile_prefix) 178 | 179 | if tmp_dir is None: 180 | result_dir = imgfile_prefix 181 | else: 182 | result_dir = tmp_dir.name 183 | 184 | eval_results = dict() 185 | print_log(f'Evaluating results under {result_dir} ...', logger=logger) 186 | 187 | CSEval.args.evalInstLevelScore = True 188 | CSEval.args.predictionPath = osp.abspath(result_dir) 189 | CSEval.args.evalPixelAccuracy = True 190 | CSEval.args.JSONOutput = False 191 | 192 | seg_map_list = [] 193 | pred_list = [] 194 | 195 | # when evaluating with official cityscapesscripts, 196 | # **_gtFine_labelIds.png is used 197 | for seg_map in mmcv.scandir( 198 | self.ann_dir, 'gtFine_labelIds.png', recursive=True): 199 | seg_map_list.append(osp.join(self.ann_dir, seg_map)) 200 | pred_list.append(CSEval.getPrediction(CSEval.args, seg_map)) 201 | 202 | eval_results.update( 203 | CSEval.evaluateImgLists(pred_list, seg_map_list, CSEval.args)) 204 | 205 | if tmp_dir is not None: 206 | tmp_dir.cleanup() 207 | 208 | return eval_results -------------------------------------------------------------------------------- /modified_mmseg/datasets/custom.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | from functools import reduce 4 | 5 | import mmcv 6 | import numpy as np 7 | from mmcv.utils import print_log 8 | from terminaltables import AsciiTable 9 | from torch.utils.data import Dataset 10 | 11 | from mmseg.core import eval_metrics 12 | from mmseg.utils import get_root_logger 13 | from .builder import DATASETS 14 | from .pipelines import Compose 15 | 16 | 17 | @DATASETS.register_module() 18 | class CustomDataset(Dataset): 19 | """Custom dataset for semantic segmentation. An example of file structure 20 | is as followed. 21 | 22 | .. code-block:: none 23 | 24 | ├── data 25 | │ ├── my_dataset 26 | │ │ ├── img_dir 27 | │ │ │ ├── train 28 | │ │ │ │ ├── xxx{img_suffix} 29 | │ │ │ │ ├── yyy{img_suffix} 30 | │ │ │ │ ├── zzz{img_suffix} 31 | │ │ │ ├── val 32 | │ │ ├── ann_dir 33 | │ │ │ ├── train 34 | │ │ │ │ ├── xxx{seg_map_suffix} 35 | │ │ │ │ ├── yyy{seg_map_suffix} 36 | │ │ │ │ ├── zzz{seg_map_suffix} 37 | │ │ │ ├── val 38 | 39 | The img/gt_semantic_seg pair of CustomDataset should be of the same 40 | except suffix. A valid img/gt_semantic_seg filename pair should be like 41 | ``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included 42 | in the suffix). If split is given, then ``xxx`` is specified in txt file. 43 | Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded. 44 | Please refer to ``docs/tutorials/new_dataset.md`` for more details. 45 | 46 | 47 | Args: 48 | pipeline (list[dict]): Processing pipeline 49 | img_dir (str): Path to image directory 50 | img_suffix (str): Suffix of images. Default: '.jpg' 51 | ann_dir (str, optional): Path to annotation directory. Default: None 52 | seg_map_suffix (str): Suffix of segmentation maps. Default: '.png' 53 | split (str, optional): Split txt file. If split is specified, only 54 | file with suffix in the splits will be loaded. Otherwise, all 55 | images in img_dir/ann_dir will be loaded. Default: None 56 | data_root (str, optional): Data root for img_dir/ann_dir. Default: 57 | None. 58 | test_mode (bool): If test_mode=True, gt wouldn't be loaded. 59 | ignore_index (int): The label index to be ignored. Default: 255 60 | reduce_zero_label (bool): Whether to mark label zero as ignored. 61 | Default: False 62 | classes (str | Sequence[str], optional): Specify classes to load. 63 | If is None, ``cls.CLASSES`` will be used. Default: None. 64 | palette (Sequence[Sequence[int]]] | np.ndarray | None): 65 | The palette of segmentation map. If None is given, and 66 | self.PALETTE is None, random palette will be generated. 67 | Default: None 68 | """ 69 | 70 | CLASSES = None 71 | 72 | PALETTE = None 73 | 74 | def __init__(self, 75 | pipeline, 76 | img_dir, 77 | img_suffix='.jpg', 78 | ann_dir=None, 79 | dt_dir=None, 80 | seg_map_suffix='.png', 81 | split=None, 82 | data_root=None, 83 | test_mode=False, 84 | ignore_index=255, 85 | reduce_zero_label=False, 86 | classes=None, 87 | palette=None): 88 | self.pipeline = Compose(pipeline) 89 | self.img_dir = img_dir 90 | self.img_suffix = img_suffix 91 | self.ann_dir = ann_dir 92 | self.dt_dir = dt_dir 93 | self.seg_map_suffix = seg_map_suffix 94 | self.split = split 95 | self.data_root = data_root 96 | self.test_mode = test_mode 97 | self.ignore_index = ignore_index 98 | self.reduce_zero_label = reduce_zero_label 99 | self.label_map = None 100 | self.CLASSES, self.PALETTE = self.get_classes_and_palette( 101 | classes, palette) 102 | 103 | # join paths if data_root is specified 104 | if self.data_root is not None: 105 | if not osp.isabs(self.img_dir): 106 | self.img_dir = osp.join(self.data_root, self.img_dir) 107 | if not (self.ann_dir is None or osp.isabs(self.ann_dir)): 108 | self.ann_dir = osp.join(self.data_root, self.ann_dir) 109 | 110 | if not (self.dt_dir is None or osp.isabs(self.dt_dir)): 111 | self.dt_dir = osp.join(self.data_root, self.dt_dir) 112 | 113 | if not (self.split is None or osp.isabs(self.split)): 114 | self.split = osp.join(self.data_root, self.split) 115 | 116 | # load annotations 117 | self.img_infos = self.load_annotations(self.img_dir, self.img_suffix, 118 | self.ann_dir, self.dt_dir, 119 | self.seg_map_suffix, self.split) 120 | 121 | def __len__(self): 122 | """Total number of samples of data.""" 123 | return len(self.img_infos) 124 | 125 | def load_annotations(self, img_dir, img_suffix, ann_dir, dt_dir, seg_map_suffix, 126 | split): 127 | """Load annotation from directory. 128 | 129 | Args: 130 | img_dir (str): Path to image directory 131 | img_suffix (str): Suffix of images. 132 | ann_dir (str|None): Path to annotation directory. 133 | seg_map_suffix (str|None): Suffix of segmentation maps. 134 | split (str|None): Split txt file. If split is specified, only file 135 | with suffix in the splits will be loaded. Otherwise, all images 136 | in img_dir/ann_dir will be loaded. Default: None 137 | 138 | Returns: 139 | list[dict]: All image info of dataset. 140 | """ 141 | 142 | img_infos = [] 143 | 144 | if split is not None: 145 | with open(split) as f: 146 | for line in f: 147 | img_name = line.strip() 148 | img_info = dict(filename=img_name + img_suffix) 149 | if ann_dir is not None: 150 | seg_map = img_name + seg_map_suffix 151 | img_info['ann'] = dict(seg_map=seg_map) 152 | img_infos.append(img_info) 153 | else: 154 | for img in mmcv.scandir(img_dir, img_suffix, recursive=True): 155 | img_info = dict(filename=img) 156 | if ann_dir is not None: 157 | seg_map = img.replace(img_suffix, seg_map_suffix) 158 | img_info['ann'] = dict(seg_map=seg_map) 159 | 160 | if dt_dir is not None: 161 | dt_map = img.replace(img_suffix, '.mat') 162 | img_info['dt'] = dict(dt_map=dt_map) 163 | 164 | img_infos.append(img_info) 165 | 166 | print_log(f'Loaded {len(img_infos)} images', logger=get_root_logger()) 167 | return img_infos 168 | 169 | def get_ann_info(self, idx): 170 | """Get annotation by index. 171 | 172 | Args: 173 | idx (int): Index of data. 174 | 175 | Returns: 176 | dict: Annotation info of specified index. 177 | """ 178 | 179 | return self.img_infos[idx]['ann'] 180 | 181 | def pre_pipeline(self, results): 182 | """Prepare results dict for pipeline.""" 183 | results['seg_fields'] = [] 184 | results['img_prefix'] = self.img_dir 185 | results['seg_prefix'] = self.ann_dir 186 | results['dt_prefix'] = self.dt_dir 187 | if self.custom_classes: 188 | results['label_map'] = self.label_map 189 | 190 | def __getitem__(self, idx): 191 | """Get training/test data after pipeline. 192 | 193 | Args: 194 | idx (int): Index of data. 195 | 196 | Returns: 197 | dict: Training/test data (with annotation if `test_mode` is set 198 | False). 199 | """ 200 | 201 | if self.test_mode: 202 | return self.prepare_test_img(idx) 203 | else: 204 | return self.prepare_train_img(idx) 205 | 206 | def prepare_train_img(self, idx): 207 | """Get training data and annotations after pipeline. 208 | 209 | Args: 210 | idx (int): Index of data. 211 | 212 | Returns: 213 | dict: Training data and annotation after pipeline with new keys 214 | introduced by pipeline. 215 | """ 216 | 217 | img_info = self.img_infos[idx] # filename, ann 218 | ann_info = self.get_ann_info(idx) 219 | 220 | results = dict(img_info=img_info, ann_info=ann_info) 221 | 222 | self.pre_pipeline(results) 223 | return self.pipeline(results) 224 | 225 | def prepare_test_img(self, idx): 226 | """Get testing data after pipeline. 227 | 228 | Args: 229 | idx (int): Index of data. 230 | 231 | Returns: 232 | dict: Testing data after pipeline with new keys intorduced by 233 | piepline. 234 | """ 235 | 236 | img_info = self.img_infos[idx] 237 | results = dict(img_info=img_info) 238 | self.pre_pipeline(results) 239 | return self.pipeline(results) 240 | 241 | def format_results(self, results, **kwargs): 242 | """Place holder to format result to dataset specific output.""" 243 | pass 244 | 245 | def get_gt_seg_maps(self, efficient_test=False): 246 | """Get ground truth segmentation maps for evaluation.""" 247 | gt_seg_maps = [] 248 | for img_info in self.img_infos: 249 | seg_map = osp.join(self.ann_dir, img_info['ann']['seg_map']) 250 | if efficient_test: 251 | gt_seg_map = seg_map 252 | else: 253 | gt_seg_map = mmcv.imread( 254 | seg_map, flag='unchanged', backend='pillow') 255 | gt_seg_maps.append(gt_seg_map) 256 | return gt_seg_maps 257 | 258 | def get_classes_and_palette(self, classes=None, palette=None): 259 | """Get class names of current dataset. 260 | 261 | Args: 262 | classes (Sequence[str] | str | None): If classes is None, use 263 | default CLASSES defined by builtin dataset. If classes is a 264 | string, take it as a file name. The file contains the name of 265 | classes where each line contains one class name. If classes is 266 | a tuple or list, override the CLASSES defined by the dataset. 267 | palette (Sequence[Sequence[int]]] | np.ndarray | None): 268 | The palette of segmentation map. If None is given, random 269 | palette will be generated. Default: None 270 | """ 271 | if classes is None: 272 | self.custom_classes = False 273 | return self.CLASSES, self.PALETTE 274 | 275 | self.custom_classes = True 276 | if isinstance(classes, str): 277 | # take it as a file path 278 | class_names = mmcv.list_from_file(classes) 279 | elif isinstance(classes, (tuple, list)): 280 | class_names = classes 281 | else: 282 | raise ValueError(f'Unsupported type {type(classes)} of classes.') 283 | 284 | if self.CLASSES: 285 | if not set(classes).issubset(self.CLASSES): 286 | raise ValueError('classes is not a subset of CLASSES.') 287 | 288 | # dictionary, its keys are the old label ids and its values 289 | # are the new label ids. 290 | # used for changing pixel labels in load_annotations. 291 | self.label_map = {} 292 | for i, c in enumerate(self.CLASSES): 293 | if c not in class_names: 294 | self.label_map[i] = -1 295 | else: 296 | self.label_map[i] = classes.index(c) 297 | 298 | palette = self.get_palette_for_custom_classes(class_names, palette) 299 | 300 | return class_names, palette 301 | 302 | def get_palette_for_custom_classes(self, class_names, palette=None): 303 | 304 | if self.label_map is not None: 305 | # return subset of palette 306 | palette = [] 307 | for old_id, new_id in sorted( 308 | self.label_map.items(), key=lambda x: x[1]): 309 | if new_id != -1: 310 | palette.append(self.PALETTE[old_id]) 311 | palette = type(self.PALETTE)(palette) 312 | 313 | elif palette is None: 314 | if self.PALETTE is None: 315 | palette = np.random.randint(0, 255, size=(len(class_names), 3)) 316 | else: 317 | palette = self.PALETTE 318 | 319 | return palette 320 | 321 | def evaluate(self, 322 | results, 323 | metric='mIoU', 324 | logger=None, 325 | efficient_test=False, 326 | **kwargs): 327 | """Evaluate the dataset. 328 | 329 | Args: 330 | results (list): Testing results of the dataset. 331 | metric (str | list[str]): Metrics to be evaluated. 'mIoU' and 332 | 'mDice' are supported. 333 | logger (logging.Logger | None | str): Logger used for printing 334 | related information during evaluation. Default: None. 335 | 336 | Returns: 337 | dict[str, float]: Default metrics. 338 | """ 339 | 340 | if isinstance(metric, str): 341 | metric = [metric] 342 | allowed_metrics = ['mIoU', 'mDice'] 343 | if not set(metric).issubset(set(allowed_metrics)): 344 | raise KeyError('metric {} is not supported'.format(metric)) 345 | eval_results = {} 346 | gt_seg_maps = self.get_gt_seg_maps(efficient_test) 347 | if self.CLASSES is None: 348 | num_classes = len( 349 | reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps])) 350 | else: 351 | num_classes = len(self.CLASSES) 352 | ret_metrics = eval_metrics( 353 | results, 354 | gt_seg_maps, 355 | num_classes, 356 | self.ignore_index, 357 | metric, 358 | label_map=self.label_map, 359 | reduce_zero_label=self.reduce_zero_label) 360 | class_table_data = [['Class'] + [m[1:] for m in metric] + ['Acc']] 361 | if self.CLASSES is None: 362 | class_names = tuple(range(num_classes)) 363 | else: 364 | class_names = self.CLASSES 365 | ret_metrics_round = [ 366 | np.round(ret_metric * 100, 2) for ret_metric in ret_metrics 367 | ] 368 | for i in range(num_classes): 369 | class_table_data.append([class_names[i]] + 370 | [m[i] for m in ret_metrics_round[2:]] + 371 | [ret_metrics_round[1][i]]) 372 | summary_table_data = [['Scope'] + 373 | ['m' + head 374 | for head in class_table_data[0][1:]] + ['aAcc']] 375 | ret_metrics_mean = [ 376 | np.round(np.nanmean(ret_metric) * 100, 2) 377 | for ret_metric in ret_metrics 378 | ] 379 | summary_table_data.append(['global'] + ret_metrics_mean[2:] + 380 | [ret_metrics_mean[1]] + 381 | [ret_metrics_mean[0]]) 382 | print_log('per class results:', logger) 383 | table = AsciiTable(class_table_data) 384 | print_log('\n' + table.table, logger=logger) 385 | print_log('Summary:', logger) 386 | table = AsciiTable(summary_table_data) 387 | print_log('\n' + table.table, logger=logger) 388 | 389 | for i in range(1, len(summary_table_data[0])): 390 | eval_results[summary_table_data[0] 391 | [i]] = summary_table_data[1][i] / 100.0 392 | if mmcv.is_list_of(results, str): 393 | for file_name in results: 394 | os.remove(file_name) 395 | return eval_results 396 | -------------------------------------------------------------------------------- /modified_mmseg/datasets/dataset_wrappers.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data.dataset import ConcatDataset as _ConcatDataset 2 | 3 | from .builder import DATASETS 4 | 5 | 6 | @DATASETS.register_module() 7 | class ConcatDataset(_ConcatDataset): 8 | """A wrapper of concatenated dataset. 9 | Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but 10 | concat the group flag for image aspect ratio. 11 | Args: 12 | datasets (list[:obj:`Dataset`]): A list of datasets. 13 | """ 14 | 15 | def __init__(self, datasets): 16 | super(ConcatDataset, self).__init__(datasets) 17 | self.CLASSES = datasets[0].CLASSES 18 | self.PALETTE = datasets[0].PALETTE 19 | 20 | 21 | @DATASETS.register_module() 22 | class RepeatDataset(object): 23 | """A wrapper of repeated dataset. 24 | The length of repeated dataset will be `times` larger than the original 25 | dataset. This is useful when the data loading time is long but the dataset 26 | is small. Using RepeatDataset can reduce the data loading time between 27 | epochs. 28 | Args: 29 | dataset (:obj:`Dataset`): The dataset to be repeated. 30 | times (int): Repeat times. 31 | """ 32 | 33 | def __init__(self, dataset, times): 34 | self.dataset = dataset 35 | self.times = times 36 | self.CLASSES = dataset.CLASSES 37 | self.PALETTE = dataset.PALETTE 38 | self._ori_len = len(self.dataset) 39 | 40 | def __getitem__(self, idx): 41 | """Get item from original dataset.""" 42 | return self.dataset[idx % self._ori_len] 43 | 44 | def __len__(self): 45 | """The length is multiplied by ``times``""" 46 | return self.times * self._ori_len -------------------------------------------------------------------------------- /modified_mmseg/datasets/pascal_context.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | from .builder import DATASETS 4 | from .custom import CustomDataset 5 | 6 | 7 | @DATASETS.register_module() 8 | class PascalContextDataset(CustomDataset): 9 | """PascalContext dataset. 10 | 11 | In segmentation map annotation for PascalContext, 0 stands for background, 12 | which is included in 60 categories. ``reduce_zero_label`` is fixed to 13 | False. The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is 14 | fixed to '.png'. 15 | 16 | Args: 17 | split (str): Split txt file for PascalContext. 18 | """ 19 | 20 | CLASSES = ('background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 21 | 'bus', 'car', 'cat', 'chair', 'cow', 'table', 'dog', 'horse', 22 | 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 23 | 'tvmonitor', 'bag', 'bed', 'bench', 'book', 'building', 24 | 'cabinet', 'ceiling', 'cloth', 'computer', 'cup', 'door', 25 | 'fence', 'floor', 'flower', 'food', 'grass', 'ground', 26 | 'keyboard', 'light', 'mountain', 'mouse', 'curtain', 'platform', 27 | 'sign', 'plate', 'road', 'rock', 'shelves', 'sidewalk', 'sky', 28 | 'snow', 'bedclothes', 'track', 'tree', 'truck', 'wall', 'water', 29 | 'window', 'wood') 30 | 31 | PALETTE = [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], 32 | [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], 33 | [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], 34 | [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], 35 | [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], 36 | [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], 37 | [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], 38 | [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], 39 | [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], 40 | [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], 41 | [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], 42 | [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], 43 | [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], 44 | [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], 45 | [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255]] 46 | 47 | def __init__(self, **kwargs): 48 | super(PascalContextDataset, self).__init__( 49 | img_suffix='.jpg', 50 | seg_map_suffix='.png', 51 | reduce_zero_label=False, 52 | **kwargs) 53 | assert osp.exists(self.img_dir) 54 | -------------------------------------------------------------------------------- /modified_mmseg/datasets/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from .compose import Compose 2 | from .formating import (Collect, ImageToTensor, ToDataContainer, ToTensor, 3 | Transpose, to_tensor) 4 | from .loading import LoadAnnotations, LoadImageFromFile 5 | from .test_time_aug import MultiScaleFlipAug 6 | from .transforms import (CLAHE, AdjustGamma, Normalize, Pad, 7 | PhotoMetricDistortion, RandomCrop, RandomFlip, 8 | RandomRotate, Rerange, Resize, RGB2Gray, SegRescale) 9 | 10 | __all__ = [ 11 | 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer', 12 | 'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile', 13 | 'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop', 14 | 'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate', 15 | 'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray' 16 | ] 17 | -------------------------------------------------------------------------------- /modified_mmseg/datasets/pipelines/compose.py: -------------------------------------------------------------------------------- 1 | import collections 2 | 3 | from mmcv.utils import build_from_cfg 4 | 5 | from ..builder import PIPELINES 6 | 7 | 8 | @PIPELINES.register_module() 9 | class Compose(object): 10 | """Compose multiple transforms sequentially. 11 | 12 | Args: 13 | transforms (Sequence[dict | callable]): Sequence of transform object or 14 | config dict to be composed. 15 | """ 16 | 17 | def __init__(self, transforms): 18 | assert isinstance(transforms, collections.abc.Sequence) 19 | self.transforms = [] 20 | for transform in transforms: 21 | if isinstance(transform, dict): 22 | transform = build_from_cfg(transform, PIPELINES) 23 | self.transforms.append(transform) 24 | elif callable(transform): 25 | self.transforms.append(transform) 26 | else: 27 | raise TypeError('transform must be callable or a dict') 28 | 29 | def __call__(self, data): 30 | """Call function to apply transforms sequentially. 31 | 32 | Args: 33 | data (dict): A result dict contains the data to transform. 34 | 35 | Returns: 36 | dict: Transformed data. 37 | """ 38 | 39 | for t in self.transforms: 40 | data = t(data) 41 | if data is None: 42 | return None 43 | return data 44 | 45 | def __repr__(self): 46 | format_string = self.__class__.__name__ + '(' 47 | for t in self.transforms: 48 | format_string += '\n' 49 | format_string += f' {t}' 50 | format_string += '\n)' 51 | return format_string 52 | -------------------------------------------------------------------------------- /modified_mmseg/datasets/pipelines/formating.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Sequence 2 | 3 | import mmcv 4 | import numpy as np 5 | import torch 6 | from mmcv.parallel import DataContainer as DC 7 | 8 | from ..builder import PIPELINES 9 | 10 | 11 | def to_tensor(data): 12 | """Convert objects of various python types to :obj:`torch.Tensor`. 13 | 14 | Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, 15 | :class:`Sequence`, :class:`int` and :class:`float`. 16 | 17 | Args: 18 | data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to 19 | be converted. 20 | """ 21 | 22 | if isinstance(data, torch.Tensor): 23 | return data 24 | elif isinstance(data, np.ndarray): 25 | return torch.from_numpy(data) 26 | elif isinstance(data, Sequence) and not mmcv.is_str(data): 27 | return torch.tensor(data) 28 | elif isinstance(data, int): 29 | return torch.LongTensor([data]) 30 | elif isinstance(data, float): 31 | return torch.FloatTensor([data]) 32 | else: 33 | raise TypeError(f'type {type(data)} cannot be converted to tensor.') 34 | 35 | 36 | @PIPELINES.register_module() 37 | class ToTensor(object): 38 | """Convert some results to :obj:`torch.Tensor` by given keys. 39 | 40 | Args: 41 | keys (Sequence[str]): Keys that need to be converted to Tensor. 42 | """ 43 | 44 | def __init__(self, keys): 45 | self.keys = keys 46 | 47 | def __call__(self, results): 48 | """Call function to convert data in results to :obj:`torch.Tensor`. 49 | 50 | Args: 51 | results (dict): Result dict contains the data to convert. 52 | 53 | Returns: 54 | dict: The result dict contains the data converted 55 | to :obj:`torch.Tensor`. 56 | """ 57 | 58 | for key in self.keys: 59 | results[key] = to_tensor(results[key]) 60 | return results 61 | 62 | def __repr__(self): 63 | return self.__class__.__name__ + f'(keys={self.keys})' 64 | 65 | 66 | @PIPELINES.register_module() 67 | class ImageToTensor(object): 68 | """Convert image to :obj:`torch.Tensor` by given keys. 69 | 70 | The dimension order of input image is (H, W, C). The pipeline will convert 71 | it to (C, H, W). If only 2 dimension (H, W) is given, the output would be 72 | (1, H, W). 73 | 74 | Args: 75 | keys (Sequence[str]): Key of images to be converted to Tensor. 76 | """ 77 | 78 | def __init__(self, keys): 79 | self.keys = keys 80 | 81 | def __call__(self, results): 82 | """Call function to convert image in results to :obj:`torch.Tensor` and 83 | transpose the channel order. 84 | 85 | Args: 86 | results (dict): Result dict contains the image data to convert. 87 | 88 | Returns: 89 | dict: The result dict contains the image converted 90 | to :obj:`torch.Tensor` and transposed to (C, H, W) order. 91 | """ 92 | 93 | for key in self.keys: 94 | img = results[key] 95 | if len(img.shape) < 3: 96 | img = np.expand_dims(img, -1) 97 | results[key] = to_tensor(img.transpose(2, 0, 1)) 98 | return results 99 | 100 | def __repr__(self): 101 | return self.__class__.__name__ + f'(keys={self.keys})' 102 | 103 | 104 | @PIPELINES.register_module() 105 | class Transpose(object): 106 | """Transpose some results by given keys. 107 | 108 | Args: 109 | keys (Sequence[str]): Keys of results to be transposed. 110 | order (Sequence[int]): Order of transpose. 111 | """ 112 | 113 | def __init__(self, keys, order): 114 | self.keys = keys 115 | self.order = order 116 | 117 | def __call__(self, results): 118 | """Call function to convert image in results to :obj:`torch.Tensor` and 119 | transpose the channel order. 120 | 121 | Args: 122 | results (dict): Result dict contains the image data to convert. 123 | 124 | Returns: 125 | dict: The result dict contains the image converted 126 | to :obj:`torch.Tensor` and transposed to (C, H, W) order. 127 | """ 128 | 129 | for key in self.keys: 130 | results[key] = results[key].transpose(self.order) 131 | return results 132 | 133 | def __repr__(self): 134 | return self.__class__.__name__ + \ 135 | f'(keys={self.keys}, order={self.order})' 136 | 137 | 138 | @PIPELINES.register_module() 139 | class ToDataContainer(object): 140 | """Convert results to :obj:`mmcv.DataContainer` by given fields. 141 | 142 | Args: 143 | fields (Sequence[dict]): Each field is a dict like 144 | ``dict(key='xxx', **kwargs)``. The ``key`` in result will 145 | be converted to :obj:`mmcv.DataContainer` with ``**kwargs``. 146 | Default: ``(dict(key='img', stack=True), 147 | dict(key='gt_semantic_seg'))``. 148 | """ 149 | 150 | def __init__(self, 151 | fields=(dict(key='img', 152 | stack=True), dict(key='gt_semantic_seg'))): 153 | self.fields = fields 154 | 155 | def __call__(self, results): 156 | """Call function to convert data in results to 157 | :obj:`mmcv.DataContainer`. 158 | 159 | Args: 160 | results (dict): Result dict contains the data to convert. 161 | 162 | Returns: 163 | dict: The result dict contains the data converted to 164 | :obj:`mmcv.DataContainer`. 165 | """ 166 | 167 | for field in self.fields: 168 | field = field.copy() 169 | key = field.pop('key') 170 | results[key] = DC(results[key], **field) 171 | return results 172 | 173 | def __repr__(self): 174 | return self.__class__.__name__ + f'(fields={self.fields})' 175 | 176 | 177 | @PIPELINES.register_module() 178 | class DefaultFormatBundle(object): 179 | """Default formatting bundle. 180 | 181 | It simplifies the pipeline of formatting common fields, including "img" 182 | and "gt_semantic_seg". These fields are formatted as follows. 183 | 184 | - img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True) 185 | - gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor, 186 | (3)to DataContainer (stack=True) 187 | """ 188 | 189 | def __call__(self, results): 190 | """Call function to transform and format common fields in results. 191 | 192 | Args: 193 | results (dict): Result dict contains the data to convert. 194 | 195 | Returns: 196 | dict: The result dict contains the data that is formatted with 197 | default bundle. 198 | """ 199 | 200 | if 'img' in results: 201 | img = results['img'] 202 | if len(img.shape) < 3: 203 | img = np.expand_dims(img, -1) 204 | img = np.ascontiguousarray(img.transpose(2, 0, 1)) 205 | results['img'] = DC(to_tensor(img), stack=True) 206 | if 'gt_semantic_seg' in results: 207 | # convert to long 208 | results['gt_semantic_seg'] = DC( 209 | to_tensor(results['gt_semantic_seg'][None, 210 | ...].astype(np.int64)), 211 | stack=True) 212 | 213 | if 'distance_map' in results: 214 | results['distance_map'] = DC(to_tensor(results['distance_map'][None, 215 | ...].astype(np.int64)), 216 | stack=True) 217 | 218 | if 'angle_map' in results: 219 | results['angle_map'] = DC(to_tensor(results['angle_map'][None, 220 | ...].astype(np.int64)), 221 | stack=True) 222 | 223 | return results 224 | 225 | def __repr__(self): 226 | return self.__class__.__name__ 227 | 228 | 229 | @PIPELINES.register_module() 230 | class Collect(object): 231 | """Collect data from the loader relevant to the specific task. 232 | 233 | This is usually the last stage of the data loader pipeline. Typically keys 234 | is set to some subset of "img", "gt_semantic_seg". 235 | 236 | The "img_meta" item is always populated. The contents of the "img_meta" 237 | dictionary depends on "meta_keys". By default this includes: 238 | 239 | - "img_shape": shape of the image input to the network as a tuple 240 | (h, w, c). Note that images may be zero padded on the bottom/right 241 | if the batch tensor is larger than this shape. 242 | 243 | - "scale_factor": a float indicating the preprocessing scale 244 | 245 | - "flip": a boolean indicating if image flip transform was used 246 | 247 | - "filename": path to the image file 248 | 249 | - "ori_shape": original shape of the image as a tuple (h, w, c) 250 | 251 | - "pad_shape": image shape after padding 252 | 253 | - "img_norm_cfg": a dict of normalization information: 254 | - mean - per channel mean subtraction 255 | - std - per channel std divisor 256 | - to_rgb - bool indicating if bgr was converted to rgb 257 | 258 | Args: 259 | keys (Sequence[str]): Keys of results to be collected in ``data``. 260 | meta_keys (Sequence[str], optional): Meta keys to be converted to 261 | ``mmcv.DataContainer`` and collected in ``data[img_metas]``. 262 | Default: ``('filename', 'ori_filename', 'ori_shape', 'img_shape', 263 | 'pad_shape', 'scale_factor', 'flip', 'flip_direction', 264 | 'img_norm_cfg')`` 265 | """ 266 | 267 | def __init__(self, 268 | keys, 269 | meta_keys=('filename', 'ori_filename', 'ori_shape', 270 | 'img_shape', 'pad_shape', 'scale_factor', 'flip', 271 | 'flip_direction', 'img_norm_cfg')): 272 | self.keys = keys 273 | self.meta_keys = meta_keys 274 | 275 | def __call__(self, results): 276 | """Call function to collect keys in results. The keys in ``meta_keys`` 277 | will be converted to :obj:mmcv.DataContainer. 278 | 279 | Args: 280 | results (dict): Result dict contains the data to collect. 281 | 282 | Returns: 283 | dict: The result dict contains the following keys 284 | - keys in``self.keys`` 285 | - ``img_metas`` 286 | """ 287 | 288 | data = {} 289 | img_meta = {} 290 | for key in self.meta_keys: 291 | img_meta[key] = results[key] 292 | data['img_metas'] = DC(img_meta, cpu_only=True) 293 | for key in self.keys: 294 | data[key] = results[key] 295 | return data 296 | 297 | def __repr__(self): 298 | return self.__class__.__name__ + \ 299 | f'(keys={self.keys}, meta_keys={self.meta_keys})' 300 | -------------------------------------------------------------------------------- /modified_mmseg/datasets/pipelines/loading.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | 3 | import mmcv 4 | import numpy as np 5 | import scipy.io as io 6 | 7 | from ..builder import PIPELINES 8 | 9 | 10 | @PIPELINES.register_module() 11 | class LoadImageFromFile(object): 12 | """Load an image from file. 13 | 14 | Required keys are "img_prefix" and "img_info" (a dict that must contain the 15 | key "filename"). Added or updated keys are "filename", "img", "img_shape", 16 | "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`), 17 | "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1). 18 | 19 | Args: 20 | to_float32 (bool): Whether to convert the loaded image to a float32 21 | numpy array. If set to False, the loaded image is an uint8 array. 22 | Defaults to False. 23 | color_type (str): The flag argument for :func:`mmcv.imfrombytes`. 24 | Defaults to 'color'. 25 | file_client_args (dict): Arguments to instantiate a FileClient. 26 | See :class:`mmcv.fileio.FileClient` for details. 27 | Defaults to ``dict(backend='disk')``. 28 | imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default: 29 | 'cv2' 30 | """ 31 | 32 | def __init__(self, 33 | to_float32=False, 34 | color_type='color', 35 | file_client_args=dict(backend='disk'), 36 | imdecode_backend='cv2'): 37 | self.to_float32 = to_float32 38 | self.color_type = color_type 39 | self.file_client_args = file_client_args.copy() 40 | self.file_client = None 41 | self.imdecode_backend = imdecode_backend 42 | 43 | def __call__(self, results): 44 | """Call functions to load image and get image meta information. 45 | 46 | Args: 47 | results (dict): Result dict from :obj:`mmseg.CustomDataset`. 48 | 49 | Returns: 50 | dict: The dict contains loaded image and meta information. 51 | """ 52 | 53 | if self.file_client is None: 54 | self.file_client = mmcv.FileClient(**self.file_client_args) 55 | 56 | if results.get('img_prefix') is not None: 57 | filename = osp.join(results['img_prefix'], 58 | results['img_info']['filename']) 59 | else: 60 | filename = results['img_info']['filename'] 61 | img_bytes = self.file_client.get(filename) 62 | img = mmcv.imfrombytes( 63 | img_bytes, flag=self.color_type, backend=self.imdecode_backend) 64 | if self.to_float32: 65 | img = img.astype(np.float32) 66 | 67 | results['filename'] = filename 68 | results['ori_filename'] = results['img_info']['filename'] 69 | results['img'] = img 70 | results['img_shape'] = img.shape 71 | results['ori_shape'] = img.shape 72 | # Set initial values for default meta_keys 73 | results['pad_shape'] = img.shape 74 | results['scale_factor'] = 1.0 75 | num_channels = 1 if len(img.shape) < 3 else img.shape[2] 76 | results['img_norm_cfg'] = dict( 77 | mean=np.zeros(num_channels, dtype=np.float32), 78 | std=np.ones(num_channels, dtype=np.float32), 79 | to_rgb=False) 80 | return results 81 | 82 | def __repr__(self): 83 | repr_str = self.__class__.__name__ 84 | repr_str += f'(to_float32={self.to_float32},' 85 | repr_str += f"color_type='{self.color_type}'," 86 | repr_str += f"imdecode_backend='{self.imdecode_backend}')" 87 | return repr_str 88 | 89 | 90 | @PIPELINES.register_module() 91 | class LoadAnnotations(object): 92 | """Load annotations for semantic segmentation. 93 | 94 | Args: 95 | reduce_zero_label (bool): Whether reduce all label value by 1. 96 | Usually used for datasets where 0 is background label. 97 | Default: False. 98 | file_client_args (dict): Arguments to instantiate a FileClient. 99 | See :class:`mmcv.fileio.FileClient` for details. 100 | Defaults to ``dict(backend='disk')``. 101 | imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default: 102 | 'pillow' 103 | """ 104 | 105 | def __init__(self, 106 | reduce_zero_label=False, 107 | file_client_args=dict(backend='disk'), 108 | imdecode_backend='pillow'): 109 | self.reduce_zero_label = reduce_zero_label 110 | self.file_client_args = file_client_args.copy() 111 | self.file_client = None 112 | self.imdecode_backend = imdecode_backend 113 | 114 | def __call__(self, results): 115 | """Call function to load multiple types annotations. 116 | 117 | Args: 118 | results (dict): Result dict from :obj:`mmseg.CustomDataset`. 119 | 120 | Returns: 121 | dict: The dict contains loaded semantic segmentation annotations. 122 | """ 123 | 124 | if self.file_client is None: 125 | self.file_client = mmcv.FileClient(**self.file_client_args) 126 | 127 | if results.get('seg_prefix', None) is not None: 128 | filename = osp.join(results['seg_prefix'], 129 | results['ann_info']['seg_map']) 130 | else: 131 | filename = results['ann_info']['seg_map'] 132 | img_bytes = self.file_client.get(filename) 133 | gt_semantic_seg = mmcv.imfrombytes( 134 | img_bytes, flag='unchanged', 135 | backend=self.imdecode_backend).squeeze().astype(np.uint8) 136 | # modify if custom classes 137 | if results.get('label_map', None) is not None: 138 | for old_id, new_id in results['label_map'].items(): 139 | gt_semantic_seg[gt_semantic_seg == old_id] = new_id 140 | # reduce zero_label 141 | if self.reduce_zero_label: 142 | # avoid using underflow conversion 143 | gt_semantic_seg[gt_semantic_seg == 0] = 255 144 | gt_semantic_seg = gt_semantic_seg - 1 145 | gt_semantic_seg[gt_semantic_seg == 254] = 255 146 | results['gt_semantic_seg'] = gt_semantic_seg 147 | results['seg_fields'].append('gt_semantic_seg') 148 | 149 | if results.get('dt_prefix', None) is not None: 150 | dt_filename = osp.join(results['dt_prefix'], 151 | results['img_info']['dt']['dt_map']) 152 | else: 153 | dt_filename = results['dt']['dt_map'] 154 | 155 | dct = self._load_mat(dt_filename) 156 | 157 | distance_map = dct['depth'].astype(np.int32) 158 | dir_deg = dct['dir_deg'].astype(np.float) # in [0, 360 / deg_reduce] 159 | deg_reduce = dct['deg_reduce'][0][0] 160 | 161 | dir_deg = deg_reduce * dir_deg - 180 # in [-180, 180] 162 | 163 | # import ipdb; ipdb.set_trace() 164 | results['distance_map'] = distance_map 165 | results['angle_map'] = dir_deg 166 | results['seg_fields'].append('distance_map') 167 | results['seg_fields'].append('angle_map') 168 | 169 | 170 | return results 171 | 172 | def _load_mat(self, filename): 173 | return io.loadmat(filename) 174 | 175 | def __repr__(self): 176 | repr_str = self.__class__.__name__ 177 | repr_str += f'(reduce_zero_label={self.reduce_zero_label},' 178 | repr_str += f"imdecode_backend='{self.imdecode_backend}')" 179 | return repr_str 180 | 181 | 182 | -------------------------------------------------------------------------------- /modified_mmseg/datasets/pipelines/test_time_aug.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import mmcv 4 | 5 | from ..builder import PIPELINES 6 | from .compose import Compose 7 | 8 | 9 | @PIPELINES.register_module() 10 | class MultiScaleFlipAug(object): 11 | """Test-time augmentation with multiple scales and flipping. 12 | 13 | An example configuration is as followed: 14 | 15 | .. code-block:: 16 | 17 | img_scale=(2048, 1024), 18 | img_ratios=[0.5, 1.0], 19 | flip=True, 20 | transforms=[ 21 | dict(type='Resize', keep_ratio=True), 22 | dict(type='RandomFlip'), 23 | dict(type='Normalize', **img_norm_cfg), 24 | dict(type='Pad', size_divisor=32), 25 | dict(type='ImageToTensor', keys=['img']), 26 | dict(type='Collect', keys=['img']), 27 | ] 28 | 29 | After MultiScaleFLipAug with above configuration, the results are wrapped 30 | into lists of the same length as followed: 31 | 32 | .. code-block:: 33 | 34 | dict( 35 | img=[...], 36 | img_shape=[...], 37 | scale=[(1024, 512), (1024, 512), (2048, 1024), (2048, 1024)] 38 | flip=[False, True, False, True] 39 | ... 40 | ) 41 | 42 | Args: 43 | transforms (list[dict]): Transforms to apply in each augmentation. 44 | img_scale (None | tuple | list[tuple]): Images scales for resizing. 45 | img_ratios (float | list[float]): Image ratios for resizing 46 | flip (bool): Whether apply flip augmentation. Default: False. 47 | flip_direction (str | list[str]): Flip augmentation directions, 48 | options are "horizontal" and "vertical". If flip_direction is list, 49 | multiple flip augmentations will be applied. 50 | It has no effect when flip == False. Default: "horizontal". 51 | """ 52 | 53 | def __init__(self, 54 | transforms, 55 | img_scale, 56 | img_ratios=None, 57 | flip=False, 58 | flip_direction='horizontal'): 59 | self.transforms = Compose(transforms) 60 | if img_ratios is not None: 61 | img_ratios = img_ratios if isinstance(img_ratios, 62 | list) else [img_ratios] 63 | assert mmcv.is_list_of(img_ratios, float) 64 | if img_scale is None: 65 | # mode 1: given img_scale=None and a range of image ratio 66 | self.img_scale = None 67 | assert mmcv.is_list_of(img_ratios, float) 68 | elif isinstance(img_scale, tuple) and mmcv.is_list_of( 69 | img_ratios, float): 70 | assert len(img_scale) == 2 71 | # mode 2: given a scale and a range of image ratio 72 | self.img_scale = [(int(img_scale[0] * ratio), 73 | int(img_scale[1] * ratio)) 74 | for ratio in img_ratios] 75 | else: 76 | # mode 3: given multiple scales 77 | self.img_scale = img_scale if isinstance(img_scale, 78 | list) else [img_scale] 79 | assert mmcv.is_list_of(self.img_scale, tuple) or self.img_scale is None 80 | self.flip = flip 81 | self.img_ratios = img_ratios 82 | self.flip_direction = flip_direction if isinstance( 83 | flip_direction, list) else [flip_direction] 84 | assert mmcv.is_list_of(self.flip_direction, str) 85 | if not self.flip and self.flip_direction != ['horizontal']: 86 | warnings.warn( 87 | 'flip_direction has no effect when flip is set to False') 88 | if (self.flip 89 | and not any([t['type'] == 'RandomFlip' for t in transforms])): 90 | warnings.warn( 91 | 'flip has no effect when RandomFlip is not in transforms') 92 | 93 | def __call__(self, results): 94 | """Call function to apply test time augment transforms on results. 95 | 96 | Args: 97 | results (dict): Result dict contains the data to transform. 98 | 99 | Returns: 100 | dict[str: list]: The augmented data, where each value is wrapped 101 | into a list. 102 | """ 103 | 104 | aug_data = [] 105 | if self.img_scale is None and mmcv.is_list_of(self.img_ratios, float): 106 | h, w = results['img'].shape[:2] 107 | img_scale = [(int(w * ratio), int(h * ratio)) 108 | for ratio in self.img_ratios] 109 | else: 110 | img_scale = self.img_scale 111 | flip_aug = [False, True] if self.flip else [False] 112 | for scale in img_scale: 113 | for flip in flip_aug: 114 | for direction in self.flip_direction: 115 | _results = results.copy() 116 | _results['scale'] = scale 117 | _results['flip'] = flip 118 | _results['flip_direction'] = direction 119 | data = self.transforms(_results) 120 | aug_data.append(data) 121 | # list of dict to dict of list 122 | aug_data_dict = {key: [] for key in aug_data[0]} 123 | for data in aug_data: 124 | for key, val in data.items(): 125 | aug_data_dict[key].append(val) 126 | return aug_data_dict 127 | 128 | def __repr__(self): 129 | repr_str = self.__class__.__name__ 130 | repr_str += f'(transforms={self.transforms}, ' 131 | repr_str += f'img_scale={self.img_scale}, flip={self.flip})' 132 | repr_str += f'flip_direction={self.flip_direction}' 133 | return repr_str 134 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | mmcv 2 | mmsegmentation 3 | timm -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Modified from mmsegmentation (https://github.com/open-mmlab/mmsegmentation) 3 | # Apache-2.0 License 4 | # Copyright (c) Open-MMLab. 5 | # ------------------------------------------------------------------------------------------------ 6 | 7 | import argparse 8 | import os 9 | 10 | import mmcv 11 | import torch 12 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel 13 | from mmcv.runner import get_dist_info, init_dist, load_checkpoint 14 | from mmcv.utils import DictAction 15 | 16 | from mmseg.apis import multi_gpu_test, single_gpu_test 17 | from modified_mmseg.datasets import build_dataset 18 | from mmseg.datasets import build_dataloader 19 | from builder import build_segmentor 20 | from models import * 21 | 22 | 23 | 24 | def parse_args(): 25 | parser = argparse.ArgumentParser( 26 | description='test/eval a model') 27 | parser.add_argument('--config', type=str, help='test config file path') 28 | parser.add_argument('--checkpoint', help='checkpoint file') 29 | parser.add_argument('--aug-test', action='store_true', help='Use Flip and Multi scale aug') 30 | parser.add_argument('--out', help='output result file in pickle format') 31 | parser.add_argument('--format-only', action='store_true', help='Format the output results without perform evaluation. It is' 32 | 'useful when you want to format the result to a specific format and ' 33 | 'submit it to the test server') 34 | parser.add_argument('--eval', type=str, nargs='+', help='evaluation metrics, which depends on the dataset, e.g., "mIoU"' 35 | ' for generic datasets, and "cityscapes" for Cityscapes') 36 | parser.add_argument('--show', action='store_true', help='show results') 37 | parser.add_argument('--show-dir', help='directory where painted images will be saved') 38 | parser.add_argument('--gpu-collect', action='store_true', help='whether to use gpu to collect results.') 39 | parser.add_argument('--tmpdir', help='tmp directory used for collecting results from multiple ' 40 | 'workers, available when gpu_collect is not specified') 41 | parser.add_argument('--options', nargs='+', action=DictAction, help='custom options') 42 | parser.add_argument('--eval-options', nargs='+', action=DictAction, help='custom options for evaluation') 43 | parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm', 'mpi'], default='none', help='job launcher') 44 | parser.add_argument('--local_rank', type=int, default=0) 45 | args = parser.parse_args() 46 | if 'LOCAL_RANK' not in os.environ: 47 | os.environ['LOCAL_RANK'] = str(args.local_rank) 48 | return args 49 | 50 | 51 | def main(): 52 | args = parse_args() 53 | 54 | assert args.out or args.eval or args.format_only or args.show \ 55 | or args.show_dir, \ 56 | ('Please specify at least one operation (save/eval/format/show the ' 57 | 'results / save the results) with the argument "--out", "--eval"' 58 | ', "--format-only", "--show" or "--show-dir"') 59 | 60 | if args.eval and args.format_only: 61 | raise ValueError('--eval and --format_only cannot be both specified') 62 | 63 | if args.out is not None and not args.out.endswith(('.pkl', '.pickle')): 64 | raise ValueError('The output file must be a pkl file.') 65 | 66 | cfg = mmcv.Config.fromfile(args.config) 67 | if args.options is not None: 68 | cfg.merge_from_dict(args.options) 69 | # set cudnn_benchmark 70 | if cfg.get('cudnn_benchmark', False): 71 | torch.backends.cudnn.benchmark = True 72 | if args.aug_test: 73 | # hard code index 74 | cfg.data.test.pipeline[1].img_ratios = [ 75 | 0.5, 0.75, 1.0, 1.25, 1.5, 1.75 76 | ] 77 | cfg.data.test.pipeline[1].flip = True 78 | # cfg.model.pretrained = None 79 | cfg.data.test.test_mode = True 80 | 81 | 82 | # init distributed env first, since logger depends on the dist info. 83 | if args.launcher == 'none': 84 | distributed = False 85 | else: 86 | distributed = True 87 | init_dist(args.launcher, **cfg.dist_params) 88 | 89 | # build the dataloader 90 | # TODO: support multiple images per gpu (only minor changes are needed) 91 | dataset = build_dataset(cfg.data.test) 92 | 93 | data_loader = build_dataloader( 94 | dataset, 95 | samples_per_gpu=1, 96 | workers_per_gpu=cfg.data.workers_per_gpu, 97 | dist=distributed, 98 | shuffle=False) 99 | # build the model and load checkpoint 100 | model = build_segmentor(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg) 101 | 102 | # later will be replaced 103 | checkpoint = torch.load(args.checkpoint, map_location="cpu") 104 | state_dict = checkpoint['state_dict'] 105 | 106 | model.load_state_dict(state_dict) 107 | 108 | model.CLASSES = dataset.CLASSES 109 | model.PALETTE = dataset.PALETTE 110 | 111 | if not distributed: 112 | model = MMDataParallel(model, device_ids=[0]) 113 | outputs = single_gpu_test(model, data_loader, args.show, args.show_dir) 114 | else: 115 | model = MMDistributedDataParallel( 116 | model.cuda(), 117 | device_ids=[torch.cuda.current_device()], 118 | broadcast_buffers=False) 119 | outputs = multi_gpu_test(model, data_loader, args.tmpdir, 120 | args.gpu_collect) 121 | 122 | rank, _ = get_dist_info() 123 | if rank == 0: 124 | if args.out: 125 | print(f'\nwriting results to {args.out}') 126 | mmcv.dump(outputs, args.out) 127 | kwargs = {} if args.eval_options is None else args.eval_options 128 | if args.format_only: 129 | dataset.format_results(outputs, **kwargs) 130 | if args.eval: 131 | dataset.evaluate(outputs, args.eval, **kwargs) 132 | 133 | 134 | if __name__ == '__main__': 135 | main() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Modified from mmsegmentation (https://github.com/open-mmlab/mmsegmentation) 3 | # Apache-2.0 License 4 | # Copyright (c) Open-MMLab. 5 | # ------------------------------------------------------------------------------------------------ 6 | 7 | import argparse 8 | import copy 9 | import os 10 | import os.path as osp 11 | import time 12 | 13 | import mmcv 14 | import torch 15 | from mmcv.runner import init_dist 16 | from mmcv.utils import Config, DictAction 17 | 18 | from mmseg import __version__ 19 | 20 | 21 | from modified_mmseg.apis import set_random_seed, train_segmentor 22 | from modified_mmseg.datasets import build_dataset 23 | from mmseg.utils import collect_env, get_root_logger 24 | from models import * 25 | from builder import build_segmentor 26 | 27 | 28 | def parse_args(): 29 | parser = argparse.ArgumentParser(description='Train UN_EPT') 30 | parser.add_argument('--config', help='train config file path') 31 | parser.add_argument('--work_dir', help='the dir to save logs and models') 32 | parser.add_argument('--load-from', help='the checkpoint file to load weights from') 33 | parser.add_argument('--resume-from', help='the checkpoint file to resume from') 34 | parser.add_argument('--no-validate', action='store_true', help='whether not to evaluate the checkpoint during training') 35 | group_gpus = parser.add_mutually_exclusive_group() 36 | group_gpus.add_argument('--gpus', type=int, help='number of gpus to use ' 37 | '(only applicable to non-distributed training)') 38 | group_gpus.add_argument('--gpu-ids', type=int, nargs='+', help='ids of gpus to use ' 39 | '(only applicable to non-distributed training)') 40 | parser.add_argument('--seed', type=int, default=None, help='random seed') 41 | parser.add_argument('--deterministic', action='store_true', help='whether to set deterministic options for CUDNN backend.') 42 | parser.add_argument('--options', nargs='+', action=DictAction, help='custom options') 43 | parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm', 'mpi'], default='none', help='job launcher') 44 | parser.add_argument('--local_rank', type=int, default=0) 45 | args = parser.parse_args() 46 | if 'LOCAL_RANK' not in os.environ: 47 | os.environ['LOCAL_RANK'] = str(args.local_rank) 48 | 49 | return args 50 | 51 | 52 | def main(): 53 | args = parse_args() 54 | 55 | cfg = Config.fromfile(args.config) 56 | if args.options is not None: 57 | cfg.merge_from_dict(args.options) 58 | # set cudnn_benchmark 59 | if cfg.get('cudnn_benchmark', False): 60 | torch.backends.cudnn.benchmark = True 61 | 62 | # work_dir is determined in this priority: CLI > segment in file > filename 63 | if args.work_dir is not None: 64 | # update configs according to CLI args if args.work_dir is not None 65 | cfg.work_dir = args.work_dir 66 | elif cfg.get('work_dir', None) is None: 67 | # use config filename as default work_dir if cfg.work_dir is None 68 | cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0]) 69 | if args.load_from is not None: 70 | cfg.load_from = args.load_from 71 | if args.resume_from is not None: 72 | cfg.resume_from = args.resume_from 73 | if args.gpu_ids is not None: 74 | cfg.gpu_ids = args.gpu_ids 75 | else: 76 | cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus) 77 | 78 | # init distributed env first, since logger depends on the dist info. 79 | if args.launcher == 'none': 80 | distributed = False 81 | else: 82 | distributed = True 83 | init_dist(args.launcher, **cfg.dist_params) 84 | 85 | # create work_dir 86 | mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) 87 | # dump config 88 | cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) 89 | # init the logger before other steps 90 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 91 | log_file = osp.join(cfg.work_dir, f'{timestamp}.log') 92 | logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) 93 | 94 | # init the meta dict to record some important information such as 95 | # environment info and seed, which will be logged 96 | meta = dict() 97 | # log env info 98 | env_info_dict = collect_env() 99 | env_info = '\n'.join([f'{k}: {v}' for k, v in env_info_dict.items()]) 100 | dash_line = '-' * 60 + '\n' 101 | logger.info('Environment info:\n' + dash_line + env_info + '\n' + 102 | dash_line) 103 | meta['env_info'] = env_info 104 | 105 | # log some basic info 106 | logger.info(f'Distributed training: {distributed}') 107 | logger.info(f'Config:\n{cfg.pretty_text}') 108 | 109 | # set random seeds 110 | if args.seed is not None: 111 | logger.info(f'Set random seed to {args.seed}, deterministic: ' 112 | f'{args.deterministic}') 113 | set_random_seed(args.seed, deterministic=args.deterministic) 114 | cfg.seed = args.seed 115 | meta['seed'] = args.seed 116 | meta['exp_name'] = osp.basename(args.config) 117 | 118 | model = build_segmentor( 119 | cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) 120 | logger.info(model) 121 | 122 | datasets = [build_dataset(cfg.data.train)] 123 | 124 | if len(cfg.workflow) == 2: 125 | val_dataset = copy.deepcopy(cfg.data.val) 126 | val_dataset.pipeline = cfg.data.train.pipeline 127 | datasets.append(build_dataset(val_dataset)) 128 | if cfg.checkpoint_config is not None: 129 | # save mmseg version, config file content and class names in 130 | # checkpoints as meta data 131 | cfg.checkpoint_config.meta = dict( 132 | mmseg_version=__version__, 133 | config=cfg.pretty_text, 134 | CLASSES=datasets[0].CLASSES, 135 | PALETTE=datasets[0].PALETTE) 136 | # add an attribute for visualization convenience 137 | model.CLASSES = datasets[0].CLASSES 138 | 139 | train_segmentor( 140 | model, 141 | datasets, 142 | cfg, 143 | distributed=distributed, 144 | validate=(not args.no_validate), 145 | timestamp=timestamp, 146 | meta=meta) 147 | 148 | 149 | if __name__ == '__main__': 150 | main() 151 | --------------------------------------------------------------------------------