├── LICENSE ├── README.md ├── configs ├── _base_ │ ├── datasets │ │ └── pannuke.py │ ├── default_runtime.py │ ├── models │ │ ├── attunet_res34_d5.py │ │ ├── attunet_vgg32_d5.py │ │ ├── cenet_res34_d5.py │ │ ├── cenet_vgg32_d5.py │ │ ├── mctrans_res34_d5.py │ │ ├── mctrans_vgg32_d5.py │ │ ├── nonlocal_res34_d5.py │ │ ├── nonlocal_vgg32_d5.py │ │ ├── transunet_res34_d5.py │ │ ├── transunet_vgg32_d5.py │ │ ├── unet++_res34_d5.py │ │ ├── unet++_vgg32_d5.py │ │ ├── unet_res34_d5.py │ │ └── unet_vgg32_d5.py │ └── schedules │ │ └── pannuke_bs32_ep400.py └── pannuke-vgg32 │ ├── attunet_vgg32_d5_256x256_400ep_pannuke.py │ ├── cenet_vgg32_d5_256x256_400ep_pannuke.py │ ├── mctrans_vgg32_d5_256x256_400ep_pannuke.py │ ├── nonlocal_vgg32_d5_256x256_400ep_pannuke.py │ ├── transunet_vgg32_d5_256x256_400ep_pannuke.py │ ├── unet++_vgg32_d5_256x256_400ep_pannuke.py │ └── unet_vgg32_d5_256x256_400ep_pannuke.py ├── docs └── guidance.md ├── imgs ├── logo.png └── overview.png ├── mctrans ├── __init__.py ├── data │ ├── __init__.py │ ├── builder.py │ ├── datasets │ │ ├── __init__.py │ │ ├── base.py │ │ ├── pannuke.py │ │ └── polyp.py │ └── transforms │ │ ├── __init__.py │ │ ├── base.py │ │ ├── monai.py │ │ └── utils.py ├── metrics │ ├── __init__.py │ ├── base.py │ ├── builder.py │ └── hausdorff_distance.py ├── models │ ├── __init__.py │ ├── builder.py │ ├── centers │ │ ├── __init__.py │ │ ├── cenet.py │ │ ├── mctrans.py │ │ ├── non_local.py │ │ └── vit.py │ ├── decoders │ │ ├── __init__.py │ │ ├── unet_decoder.py │ │ └── unet_plus_plus_decoder.py │ ├── encoders │ │ ├── __init__.py │ │ ├── resnet.py │ │ └── vgg.py │ ├── heads │ │ ├── __init__.py │ │ ├── basic_seg_head.py │ │ └── mctrans_aux_head.py │ ├── losses │ │ ├── __init__.py │ │ ├── cross_entropy_loss.py │ │ ├── debug_focal.py │ │ └── monai.py │ ├── ops │ │ ├── __init__.py │ │ ├── functions │ │ │ ├── __init__.py │ │ │ └── ms_deform_attn_func.py │ │ ├── make.sh │ │ ├── modules │ │ │ ├── __init__.py │ │ │ └── ms_deform_attn.py │ │ ├── setup.py │ │ ├── src │ │ │ ├── cpu │ │ │ │ ├── ms_deform_attn_cpu.cpp │ │ │ │ └── ms_deform_attn_cpu.h │ │ │ ├── cuda │ │ │ │ ├── ms_deform_attn_cuda.cu │ │ │ │ ├── ms_deform_attn_cuda.h │ │ │ │ └── ms_deform_im2col_cuda.cuh │ │ │ ├── ms_deform_attn.h │ │ │ └── vision.cpp │ │ └── test.py │ ├── segmentors │ │ ├── __init__.py │ │ ├── base.py │ │ └── encoder_decoder.py │ ├── trans │ │ ├── __init__.py │ │ ├── transformer.py │ │ └── utils.py │ └── utils.py ├── pipline │ ├── __init__.py │ └── segpipline.py └── utils │ ├── __init__.py │ ├── logger.py │ └── misc.py ├── requirements.txt ├── setup.py └── tools ├── train.py └── train.sh /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020 SenseTime. All Rights Reserved. 2 | 3 | Apache License 4 | Version 2.0, January 2004 5 | http://www.apache.org/licenses/ 6 | 7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 8 | 9 | 1. Definitions. 10 | 11 | "License" shall mean the terms and conditions for use, reproduction, 12 | and distribution as defined by Sections 1 through 9 of this document. 13 | 14 | "Licensor" shall mean the copyright owner or entity authorized by 15 | the copyright owner that is granting the License. 16 | 17 | "Legal Entity" shall mean the union of the acting entity and all 18 | other entities that control, are controlled by, or are under common 19 | control with that entity. For the purposes of this definition, 20 | "control" means (i) the power, direct or indirect, to cause the 21 | direction or management of such entity, whether by contract or 22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity 26 | exercising permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, 29 | including but not limited to software source code, documentation 30 | source, and configuration files. 31 | 32 | "Object" form shall mean any form resulting from mechanical 33 | transformation or translation of a Source form, including but 34 | not limited to compiled object code, generated documentation, 35 | and conversions to other media types. 36 | 37 | "Work" shall mean the work of authorship, whether in Source or 38 | Object form, made available under the License, as indicated by a 39 | copyright notice that is included in or attached to the work 40 | (an example is provided in the Appendix below). 41 | 42 | "Derivative Works" shall mean any work, whether in Source or Object 43 | form, that is based on (or derived from) the Work and for which the 44 | editorial revisions, annotations, elaborations, or other modifications 45 | represent, as a whole, an original work of authorship. For the purposes 46 | of this License, Derivative Works shall not include works that remain 47 | separable from, or merely link (or bind by name) to the interfaces of, 48 | the Work and Derivative Works thereof. 49 | 50 | "Contribution" shall mean any work of authorship, including 51 | the original version of the Work and any modifications or additions 52 | to that Work or Derivative Works thereof, that is intentionally 53 | submitted to Licensor for inclusion in the Work by the copyright owner 54 | or by an individual or Legal Entity authorized to submit on behalf of 55 | the copyright owner. For the purposes of this definition, "submitted" 56 | means any form of electronic, verbal, or written communication sent 57 | to the Licensor or its representatives, including but not limited to 58 | communication on electronic mailing lists, source code control systems, 59 | and issue tracking systems that are managed by, or on behalf of, the 60 | Licensor for the purpose of discussing and improving the Work, but 61 | excluding communication that is conspicuously marked or otherwise 62 | designated in writing by the copyright owner as "Not a Contribution." 63 | 64 | "Contributor" shall mean Licensor and any individual or Legal Entity 65 | on behalf of whom a Contribution has been received by Licensor and 66 | subsequently incorporated within the Work. 67 | 68 | 2. Grant of Copyright License. Subject to the terms and conditions of 69 | this License, each Contributor hereby grants to You a perpetual, 70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 71 | copyright license to reproduce, prepare Derivative Works of, 72 | publicly display, publicly perform, sublicense, and distribute the 73 | Work and such Derivative Works in Source or Object form. 74 | 75 | 3. Grant of Patent License. Subject to the terms and conditions of 76 | this License, each Contributor hereby grants to You a perpetual, 77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 78 | (except as stated in this section) patent license to make, have made, 79 | use, offer to sell, sell, import, and otherwise transfer the Work, 80 | where such license applies only to those patent claims licensable 81 | by such Contributor that are necessarily infringed by their 82 | Contribution(s) alone or by combination of their Contribution(s) 83 | with the Work to which such Contribution(s) was submitted. If You 84 | institute patent litigation against any entity (including a 85 | cross-claim or counterclaim in a lawsuit) alleging that the Work 86 | or a Contribution incorporated within the Work constitutes direct 87 | or contributory patent infringement, then any patent licenses 88 | granted to You under this License for that Work shall terminate 89 | as of the date such litigation is filed. 90 | 91 | 4. Redistribution. You may reproduce and distribute copies of the 92 | Work or Derivative Works thereof in any medium, with or without 93 | modifications, and in Source or Object form, provided that You 94 | meet the following conditions: 95 | 96 | (a) You must give any other recipients of the Work or 97 | Derivative Works a copy of this License; and 98 | 99 | (b) You must cause any modified files to carry prominent notices 100 | stating that You changed the files; and 101 | 102 | (c) You must retain, in the Source form of any Derivative Works 103 | that You distribute, all copyright, patent, trademark, and 104 | attribution notices from the Source form of the Work, 105 | excluding those notices that do not pertain to any part of 106 | the Derivative Works; and 107 | 108 | (d) If the Work includes a "NOTICE" text file as part of its 109 | distribution, then any Derivative Works that You distribute must 110 | include a readable copy of the attribution notices contained 111 | within such NOTICE file, excluding those notices that do not 112 | pertain to any part of the Derivative Works, in at least one 113 | of the following places: within a NOTICE text file distributed 114 | as part of the Derivative Works; within the Source form or 115 | documentation, if provided along with the Derivative Works; or, 116 | within a display generated by the Derivative Works, if and 117 | wherever such third-party notices normally appear. The contents 118 | of the NOTICE file are for informational purposes only and 119 | do not modify the License. You may add Your own attribution 120 | notices within Derivative Works that You distribute, alongside 121 | or as an addendum to the NOTICE text from the Work, provided 122 | that such additional attribution notices cannot be construed 123 | as modifying the License. 124 | 125 | You may add Your own copyright statement to Your modifications and 126 | may provide additional or different license terms and conditions 127 | for use, reproduction, or distribution of Your modifications, or 128 | for any such Derivative Works as a whole, provided Your use, 129 | reproduction, and distribution of the Work otherwise complies with 130 | the conditions stated in this License. 131 | 132 | 5. Submission of Contributions. Unless You explicitly state otherwise, 133 | any Contribution intentionally submitted for inclusion in the Work 134 | by You to the Licensor shall be under the terms and conditions of 135 | this License, without any additional terms or conditions. 136 | Notwithstanding the above, nothing herein shall supersede or modify 137 | the terms of any separate license agreement you may have executed 138 | with Licensor regarding such Contributions. 139 | 140 | 6. Trademarks. This License does not grant permission to use the trade 141 | names, trademarks, service marks, or product names of the Licensor, 142 | except as required for reasonable and customary use in describing the 143 | origin of the Work and reproducing the content of the NOTICE file. 144 | 145 | 7. Disclaimer of Warranty. Unless required by applicable law or 146 | agreed to in writing, Licensor provides the Work (and each 147 | Contributor provides its Contributions) on an "AS IS" BASIS, 148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 149 | implied, including, without limitation, any warranties or conditions 150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 151 | PARTICULAR PURPOSE. You are solely responsible for determining the 152 | appropriateness of using or redistributing the Work and assume any 153 | risks associated with Your exercise of permissions under this License. 154 | 155 | 8. Limitation of Liability. In no event and under no legal theory, 156 | whether in tort (including negligence), contract, or otherwise, 157 | unless required by applicable law (such as deliberate and grossly 158 | negligent acts) or agreed to in writing, shall any Contributor be 159 | liable to You for damages, including any direct, indirect, special, 160 | incidental, or consequential damages of any character arising as a 161 | result of this License or out of the use or inability to use the 162 | Work (including but not limited to damages for loss of goodwill, 163 | work stoppage, computer failure or malfunction, or any and all 164 | other commercial damages or losses), even if such Contributor 165 | has been advised of the possibility of such damages. 166 | 167 | 9. Accepting Warranty or Additional Liability. While redistributing 168 | the Work or Derivative Works thereof, You may choose to offer, 169 | and charge a fee for, acceptance of support, warranty, indemnity, 170 | or other liability obligations and/or rights consistent with this 171 | License. However, in accepting such obligations, You may act only 172 | on Your own behalf and on Your sole responsibility, not on behalf 173 | of any other Contributor, and only if You agree to indemnify, 174 | defend, and hold each Contributor harmless for any liability 175 | incurred by, or claims asserted against, such Contributor by reason 176 | of your accepting any such warranty or additional liability. 177 | 178 | END OF TERMS AND CONDITIONS 179 | 180 | APPENDIX: How to apply the Apache License to your work. 181 | 182 | To apply the Apache License to your work, attach the following 183 | boilerplate notice, with the fields enclosed by brackets "[]" 184 | replaced with your own identifying information. (Don't include 185 | the brackets!) The text should be enclosed in the appropriate 186 | comment syntax for the file format. We also recommend that a 187 | file or class name and description of purpose be included on the 188 | same "printed page" as the copyright notice for easier 189 | identification within third-party archives. 190 | 191 | Copyright 2020 SenseTime 192 | 193 | Licensed under the Apache License, Version 2.0 (the "License"); 194 | you may not use this file except in compliance with the License. 195 | You may obtain a copy of the License at 196 | 197 | http://www.apache.org/licenses/LICENSE-2.0 198 | 199 | Unless required by applicable law or agreed to in writing, software 200 | distributed under the License is distributed on an "AS IS" BASIS, 201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 202 | See the License for the specific language governing permissions and 203 | limitations under the License. 204 | 205 | 206 | DETR 207 | 208 | Copyright 2020 - present, Facebook, Inc 209 | 210 | Licensed under the Apache License, Version 2.0 (the "License"); 211 | you may not use this file except in compliance with the License. 212 | You may obtain a copy of the License at 213 | 214 | http://www.apache.org/licenses/LICENSE-2.0 215 | 216 | Unless required by applicable law or agreed to in writing, software 217 | distributed under the License is distributed on an "AS IS" BASIS, 218 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 219 | See the License for the specific language governing permissions and 220 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |
4 | 5 | 6 | ## News 7 | The code of MCTrans has been released. if you are interested in contributing to the standardization of the medical image analysis community, please feel free to contact me. 8 | 9 | ## Introduction 10 | 11 | - This repository provides code for "**Multi-Compound Transformer for Accurate Biomedical Image Segmentation**" [[paper](https://arxiv.org/pdf/2106.14385.pdf)]. 12 | 13 | - The MCTrans repository heavily references and uses the packages of [MMSegmentation](https://github.com/open-mmlab/mmsegmentation), [MMCV](https://github.com/open-mmlab/mmcv), and [MONAI](https://monai.io/). We thank them for their selfless contributions 14 | 15 | 16 | 17 | ## Highlights 18 | 19 | - A comprehensive toolbox for medical image segmentation, including flexible data loading, processing, modular network construction, and more. 20 | 21 | - Supports representative and popular medical image segmentation methods, e.g. UNet, UNet++, CENet, AttentionUNet, etc. 22 | 23 | 24 | 25 | ## Changelog 26 | The first version was released on 2021.7.16. 27 | 28 | ## Model Zoo 29 | 30 | Supported backbones: 31 | 32 | - [x] VGG 33 | - [x] ResNet 34 | 35 | Supported methods: 36 | 37 | - [x] UNet 38 | - [x] UNet++ 39 | - [x] AttentionUNet 40 | - [x] CENet 41 | - [x] TransUNet 42 | - [x] NonLocalUNet 43 | 44 | ## Installation and Usage 45 | 46 | Please see the [guidance.md](docs/guidance.md). 47 | 48 | 49 | 50 | ## Citation 51 | 52 | If you find this project useful in your research, please consider cite: 53 | 54 | ```latex 55 | @article{ji2021multi, 56 | title={Multi-Compound Transformer for Accurate Biomedical Image Segmentation}, 57 | author={Ji, Yuanfeng and Zhang, Ruimao and Wang, Huijie and Li, Zhen and Wu, Lingyun and Zhang, Shaoting and Luo, Ping}, 58 | journal={arXiv preprint arXiv:2106.14385}, 59 | year={2021} 60 | } 61 | ``` 62 | 63 | 64 | 65 | ## Contribution 66 | 67 | I don't have a lot of time to improve the code base at this stage, so if you have some free time and are interested in contributing to the standardization of the medical image analysis community, please feel free to contact me (jyuanfeng8@gmail.com). 68 | 69 | 70 | 71 | ## License 72 | 73 | This project is released under the [Apache 2.0 license](LICENSE). 74 | 75 | -------------------------------------------------------------------------------- /configs/_base_/datasets/pannuke.py: -------------------------------------------------------------------------------- 1 | dataset_type = 'PanNukeDataset' 2 | patch_size = [256, 256] 3 | 4 | keyword = ["img", "seg_label"] 5 | 6 | train_transforms = [ 7 | dict(type="LoadImage", 8 | keys=keyword, 9 | meta_key_postfix="metas"), 10 | dict(type="AsChannelFirst", 11 | keys=keyword[0]), 12 | dict(type="AddChannel", 13 | keys=keyword[1]), 14 | dict(type='Resize', 15 | keys=keyword, 16 | spatial_size=patch_size, 17 | mode=("bilinear", "nearest")), 18 | dict(type="ScaleIntensity", 19 | keys=keyword[0]), 20 | dict(type='RandMirror'), 21 | dict(type='ToTensor', 22 | keys=keyword) 23 | ] 24 | 25 | test_transforms = [ 26 | dict(type="LoadImage", 27 | keys=keyword, 28 | meta_key_postfix="metas"), 29 | dict(type="AsChannelFirst", 30 | keys=keyword[0]), 31 | dict(type="AddChannel", 32 | keys=keyword[1]), 33 | dict(type='Resize', 34 | keys=keyword, 35 | spatial_size=patch_size, 36 | mode=("bilinear", "nearest")), 37 | dict(type="ScaleIntensity", 38 | keys=keyword[0]), 39 | dict(type='ToTensor', 40 | keys=keyword) 41 | ] 42 | data = dict( 43 | samples_per_gpu=32, 44 | workers_per_gpu=8, 45 | train=dict( 46 | type=dataset_type, 47 | img_dirs=["data/medical/pannuke/split-images-npy/0", 48 | "data/medical/pannuke/split-images-npy/1"], 49 | img_suffix=".npy", 50 | label_dirs=["data/medical/pannuke/split-masks-npy/0", 51 | "data/medical/pannuke/split-masks-npy/1"], 52 | label_suffix=".npy", 53 | phase="train", 54 | transforms=train_transforms, 55 | ), 56 | val=dict( 57 | type=dataset_type, 58 | img_dirs=["data/medical/pannuke/split-images-npy/2"], 59 | img_suffix=".npy", 60 | label_dirs=["data/medical/pannuke/split-masks-npy/2"], 61 | label_suffix=".npy", 62 | phase="val", 63 | transforms=test_transforms 64 | ), 65 | ) 66 | #support data with different formats 67 | # 68 | # data = dict( 69 | # samples_per_gpu=32, 70 | # workers_per_gpu=8, 71 | # train=dict( 72 | # type=dataset_type, 73 | # img_dirs=["data/medical/pannuke/split-images/0", 74 | # "data/medical/pannuke/split-images/1"], 75 | # img_suffix=".png", 76 | # label_dirs=["data/medical/pannuke/split-masks/0", 77 | # "data/medical/pannuke/split-masks/1"], 78 | # label_suffix=".png", 79 | # phase="train", 80 | # transforms=train_transforms, 81 | # ), 82 | # val=dict( 83 | # type=dataset_type, 84 | # img_dirs=["data/medical/pannuke/split-images/2"], 85 | # img_suffix=".png", 86 | # label_dirs=["data/medical/pannuke/split-masks/2"], 87 | # label_suffix=".png", 88 | # phase="val", 89 | # transforms=test_transforms 90 | # ), 91 | # ) -------------------------------------------------------------------------------- /configs/_base_/default_runtime.py: -------------------------------------------------------------------------------- 1 | # yapf:disable 2 | log_config = dict( 3 | interval=10, 4 | hooks=[ 5 | dict(type='TextLoggerHook', by_epoch=True), 6 | dict(type='TensorboardLoggerHook') 7 | ]) 8 | # yapf:enable 9 | dist_params = dict(backend='nccl') 10 | log_level = 'INFO' 11 | load_from = None 12 | resume_from = None 13 | workflow = [('train', 1)] 14 | cudnn_benchmark = True -------------------------------------------------------------------------------- /configs/_base_/models/attunet_res34_d5.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='EncoderDecoder', 3 | pretrained=None, 4 | encoder=dict( 5 | type="ResNet", 6 | depth=34, 7 | in_channels=3), 8 | decoder=dict( 9 | type="UNetDecoder", 10 | in_channels=[64, 64, 128, 256, 512], 11 | att=True 12 | ), 13 | seg_head=dict( 14 | type="BasicSegHead", 15 | in_channels=64, 16 | num_classes=6, 17 | post_trans=[dict(type="Activations", softmax=True), 18 | dict(type="AsDiscrete", argmax=True)], 19 | losses=[dict(type="DiceCELoss", softmax=True, to_onehot_y=True), 20 | dict(type="FocalLoss", to_onehot_y=True)]) 21 | ) 22 | -------------------------------------------------------------------------------- /configs/_base_/models/attunet_vgg32_d5.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='EncoderDecoder', 3 | pretrained=None, 4 | encoder=dict( 5 | type="VGG", 6 | in_channel=3, 7 | init_channels=32, 8 | num_blocks=2), 9 | decoder=dict( 10 | type="UNetDecoder", 11 | in_channels=[32, 64, 128, 256, 512], 12 | att=True 13 | ), 14 | seg_head=dict( 15 | type="BasicSegHead", 16 | in_channels=32, 17 | num_classes=6, 18 | post_trans=[dict(type="Activations", softmax=True), 19 | dict(type="AsDiscrete", argmax=True)], 20 | losses=[dict(type="DiceLoss", softmax=True, to_onehot_y=True), 21 | dict(type="CrossEntropyLoss"), 22 | dict(type="FLoss")] 23 | ) 24 | ) 25 | -------------------------------------------------------------------------------- /configs/_base_/models/cenet_res34_d5.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='EncoderDecoder', 3 | pretrained=None, 4 | encoder=dict( 5 | type="ResNet", 6 | depth=34, 7 | in_channels=3), 8 | center=dict( 9 | type="CEncoder", 10 | in_channels=[64, 64, 128, 256, 512]), 11 | decoder=dict( 12 | type="UNetDecoder", 13 | in_channels=[64, 64, 128, 256, 512 + 4]), 14 | seg_head=dict( 15 | type="BasicSegHead", 16 | in_channels=64, 17 | num_classes=6, 18 | post_trans=[dict(type="Activations", softmax=True), 19 | dict(type="AsDiscrete", argmax=True)], 20 | losses=[dict(type="DiceCELoss", softmax=True, to_onehot_y=True), 21 | dict(type="FocalLoss", to_onehot_y=True)]) 22 | ) 23 | -------------------------------------------------------------------------------- /configs/_base_/models/cenet_vgg32_d5.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='EncoderDecoder', 3 | pretrained=None, 4 | encoder=dict( 5 | type="VGG", 6 | in_channel=3, 7 | init_channels=32, 8 | num_blocks=2), 9 | center=dict( 10 | type="CEncoder", 11 | in_channels=[32, 64, 128, 256, 512]), 12 | decoder=dict( 13 | type="UNetDecoder", 14 | in_channels=[32, 64, 128, 256, 512+4], 15 | att=True), 16 | seg_head=dict( 17 | type="BasicSegHead", 18 | in_channels=32, 19 | num_classes=6, 20 | post_trans=[dict(type="Activations", softmax=True), 21 | dict(type="AsDiscrete", argmax=True)], 22 | losses=[dict(type="DiceLoss", softmax=True, to_onehot_y=True), 23 | dict(type="CrossEntropyLoss"), 24 | dict(type="FLoss")] 25 | ) 26 | ) 27 | -------------------------------------------------------------------------------- /configs/_base_/models/mctrans_res34_d5.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='EncoderDecoder', 3 | pretrained=None, 4 | encoder=dict( 5 | type="ResNet", 6 | depth=34, 7 | in_channels=3), 8 | center=dict( 9 | type="MCTrans", 10 | d_model=128, 11 | nhead=8, 12 | d_ffn=512, 13 | dropout=0.1, 14 | act="relu", 15 | n_levels=3, 16 | n_points=4, 17 | n_sa_layers=6), 18 | decoder=dict( 19 | type="UNetDecoder", 20 | in_channels=[64, 64, 128, 128, 128], 21 | ), 22 | seg_head=dict( 23 | type="BasicSegHead", 24 | in_channels=64, 25 | num_classes=6, 26 | post_trans=[dict(type="Activations", softmax=True), 27 | dict(type="AsDiscrete", argmax=True)], 28 | losses=[dict(type="DiceCELoss", softmax=True, to_onehot_y=True), 29 | dict(type="FocalLoss", to_onehot_y=True)]), 30 | aux_head=dict( 31 | type="MCTransAuxHead", 32 | d_model=128, 33 | d_ffn=512, 34 | act="relu", 35 | num_classes=6, 36 | in_channles=[64, 64, 128, 128, 128], 37 | losses=[dict(type="MCTransAuxLoss", sigmoid=True, loss_weight=0.1)]), 38 | ) 39 | -------------------------------------------------------------------------------- /configs/_base_/models/mctrans_vgg32_d5.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='EncoderDecoder', 3 | pretrained=None, 4 | encoder=dict( 5 | type="VGG", 6 | in_channel=3, 7 | init_channels=32, 8 | num_blocks=2), 9 | center=dict( 10 | type="MCTrans", 11 | d_model=128, 12 | nhead=8, 13 | d_ffn=512, 14 | dropout=0.1, 15 | act="relu", 16 | n_levels=3, 17 | n_points=4, 18 | n_sa_layers=6), 19 | decoder=dict( 20 | type="UNetDecoder", 21 | in_channels=[32, 64, 128, 128, 128], 22 | ), 23 | seg_head=dict( 24 | type="BasicSegHead", 25 | in_channels=32, 26 | num_classes=6, 27 | post_trans=[dict(type="Activations", softmax=True), 28 | dict(type="AsDiscrete", argmax=True)], 29 | losses=[dict(type="DiceLoss", softmax=True, to_onehot_y=True), 30 | dict(type="CrossEntropyLoss"), 31 | dict(type="FLoss")]), 32 | aux_head=dict( 33 | type="MCTransAuxHead", 34 | d_model=128, 35 | d_ffn=512, 36 | act="relu", 37 | num_classes=6, 38 | in_channles=[32, 64, 128, 128, 128], 39 | losses=[dict(type="MCTransAuxLoss", sigmoid=True, loss_weight=0.1)]), 40 | ) 41 | -------------------------------------------------------------------------------- /configs/_base_/models/nonlocal_res34_d5.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='EncoderDecoder', 3 | pretrained=None, 4 | encoder=dict( 5 | type="ResNet", 6 | depth=34, 7 | in_channels=3), 8 | center=dict( 9 | type="NonLocal", 10 | in_channels=[64, 64, 128, 256, 512]), 11 | decoder=dict( 12 | type="UNetDecoder", 13 | in_channels=[64, 64, 128, 256, 512]), 14 | seg_head=dict( 15 | type="BasicSegHead", 16 | in_channels=64, 17 | num_classes=6, 18 | post_trans=[dict(type="Activations", softmax=True), 19 | dict(type="AsDiscrete", argmax=True)], 20 | losses=[dict(type="DiceCELoss", softmax=True, to_onehot_y=True), 21 | dict(type="FocalLoss", to_onehot_y=True)]) 22 | ) 23 | -------------------------------------------------------------------------------- /configs/_base_/models/nonlocal_vgg32_d5.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='EncoderDecoder', 3 | pretrained=None, 4 | encoder=dict( 5 | type="VGG", 6 | in_channel=3, 7 | init_channels=32, 8 | num_blocks=2), 9 | center=dict( 10 | type="NonLocal", 11 | in_channels=[32, 64, 128, 256, 512]), 12 | decoder=dict( 13 | type="UNetDecoder", 14 | in_channels=[32, 64, 128, 256, 512]), 15 | seg_head=dict( 16 | type="BasicSegHead", 17 | in_channels=32, 18 | num_classes=6, 19 | post_trans=[dict(type="Activations", softmax=True), 20 | dict(type="AsDiscrete", argmax=True)], 21 | losses=[dict(type="DiceLoss", softmax=True, to_onehot_y=True), 22 | dict(type="CrossEntropyLoss"), 23 | dict(type="FLoss")]) 24 | ) 25 | -------------------------------------------------------------------------------- /configs/_base_/models/transunet_res34_d5.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='EncoderDecoder', 3 | pretrained=None, 4 | encoder=dict( 5 | type="ResNet", 6 | depth=34, 7 | in_channels=3), 8 | center=dict( 9 | type="Vit", 10 | input_size=(512, 16, 16)), 11 | decoder=dict( 12 | type="UNetDecoder", 13 | in_channels=[64, 64, 128, 256, 512]), 14 | seg_head=dict( 15 | type="BasicSegHead", 16 | in_channels=64, 17 | num_classes=6, 18 | post_trans=[dict(type="Activations", softmax=True), 19 | dict(type="AsDiscrete", argmax=True)], 20 | losses=[dict(type="DiceCELoss", softmax=True, to_onehot_y=True), 21 | dict(type="FocalLoss", to_onehot_y=True)]) 22 | ) 23 | -------------------------------------------------------------------------------- /configs/_base_/models/transunet_vgg32_d5.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='EncoderDecoder', 3 | pretrained=None, 4 | encoder=dict( 5 | type="VGG", 6 | in_channel=3, 7 | init_channels=32, 8 | num_blocks=2), 9 | center=dict( 10 | type="Vit", 11 | input_size=(512, 16, 16)), 12 | decoder=dict( 13 | type="UNetDecoder", 14 | in_channels=[32, 64, 128, 256, 512]), 15 | seg_head=dict( 16 | type="BasicSegHead", 17 | in_channels=32, 18 | num_classes=6, 19 | post_trans=[dict(type="Activations", softmax=True), 20 | dict(type="AsDiscrete", argmax=True)], 21 | losses=[dict(type="DiceLoss", softmax=True, to_onehot_y=True), 22 | dict(type="CrossEntropyLoss"), 23 | dict(type="FLoss")]) 24 | ) 25 | 26 | -------------------------------------------------------------------------------- /configs/_base_/models/unet++_res34_d5.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='EncoderDecoder', 3 | pretrained=None, 4 | encoder=dict( 5 | type="ResNet", 6 | depth=34, 7 | in_channels=3, 8 | out_indices=(0, 1, 2, 3, 4)), 9 | decoder=dict( 10 | type="UNetPlusPlusDecoder", 11 | in_channels=[64, 64, 128, 256, 512], 12 | ), 13 | seg_head=dict( 14 | type="BasicSegHead", 15 | in_channels=64, 16 | num_classes=6, 17 | post_trans=[dict(type="Activations", softmax=True), 18 | dict(type="AsDiscrete", argmax=True)], 19 | losses=[dict(type="DiceCELoss", softmax=True, to_onehot_y=True), 20 | dict(type="FocalLoss", to_onehot_y=True)]) 21 | ) 22 | -------------------------------------------------------------------------------- /configs/_base_/models/unet++_vgg32_d5.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='EncoderDecoder', 3 | pretrained=None, 4 | encoder=dict( 5 | type="VGG", 6 | in_channel=3, 7 | init_channels=32, 8 | num_blocks=2), 9 | decoder=dict( 10 | type="UNetPlusPlusDecoder", 11 | in_channels=[32, 64, 128, 256, 512], 12 | ), 13 | seg_head=dict( 14 | type="BasicSegHead", 15 | in_channels=32, 16 | num_classes=6, 17 | post_trans=[dict(type="Activations", softmax=True), 18 | dict(type="AsDiscrete", argmax=True)], 19 | losses=[dict(type="DiceLoss", softmax=True, to_onehot_y=True), 20 | dict(type="CrossEntropyLoss"), 21 | dict(type="FLoss")]) 22 | ) 23 | -------------------------------------------------------------------------------- /configs/_base_/models/unet_res34_d5.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='EncoderDecoder', 3 | pretrained=None, 4 | encoder=dict( 5 | type="ResNet", 6 | depth=34, 7 | in_channels=3), 8 | decoder=dict( 9 | type="UNetDecoder", 10 | in_channels=[64, 64, 128, 256, 512], 11 | ), 12 | seg_head=dict( 13 | type="BasicSegHead", 14 | in_channels=64, 15 | num_classes=6, 16 | post_trans=[dict(type="Activations", softmax=True), 17 | dict(type="AsDiscrete", argmax=True)], 18 | losses=[dict(type="DiceCELoss", softmax=True, to_onehot_y=True), 19 | dict(type="FocalLoss", to_onehot_y=True)]) 20 | ) 21 | -------------------------------------------------------------------------------- /configs/_base_/models/unet_vgg32_d5.py: -------------------------------------------------------------------------------- 1 | model = dict( 2 | type='EncoderDecoder', 3 | pretrained=None, 4 | encoder=dict( 5 | type="VGG", 6 | in_channel=3, 7 | init_channels=32, 8 | num_blocks=2), 9 | decoder=dict( 10 | type="UNetDecoder", 11 | in_channels=[32, 64, 128, 256, 512], 12 | ), 13 | seg_head=dict( 14 | type="BasicSegHead", 15 | in_channels=32, 16 | num_classes=6, 17 | post_trans=[dict(type="Activations", softmax=True), 18 | dict(type="AsDiscrete", argmax=True)], 19 | losses=[dict(type="DiceLoss", softmax=True, to_onehot_y=True), 20 | dict(type="CrossEntropyLoss"), 21 | dict(type="FLoss")] 22 | ) 23 | ) 24 | -------------------------------------------------------------------------------- /configs/_base_/schedules/pannuke_bs32_ep400.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='Adam', lr=0.0003, betas=[0.9, 0.99]) 3 | optimizer_config = dict() 4 | # learning policy 5 | lr_config = dict(policy='CosineAnnealing', min_lr=0.0) 6 | # runtime settings 7 | max_epochs = 400 8 | runner = dict(type='EpochBasedRunner', max_epochs=max_epochs) 9 | checkpoint_config = dict(by_epoch=True, interval=5) 10 | evaluation = dict(interval=1, save_best="mDice") -------------------------------------------------------------------------------- /configs/pannuke-vgg32/attunet_vgg32_d5_256x256_400ep_pannuke.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/attunet_vgg32_d5.py', '../_base_/datasets/pannuke.py', 3 | '../_base_/default_runtime.py', '../_base_/schedules/pannuke_bs32_ep400.py' 4 | ] 5 | -------------------------------------------------------------------------------- /configs/pannuke-vgg32/cenet_vgg32_d5_256x256_400ep_pannuke.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/cenet_vgg32_d5.py', '../_base_/datasets/pannuke.py', 3 | '../_base_/default_runtime.py', '../_base_/schedules/pannuke_bs32_ep400.py' 4 | ] 5 | -------------------------------------------------------------------------------- /configs/pannuke-vgg32/mctrans_vgg32_d5_256x256_400ep_pannuke.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/mctrans_vgg32_d5.py', '../_base_/datasets/pannuke.py', 3 | '../_base_/default_runtime.py', '../_base_/schedules/pannuke_bs32_ep400.py' 4 | ] 5 | -------------------------------------------------------------------------------- /configs/pannuke-vgg32/nonlocal_vgg32_d5_256x256_400ep_pannuke.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/nonlocal_vgg32_d5.py', '../_base_/datasets/pannuke.py', 3 | '../_base_/default_runtime.py', '../_base_/schedules/pannuke_bs32_ep400.py' 4 | ] 5 | -------------------------------------------------------------------------------- /configs/pannuke-vgg32/transunet_vgg32_d5_256x256_400ep_pannuke.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/transunet_vgg32_d5.py', '../_base_/datasets/pannuke.py', 3 | '../_base_/default_runtime.py', '../_base_/schedules/pannuke_bs32_ep400.py' 4 | ] 5 | -------------------------------------------------------------------------------- /configs/pannuke-vgg32/unet++_vgg32_d5_256x256_400ep_pannuke.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/unet++_vgg32_d5.py', '../_base_/datasets/pannuke.py', 3 | '../_base_/default_runtime.py', '../_base_/schedules/pannuke_bs32_ep400.py' 4 | ] 5 | -------------------------------------------------------------------------------- /configs/pannuke-vgg32/unet_vgg32_d5_256x256_400ep_pannuke.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/unet_vgg32_d5.py', '../_base_/datasets/pannuke.py', 3 | '../_base_/default_runtime.py', '../_base_/schedules/pannuke_bs32_ep400.py' 4 | ] 5 | -------------------------------------------------------------------------------- /docs/guidance.md: -------------------------------------------------------------------------------- 1 | ## Installation 2 | 3 | a. Create a conda virtual environment and install required packages. 4 | 5 | ```shell 6 | conda create -n mctrans pip python=3.7 7 | conda activate mctrans 8 | git clone https://github.com/JiYuanFeng/MCTrans.git 9 | cd MCTrans 10 | python setup.py develop 11 | pip install -r requirements.txt 12 | ``` 13 | 14 | a. Complie other CUDA operators such as [MultiScaleDeformableAttention](https://github.com/fundamentalvision/Deformable-DETR). 15 | 16 | ```shell 17 | cd mctrans/models/ops/ 18 | bash make.sh 19 | ``` 20 | 21 | c. Create data folder under the MCTrans and link the actual dataset path ($DATA_ROOT). 22 | 23 | ```shell 24 | mkdir data 25 | ln -s $DATA_ROOT data 26 | ``` 27 | 28 | 29 | 30 | ## Datasets Preparation 31 | 32 | - It is recommended to you to convert your dataset (espeacial the label) to standard format. For example, The binary segmengtaion label shoule only contain `0,1` or `0,255`. 33 | 34 | - If your folder structure is different, you may need to change the corresponding paths in config files. 35 | 36 | - We have upload some preprocessed datasets at [drive](https://drive.google.com/file/d/1mcD7Grx2bUQhAL9ClTrCtKv6FyX03Ehd/view?usp=sharing), you can download and unpack them under the data folder. 37 | 38 | ```none 39 | MCTrans 40 | ├── mctrans 41 | ├── data 42 | │ ├── pannuke 43 | │ │ ├── split-images 44 | │ │ ├── split-masks 45 | │ │ ├── split-images-npy 46 | │ │ ├── split-masks-npy 47 | │ ├── cvc-clinic 48 | │ │ ├── images 49 | │ │ ├── masks 50 | │ ├── cvc-colondb 51 | │ │ ├── images 52 | │ │ ├── masks 53 | │ ├── kvasir 54 | │ │ ├── images 55 | │ │ ├── masks 56 | ``` 57 | 58 | ## Single GPU Training 59 | ```shell 60 | bash tools/train.sh 61 | ``` 62 | 63 | ## Multi GPU Training 64 | 65 | ```none 66 | TO DO 67 | ``` 68 | 69 | -------------------------------------------------------------------------------- /imgs/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiYuanFeng/MCTrans/9b8b5677eef584b423d5e1630680a4b667cbe823/imgs/logo.png -------------------------------------------------------------------------------- /imgs/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiYuanFeng/MCTrans/9b8b5677eef584b423d5e1630680a4b667cbe823/imgs/overview.png -------------------------------------------------------------------------------- /mctrans/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiYuanFeng/MCTrans/9b8b5677eef584b423d5e1630680a4b667cbe823/mctrans/__init__.py -------------------------------------------------------------------------------- /mctrans/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import * 2 | from .builder import build_dataset, build_dataloader, DATASETS, TRANSFORMS 3 | -------------------------------------------------------------------------------- /mctrans/data/builder.py: -------------------------------------------------------------------------------- 1 | import platform 2 | import random 3 | from collections import Sequence, Mapping 4 | from functools import partial 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | from mmcv.parallel import collate, DataContainer 10 | from mmcv.runner import get_dist_info 11 | from mmcv.utils import Registry, build_from_cfg 12 | from mmcv.utils.parrots_wrapper import DataLoader, PoolDataLoader 13 | from torch.utils.data import DistributedSampler 14 | 15 | if platform.system() != 'Windows': 16 | # https://github.com/pytorch/pytorch/issues/973 17 | import resource 18 | 19 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 20 | hard_limit = rlimit[1] 21 | soft_limit = min(4096, hard_limit) 22 | 23 | resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit)) 24 | 25 | DATASETS = Registry('dataset') 26 | TRANSFORMS = Registry('transform') 27 | 28 | 29 | def collate_fn(batch, samples_per_gpu=1): 30 | """Puts each data field into a tensor/DataContainer with outer dimension 31 | batch size. 32 | 33 | Extend default_collate to add support for 34 | :type:`~mmcv.parallel.DataContainer`. There are 3 cases. 35 | 36 | 1. cpu_only = True, e.g., meta data 37 | 2. cpu_only = False, stack = True, e.g., images tensors 38 | 3. cpu_only = False, stack = False, e.g., gt bboxes 39 | """ 40 | if not isinstance(batch, Sequence): 41 | raise TypeError(f'{batch.dtype} is not supported.') 42 | 43 | if isinstance(batch[0], list): 44 | batch = [item for _ in batch for item in _] 45 | 46 | if isinstance(batch[0], DataContainer): 47 | assert len(batch) % samples_per_gpu == 0 48 | stacked = [] 49 | if batch[0].cpu_only: 50 | for i in range(0, len(batch), samples_per_gpu): 51 | stacked.append( 52 | [sample.data for sample in batch[i:i + samples_per_gpu]]) 53 | return DataContainer( 54 | stacked, batch[0].stack, batch[0].padding_value, cpu_only=True) 55 | elif batch[0].stack: 56 | for i in range(0, len(batch), samples_per_gpu): 57 | assert isinstance(batch[i].data, torch.Tensor) 58 | 59 | if batch[i].pad_dims is not None: 60 | ndim = batch[i].dim() 61 | assert ndim > batch[i].pad_dims 62 | max_shape = [0 for _ in range(batch[i].pad_dims)] 63 | for dim in range(1, batch[i].pad_dims + 1): 64 | max_shape[dim - 1] = batch[i].size(-dim) 65 | for sample in batch[i:i + samples_per_gpu]: 66 | for dim in range(0, ndim - batch[i].pad_dims): 67 | assert batch[i].size(dim) == sample.size(dim) 68 | for dim in range(1, batch[i].pad_dims + 1): 69 | max_shape[dim - 1] = max(max_shape[dim - 1], 70 | sample.size(-dim)) 71 | padded_samples = [] 72 | for sample in batch[i:i + samples_per_gpu]: 73 | pad = [0 for _ in range(batch[i].pad_dims * 2)] 74 | for dim in range(1, batch[i].pad_dims + 1): 75 | pad[2 * dim - 76 | 1] = max_shape[dim - 1] - sample.size(-dim) 77 | padded_samples.append( 78 | F.pad( 79 | sample.data, pad, value=sample.padding_value)) 80 | stacked.append(collate(padded_samples)) 81 | elif batch[i].pad_dims is None: 82 | stacked.append( 83 | collate([ 84 | sample.data 85 | for sample in batch[i:i + samples_per_gpu] 86 | ])) 87 | else: 88 | raise ValueError( 89 | 'pad_dims should be either None or integers (1-3)') 90 | 91 | else: 92 | for i in range(0, len(batch), samples_per_gpu): 93 | stacked.append( 94 | [sample.data for sample in batch[i:i + samples_per_gpu]]) 95 | return DataContainer(stacked, batch[0].stack, batch[0].padding_value) 96 | elif isinstance(batch[0], Sequence): 97 | transposed = zip(*batch) 98 | return [collate(samples, samples_per_gpu) for samples in transposed] 99 | 100 | elif isinstance(batch[0], Mapping): 101 | res = dict() 102 | for key in batch[0]: 103 | if isinstance(batch[0][key], torch.Tensor): 104 | res.update({key: collate([d[key] for d in batch], samples_per_gpu)}) 105 | else: 106 | res.update({key: [d[key] for d in batch]}) 107 | 108 | return res 109 | # return { 110 | # key: collate([d[key] for d in batch], samples_per_gpu) 111 | # for key in batch[0] 112 | # } 113 | else: 114 | return collate(batch) 115 | 116 | 117 | def worker_init_fn(worker_id, num_workers, rank, seed): 118 | """Worker init func for dataloader. 119 | 120 | The seed of each worker equals to num_worker * rank + worker_id + user_seed 121 | 122 | Args: 123 | worker_id (int): Worker id. 124 | num_workers (int): Number of workers. 125 | rank (int): The rank of current process. 126 | seed (int): The random seed to use. 127 | """ 128 | 129 | worker_seed = num_workers * rank + worker_id + seed 130 | np.random.seed(worker_seed) 131 | random.seed(worker_seed) 132 | 133 | 134 | def build_dataset(cfg, default_args=None): 135 | """Build datasets.""" 136 | dataset = build_from_cfg(cfg, DATASETS, default_args) 137 | return dataset 138 | 139 | 140 | def build_dataloader(dataset, 141 | samples_per_gpu, 142 | workers_per_gpu, 143 | num_gpus=1, 144 | dist=True, 145 | shuffle=True, 146 | seed=None, 147 | drop_last=False, 148 | pin_memory=True, 149 | dataloader_type='PoolDataLoader', 150 | **kwargs): 151 | """Build PyTorch DataLoader. 152 | 153 | In distributed training, each GPU/process has a dataloader. 154 | In non-distributed training, there is only one dataloader for all GPUs. 155 | 156 | Args: 157 | dataset (Dataset): A PyTorch dataset. 158 | samples_per_gpu (int): Number of training samples on each GPU, i.e., 159 | batch size of each GPU. 160 | workers_per_gpu (int): How many subprocesses to use for data loading 161 | for each GPU. 162 | num_gpus (int): Number of GPUs. Only used in non-distributed training. 163 | dist (bool): Distributed training/test or not. Default: True. 164 | shuffle (bool): Whether to shuffle the data at every epoch. 165 | Default: True. 166 | seed (int | None): Seed to be used. Default: None. 167 | drop_last (bool): Whether to drop the last incomplete batch in epoch. 168 | Default: False 169 | pin_memory (bool): Whether to use pin_memory in DataLoader. 170 | Default: True 171 | dataloader_type (str): Type of dataloader. Default: 'PoolDataLoader' 172 | kwargs: any keyword argument to be used to initialize DataLoader 173 | 174 | Returns: 175 | DataLoader: A PyTorch dataloader. 176 | """ 177 | rank, world_size = get_dist_info() 178 | if dist: 179 | sampler = DistributedSampler( 180 | dataset, world_size, rank, shuffle=shuffle) 181 | shuffle = False 182 | batch_size = samples_per_gpu 183 | num_workers = workers_per_gpu 184 | else: 185 | sampler = None 186 | batch_size = num_gpus * samples_per_gpu 187 | num_workers = num_gpus * workers_per_gpu 188 | 189 | init_fn = partial( 190 | worker_init_fn, num_workers=num_workers, rank=rank, 191 | seed=seed) if seed is not None else None 192 | 193 | assert dataloader_type in ( 194 | 'DataLoader', 195 | 'PoolDataLoader'), f'unsupported dataloader {dataloader_type}' 196 | 197 | if dataloader_type == 'PoolDataLoader': 198 | dataloader = PoolDataLoader 199 | elif dataloader_type == 'DataLoader': 200 | dataloader = DataLoader 201 | 202 | data_loader = dataloader( 203 | dataset, 204 | batch_size=batch_size, 205 | sampler=sampler, 206 | num_workers=num_workers, 207 | collate_fn=partial(collate_fn, samples_per_gpu=samples_per_gpu), 208 | pin_memory=pin_memory, 209 | shuffle=shuffle, 210 | worker_init_fn=init_fn, 211 | drop_last=drop_last, 212 | **kwargs) 213 | 214 | return data_loader 215 | -------------------------------------------------------------------------------- /mctrans/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseDataset 2 | from .polyp import LesionDataset 3 | from .pannuke import PanNukeDataset 4 | 5 | __all__ = ["BaseDataset", "LesionDataset", "PanNukeDataset"] 6 | -------------------------------------------------------------------------------- /mctrans/data/datasets/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | from functools import reduce 4 | from collections import OrderedDict 5 | 6 | import mmcv 7 | import monai 8 | from mmcv import print_log 9 | import numpy as np 10 | 11 | from prettytable import PrettyTable 12 | from sklearn.model_selection import KFold 13 | from torch.utils.data import Dataset 14 | 15 | from ..builder import DATASETS 16 | from ..transforms import build_transforms 17 | from ...metrics.base import eval_metrics 18 | from ...utils import get_root_logger 19 | 20 | 21 | @DATASETS.register_module() 22 | class BaseDataset(Dataset): 23 | """ Custom Dataset for medical image segmentation""" 24 | CLASSES = None 25 | PALETTE = None 26 | 27 | def __init__(self, 28 | transforms, 29 | img_dirs, 30 | img_suffix=".jpg", 31 | label_dirs=None, 32 | label_suffix=".png", 33 | phase=None, 34 | cross_valid=False, 35 | fold_idx=0, 36 | fold_nums=5, 37 | data_root=None, 38 | ignore_index=None, 39 | binnary_label=False, 40 | exclude_backgroud=True, 41 | label_map=None 42 | ): 43 | self.transforms = build_transforms(transforms) 44 | 45 | self.img_dirs = img_dirs if isinstance(img_dirs, (list, tuple)) else [img_dirs] 46 | self.label_dirs = label_dirs if isinstance(label_dirs, (list, tuple)) else [label_dirs] 47 | self.img_suffix = img_suffix 48 | self.label_suffix = label_suffix 49 | self.phase = phase 50 | self.cross_valid = cross_valid 51 | self.fold_idx = fold_idx 52 | self.fold_nums = fold_nums 53 | self.data_root = data_root 54 | self.label_map = label_map 55 | self.binary_label = binnary_label 56 | self.ignore_index = ignore_index 57 | self.exclude_backgroud = exclude_backgroud 58 | self.data_list = self.generate_data_list(self.img_dirs, self.img_suffix, self.label_dirs, self.label_suffix, 59 | self.cross_valid, self.phase, self.fold_idx, self.fold_nums) 60 | 61 | def generate_data_list(self, 62 | img_dirs, 63 | img_suffix, 64 | label_dirs, 65 | label_suffix, 66 | cross_valid=False, 67 | phase="Train", 68 | fold_idx=0, 69 | fold_nums=5, 70 | img_key="img", 71 | label_key="seg_label"): 72 | 73 | if label_dirs is not None: 74 | assert len(img_dirs) == len(label_dirs) 75 | 76 | data_list = [] 77 | for idx, img_dir in enumerate(img_dirs): 78 | for img in mmcv.scandir(img_dir, img_suffix, recursive=True): 79 | data_info = {} 80 | data_info[img_key] = osp.join(img_dir, img) 81 | if label_dirs is not None: 82 | label = img.replace(img_suffix, label_suffix) 83 | data_info[label_key] = osp.join(label_dirs[idx], label) 84 | data_list.append(data_info) 85 | 86 | if cross_valid: 87 | assert isinstance(fold_idx, int) and isinstance(fold_nums, int) 88 | splits = [] 89 | kfold = KFold(n_splits=fold_nums, shuffle=True) 90 | 91 | for tr_idx, te_idx in kfold.split(data_list): 92 | splits.append(dict()) 93 | splits[-1]['train'] = [item for idx, item in enumerate(data_list) if idx in tr_idx] 94 | splits[-1]['val'] = [item for idx, item in enumerate(data_list) if idx in te_idx] 95 | data_list = splits[fold_idx][phase] 96 | 97 | print_log("Phase {} : Loaded {} images".format(phase, len(data_list)), logger=get_root_logger()) 98 | 99 | return data_list 100 | 101 | def get_gt_seg_maps(self): 102 | """Get ground truth segmentation maps for evaluation.""" 103 | gt_seg_maps = [] 104 | reader = monai.transforms.LoadImage(image_only=True) 105 | for img_info in self.data_list: 106 | seg_map = osp.join(img_info['seg_label']) 107 | # gt_seg_map = mmcv.imread( 108 | # seg_map, flag='unchanged', backend='pillow') 109 | gt_seg_map = reader(seg_map) 110 | # binary the mask if need 111 | if self.binary_label: 112 | gt_seg_map[gt_seg_map > 0] = 255 113 | # modify if custom classes 114 | if self.label_map is not None: 115 | for old_id, new_id in self.label_map.items(): 116 | gt_seg_map[gt_seg_map == old_id] = new_id 117 | gt_seg_maps.append(gt_seg_map) 118 | 119 | return gt_seg_maps 120 | 121 | def evaluate(self, 122 | results, 123 | metric=['mDice', "mIoU", "mFscore", "mHd95"], 124 | logger=None, 125 | **kwargs): 126 | """Evaluate the dataset. 127 | Args: 128 | results (list): Testing results of the dataset. 129 | metric (str | list[str]): Metrics to be evaluated. 'mIoU', 130 | 'mDice' and 'mFscore' are supported. 131 | logger (logging.Logger | None | str): Logger used for printing 132 | related information during evaluation. Default: None. 133 | Returns: 134 | dict[str, float]: Default metrics. 135 | """ 136 | 137 | if isinstance(metric, str): 138 | metric = [metric] 139 | allowed_metrics = ['mIoU', 'mDice', 'mFscore', "mHd95"] 140 | if not set(metric).issubset(set(allowed_metrics)): 141 | raise KeyError('metric {} is not supported'.format(metric)) 142 | eval_results = {} 143 | gt_seg_maps = self.get_gt_seg_maps() 144 | 145 | if self.CLASSES is None: 146 | num_classes = len( 147 | reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps])) 148 | else: 149 | num_classes = len(self.CLASSES) 150 | 151 | ret_metrics = eval_metrics( 152 | results, 153 | gt_seg_maps, 154 | num_classes, 155 | self.ignore_index, 156 | metric) 157 | 158 | if self.CLASSES is None: 159 | class_names = tuple(range(num_classes)) 160 | else: 161 | class_names = self.CLASSES 162 | 163 | # each class table 164 | ret_metrics.pop('aAcc', None) 165 | ret_metrics_class = OrderedDict({ 166 | ret_metric: np.round(ret_metric_value * 100, 2) 167 | for ret_metric, ret_metric_value in ret_metrics.items() 168 | }) 169 | ret_metrics_class.update({'Class': class_names}) 170 | ret_metrics_class.move_to_end('Class', last=False) 171 | 172 | # summary table 173 | # exclude some ignore idx 174 | if self.exclude_backgroud: 175 | for ret_metric, ret_metric_value in ret_metrics.items(): 176 | ret_metrics[ret_metric] = ret_metric_value[1:] 177 | 178 | ret_metrics_summary = OrderedDict({ 179 | ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2) 180 | for ret_metric, ret_metric_value in ret_metrics.items() 181 | }) 182 | 183 | # for logger 184 | class_table_data = PrettyTable() 185 | for key, val in ret_metrics_class.items(): 186 | class_table_data.add_column(key, val) 187 | 188 | summary_table_data = PrettyTable() 189 | for key, val in ret_metrics_summary.items(): 190 | if key == 'aAcc': 191 | summary_table_data.add_column(key, [val]) 192 | else: 193 | summary_table_data.add_column('m' + key, [val]) 194 | 195 | print_log('per class results:', logger) 196 | print_log('\n' + class_table_data.get_string(), logger=logger) 197 | print_log('Summary:', logger) 198 | print_log('\n' + summary_table_data.get_string(), logger=logger) 199 | 200 | # each metric dict 201 | for key, value in ret_metrics_summary.items(): 202 | if key == 'aAcc': 203 | eval_results[key] = value / 100.0 204 | else: 205 | eval_results['m' + key] = value / 100.0 206 | 207 | ret_metrics_class.pop('Class', None) 208 | for key, value in ret_metrics_class.items(): 209 | eval_results.update({ 210 | key + '.' + str(name): value[idx] / 100.0 211 | for idx, name in enumerate(class_names) 212 | }) 213 | 214 | if mmcv.is_list_of(results, str): 215 | for file_name in results: 216 | os.remove(file_name) 217 | return eval_results 218 | 219 | def __len__(self): 220 | """Total number of samples of data.""" 221 | return len(self.data_list) 222 | 223 | def __getitem__(self, idx): 224 | data = self.data_list[idx] 225 | data = self.transforms(data) 226 | return data 227 | -------------------------------------------------------------------------------- /mctrans/data/datasets/pannuke.py: -------------------------------------------------------------------------------- 1 | from .base import BaseDataset 2 | from ..builder import DATASETS 3 | 4 | @DATASETS.register_module() 5 | class PanNukeDataset(BaseDataset): 6 | CLASSES = ('Background', 'Neoplastic', "Inflammatory", "Connective", "Dead", "Non-Neoplastic Epithelial") 7 | def __init__(self, **kwargs): 8 | super(PanNukeDataset, self).__init__(**kwargs) 9 | -------------------------------------------------------------------------------- /mctrans/data/datasets/polyp.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from .base import BaseDataset 3 | from ..builder import DATASETS 4 | 5 | 6 | @DATASETS.register_module() 7 | class LesionDataset(BaseDataset): 8 | CLASSES = ('background', 'lesion') 9 | PALETTE = [[120, 120, 120], [180, 120, 120]] 10 | 11 | def __init__(self, **kwargs): 12 | super(LesionDataset, self).__init__( 13 | label_map={255: 1}, **kwargs) 14 | assert osp.exists(self.img_dir) and self.phase is not None 15 | -------------------------------------------------------------------------------- /mctrans/data/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .monai import * 2 | from .base import BinrayLabel, build_transforms 3 | -------------------------------------------------------------------------------- /mctrans/data/transforms/base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Mapping, Hashable, Dict 3 | 4 | from mmcv.utils import build_from_cfg 5 | from monai.config import KeysCollection 6 | from monai.transforms import apply_transform, MapTransform 7 | 8 | from ..builder import TRANSFORMS 9 | 10 | 11 | class Compose(object): 12 | """Composes several transforms together. 13 | Args: 14 | transforms (list of ``Transform`` objects): list of transforms to compose. 15 | Example: 16 | >>> transforms.Compose([ 17 | >>> transforms.CenterCrop(10), 18 | >>> transforms.ToTensor(), 19 | >>> ]) 20 | """ 21 | 22 | def __init__(self, transforms): 23 | self.transforms = transforms 24 | 25 | def __call__(self, data): 26 | for t in self.transforms: 27 | data = apply_transform(t, data) 28 | return data 29 | 30 | def __repr__(self): 31 | format_string = self.__class__.__name__ + '(' 32 | for t in self.transforms: 33 | format_string += '\n' 34 | format_string += ' {0}'.format(t) 35 | format_string += '\n)' 36 | return format_string 37 | 38 | 39 | def build_transforms(cfg): 40 | """Build a transformer. 41 | 42 | Args: 43 | cfg (dict, list[dict]): The config of tranforms, is is either a dict 44 | or a list of configs. 45 | Returns: 46 | nn.Module: A built nn module. 47 | """ 48 | if isinstance(cfg, list): 49 | transforms = [ 50 | build_from_cfg(cfg_, TRANSFORMS) for cfg_ in cfg 51 | ] 52 | return Compose(transforms) 53 | else: 54 | return build_from_cfg(cfg, TRANSFORMS) 55 | 56 | 57 | @TRANSFORMS.register_module() 58 | class BinrayLabel(MapTransform): 59 | def __init__(self, 60 | keys: KeysCollection, 61 | allow_missing_keys: bool = False, 62 | ) -> None: 63 | super().__init__(keys, allow_missing_keys) 64 | 65 | def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]: 66 | d = dict(data) 67 | for key in self.key_iterator(d): 68 | d[key][d[key] > 0] = 1 69 | return d 70 | -------------------------------------------------------------------------------- /mctrans/data/transforms/monai.py: -------------------------------------------------------------------------------- 1 | from monai.transforms import RandSpatialCropd, SpatialPadd, CropForegroundd, RandCropByPosNegLabeld, Transform 2 | from monai.transforms import Resized, Spacingd, Orientationd, RandRotated, RandZoomd, RandAxisFlipd, RandAffined 3 | from monai.transforms import NormalizeIntensityd, ScaleIntensityRanged, ScaleIntensityd 4 | from monai.transforms import AddChanneld, AsChannelFirstd, MapLabelValued, Lambdad, ToTensord 5 | from monai.transforms import LoadImaged, RemoveRepeatedChanneld 6 | from monai.transforms import Activations, AsDiscrete 7 | 8 | import numpy as np 9 | from ..builder import TRANSFORMS 10 | 11 | # cropped 12 | TRANSFORMS.register_module(name="RandCrop", module=RandSpatialCropd) 13 | TRANSFORMS.register_module(name="SpatialPad", module=SpatialPadd) 14 | TRANSFORMS.register_module(name="CropForeground", module=CropForegroundd) 15 | TRANSFORMS.register_module(name="RandCropByPosNegLabel", module=RandCropByPosNegLabeld) 16 | # spatial 17 | TRANSFORMS.register_module(name='Resize', module=Resized) 18 | TRANSFORMS.register_module(name="Spacing", module=Spacingd) 19 | TRANSFORMS.register_module(name="Orientation", module=Orientationd) 20 | TRANSFORMS.register_module(name="RandRotate", module=RandRotated) 21 | TRANSFORMS.register_module(name="RandZoom", module=RandZoomd) 22 | TRANSFORMS.register_module(name="RandAxisFlip", module=RandAxisFlipd) 23 | TRANSFORMS.register_module(name="RandAffine", module=RandAffined) 24 | # intensity 25 | TRANSFORMS.register_module(name='NormalizeIntensity', module=NormalizeIntensityd) 26 | TRANSFORMS.register_module(name='ScaleIntensityRange', module=ScaleIntensityRanged) 27 | TRANSFORMS.register_module(name='ScaleIntensity', module=ScaleIntensityd) 28 | # utility 29 | TRANSFORMS.register_module(name="AddChannel", module=AddChanneld) 30 | TRANSFORMS.register_module(name="AsChannelFirst", module=AsChannelFirstd) 31 | TRANSFORMS.register_module(name="MapLabelValue", module=MapLabelValued) 32 | TRANSFORMS.register_module(name="Lambda", module=Lambdad) 33 | TRANSFORMS.register_module(name="ToTensor", module=ToTensord) 34 | # io 35 | TRANSFORMS.register_module(name="LoadImage", module=LoadImaged) 36 | TRANSFORMS.register_module(name="RemoveRepeatedChannel", module=RemoveRepeatedChanneld) 37 | # post-process 38 | TRANSFORMS.register_module(name="Activations", module=Activations) 39 | TRANSFORMS.register_module(name="AsDiscrete", module=AsDiscrete) 40 | 41 | 42 | @TRANSFORMS.register_module() 43 | class RandMirror(Transform): 44 | """ Mirror the data randomly along each specified axis according to the probability. 45 | 46 | Args: 47 | axis(None or int or tuple of ints): Along which axis to flip. 48 | prob(float): Probability of flipping. 49 | """ 50 | 51 | def __init__(self, axis=(0, 1, 2), prob=0.5, image_key='img', label_key='seg_label'): 52 | self.axis = axis 53 | self.prob = prob 54 | self.label_key = label_key 55 | self.image_key = image_key 56 | 57 | def __call__(self, data): 58 | data = dict(data) 59 | image = data[self.image_key] 60 | seg_label = data[self.label_key] if self.label_key else None 61 | image, seg_label = self.augment_mirroring(image, seg_label) 62 | 63 | data[self.image_key] = image 64 | 65 | if self.label_key is not None: 66 | data[self.label_key] = seg_label 67 | 68 | return data 69 | 70 | def augment_mirroring(self, image, seg_label=None): 71 | if (len(image.shape) != 3) and (len(image.shape) != 4): 72 | raise Exception( 73 | "Invalid dimension for sample_data and sample_seg. sample_data and sample_seg should be either " 74 | "[channels, x, y] or [channels, x, y, z]") 75 | if 0 in self.axis and np.random.uniform() < self.prob: 76 | image[:, :] = image[:, ::-1] 77 | if seg_label is not None: 78 | seg_label[:, :] = seg_label[:, ::-1] 79 | if 1 in self.axis and np.random.uniform() < self.prob: 80 | image[:, :, :] = image[:, :, ::-1] 81 | if seg_label is not None: 82 | seg_label[:, :, :] = seg_label[:, :, ::-1] 83 | if 2 in self.axis and len(image.shape) == 4: 84 | if np.random.uniform() < self.prob: 85 | image[:, :, :, :] = image[:, :, :, ::-1] 86 | if seg_label is not None: 87 | seg_label[:, :, :, :] = seg_label[:, :, :, ::-1] 88 | return image, seg_label -------------------------------------------------------------------------------- /mctrans/data/transforms/utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | import numpy 7 | 8 | 9 | def resize(input, 10 | size=None, 11 | scale_factor=None, 12 | mode='nearest', 13 | align_corners=None, 14 | warning=True): 15 | if warning: 16 | if size is not None and align_corners: 17 | input_h, input_w = tuple(int(x) for x in input.shape[2:]) 18 | output_h, output_w = tuple(int(x) for x in size) 19 | if output_h > input_h or output_w > output_h: 20 | if ((output_h > 1 and output_w > 1 and input_h > 1 21 | and input_w > 1) and (output_h - 1) % (input_h - 1) 22 | and (output_w - 1) % (input_w - 1)): 23 | warnings.warn( 24 | f'When align_corners={align_corners}, ' 25 | 'the output would more aligned if ' 26 | f'input size {(input_h, input_w)} is `x+1` and ' 27 | f'out size {(output_h, output_w)} is `nx+1`') 28 | if isinstance(size, torch.Size): 29 | size = tuple(int(x) for x in size) 30 | if isinstance(size, (numpy.ndarray, numpy.generic) ): 31 | size = tuple(int(x) for x in size) 32 | return F.interpolate(input, size, scale_factor, mode, align_corners) -------------------------------------------------------------------------------- /mctrans/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import build_metrics, METRICS, Metric -------------------------------------------------------------------------------- /mctrans/metrics/base.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import mmcv 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def f_score(precision, recall, beta=1): 9 | """calcuate the f-score value. 10 | Args: 11 | precision (float | torch.Tensor): The precision value. 12 | recall (float | torch.Tensor): The recall value. 13 | beta (int): Determines the weight of recall in the combined score. 14 | Default: False. 15 | Returns: 16 | [torch.tensor]: The f-score value. 17 | """ 18 | score = (1 + beta ** 2) * (precision * recall) / ( 19 | (beta ** 2 * precision) + recall) 20 | return score 21 | 22 | 23 | def intersect_and_union(pred_label, 24 | label, 25 | num_classes, 26 | ignore_index, 27 | label_map=dict(), 28 | reduce_zero_label=False): 29 | """Calculate intersection and Union. 30 | Args: 31 | pred_label (ndarray | str): Prediction segmentation map 32 | or predict result filename. 33 | label (ndarray | str): Ground truth segmentation map 34 | or label filename. 35 | num_classes (int): Number of categories. 36 | ignore_index (int): Index that will be ignored in evaluation. 37 | label_map (dict): Mapping old labels to new labels. The parameter will 38 | work only when label is str. Default: dict(). 39 | reduce_zero_label (bool): Wether ignore zero label. The parameter will 40 | work only when label is str. Default: False. 41 | Returns: 42 | torch.Tensor: The intersection of prediction and ground truth 43 | histogram on all classes. 44 | torch.Tensor: The union of prediction and ground truth histogram on 45 | all classes. 46 | torch.Tensor: The prediction histogram on all classes. 47 | torch.Tensor: The ground truth histogram on all classes. 48 | """ 49 | 50 | if isinstance(pred_label, str): 51 | pred_label = torch.from_numpy(np.load(pred_label)) 52 | else: 53 | pred_label = torch.from_numpy((pred_label)) 54 | 55 | if isinstance(label, str): 56 | label = torch.from_numpy( 57 | mmcv.imread(label, flag='unchanged', backend='pillow')) 58 | else: 59 | label = torch.from_numpy(label) 60 | 61 | if label_map is not None: 62 | for old_id, new_id in label_map.items(): 63 | label[label == old_id] = new_id 64 | if reduce_zero_label: 65 | label[label == 0] = 255 66 | label = label - 1 67 | label[label == 254] = 255 68 | 69 | if len(label.shape) != len(pred_label.shape): 70 | pred_label = torch.squeeze(pred_label, dim=0) 71 | mask = (label != ignore_index) 72 | pred_label = pred_label[mask] 73 | label = label[mask] 74 | 75 | intersect = pred_label[pred_label == label] 76 | area_intersect = torch.histc( 77 | intersect.float(), bins=(num_classes), min=0, max=num_classes - 1) 78 | area_pred_label = torch.histc( 79 | pred_label.float(), bins=(num_classes), min=0, max=num_classes - 1) 80 | area_label = torch.histc( 81 | label.float(), bins=(num_classes), min=0, max=num_classes - 1) 82 | area_union = area_pred_label + area_label - area_intersect 83 | return area_intersect, area_union, area_pred_label, area_label 84 | 85 | 86 | def total_intersect_and_union(results, 87 | gt_seg_maps, 88 | num_classes, 89 | ignore_index, 90 | label_map=dict(), 91 | reduce_zero_label=False): 92 | """Calculate Total Intersection and Union. 93 | Args: 94 | results (list[ndarray] | list[str]): List of prediction segmentation 95 | maps or list of prediction result filenames. 96 | gt_seg_maps (list[ndarray] | list[str]): list of ground truth 97 | segmentation maps or list of label filenames. 98 | num_classes (int): Number of categories. 99 | ignore_index (int): Index that will be ignored in evaluation. 100 | label_map (dict): Mapping old labels to new labels. Default: dict(). 101 | reduce_zero_label (bool): Wether ignore zero label. Default: False. 102 | Returns: 103 | ndarray: The intersection of prediction and ground truth histogram 104 | on all classes. 105 | ndarray: The union of prediction and ground truth histogram on all 106 | classes. 107 | ndarray: The prediction histogram on all classes. 108 | ndarray: The ground truth histogram on all classes. 109 | """ 110 | num_imgs = len(results) 111 | assert len(gt_seg_maps) == num_imgs 112 | total_area_intersect = torch.zeros((num_classes,), dtype=torch.float64) 113 | total_area_union = torch.zeros((num_classes,), dtype=torch.float64) 114 | total_area_pred_label = torch.zeros((num_classes,), dtype=torch.float64) 115 | total_area_label = torch.zeros((num_classes,), dtype=torch.float64) 116 | for i in range(num_imgs): 117 | area_intersect, area_union, area_pred_label, area_label = \ 118 | intersect_and_union( 119 | results[i], gt_seg_maps[i], num_classes, ignore_index, 120 | label_map, reduce_zero_label) 121 | total_area_intersect += area_intersect 122 | total_area_union += area_union 123 | total_area_pred_label += area_pred_label 124 | total_area_label += area_label 125 | return total_area_intersect, total_area_union, total_area_pred_label, \ 126 | total_area_label 127 | 128 | 129 | def mean_iou(results, 130 | gt_seg_maps, 131 | num_classes, 132 | ignore_index, 133 | nan_to_num=None, 134 | label_map=dict(), 135 | reduce_zero_label=False): 136 | """Calculate Mean Intersection and Union (mIoU) 137 | Args: 138 | results (list[ndarray] | list[str]): List of prediction segmentation 139 | maps or list of prediction result filenames. 140 | gt_seg_maps (list[ndarray] | list[str]): list of ground truth 141 | segmentation maps or list of label filenames. 142 | num_classes (int): Number of categories. 143 | ignore_index (int): Index that will be ignored in evaluation. 144 | nan_to_num (int, optional): If specified, NaN values will be replaced 145 | by the numbers defined by the user. Default: None. 146 | label_map (dict): Mapping old labels to new labels. Default: dict(). 147 | reduce_zero_label (bool): Wether ignore zero label. Default: False. 148 | Returns: 149 | dict[str, float | ndarray]: 150 | float: Overall accuracy on all images. 151 | ndarray: Per category accuracy, shape (num_classes, ). 152 | ndarray: Per category IoU, shape (num_classes, ). 153 | """ 154 | iou_result = eval_metrics( 155 | results=results, 156 | gt_seg_maps=gt_seg_maps, 157 | num_classes=num_classes, 158 | ignore_index=ignore_index, 159 | metrics=['mIoU'], 160 | nan_to_num=nan_to_num, 161 | label_map=label_map, 162 | reduce_zero_label=reduce_zero_label) 163 | return iou_result 164 | 165 | 166 | def mean_dice(results, 167 | gt_seg_maps, 168 | num_classes, 169 | ignore_index, 170 | nan_to_num=None, 171 | label_map=dict(), 172 | reduce_zero_label=False): 173 | """Calculate Mean Dice (mDice) 174 | Args: 175 | results (list[ndarray] | list[str]): List of prediction segmentation 176 | maps or list of prediction result filenames. 177 | gt_seg_maps (list[ndarray] | list[str]): list of ground truth 178 | segmentation maps or list of label filenames. 179 | num_classes (int): Number of categories. 180 | ignore_index (int): Index that will be ignored in evaluation. 181 | nan_to_num (int, optional): If specified, NaN values will be replaced 182 | by the numbers defined by the user. Default: None. 183 | label_map (dict): Mapping old labels to new labels. Default: dict(). 184 | reduce_zero_label (bool): Wether ignore zero label. Default: False. 185 | Returns: 186 | dict[str, float | ndarray]: Default metrics. 187 | float: Overall accuracy on all images. 188 | ndarray: Per category accuracy, shape (num_classes, ). 189 | ndarray: Per category dice, shape (num_classes, ). 190 | """ 191 | 192 | dice_result = eval_metrics( 193 | results=results, 194 | gt_seg_maps=gt_seg_maps, 195 | num_classes=num_classes, 196 | ignore_index=ignore_index, 197 | metrics=['mDice'], 198 | nan_to_num=nan_to_num, 199 | label_map=label_map, 200 | reduce_zero_label=reduce_zero_label) 201 | return dice_result 202 | 203 | 204 | def mean_fscore(results, 205 | gt_seg_maps, 206 | num_classes, 207 | ignore_index, 208 | nan_to_num=None, 209 | label_map=dict(), 210 | reduce_zero_label=False, 211 | beta=1): 212 | """Calculate Mean Intersection and Union (mIoU) 213 | Args: 214 | results (list[ndarray] | list[str]): List of prediction segmentation 215 | maps or list of prediction result filenames. 216 | gt_seg_maps (list[ndarray] | list[str]): list of ground truth 217 | segmentation maps or list of label filenames. 218 | num_classes (int): Number of categories. 219 | ignore_index (int): Index that will be ignored in evaluation. 220 | nan_to_num (int, optional): If specified, NaN values will be replaced 221 | by the numbers defined by the user. Default: None. 222 | label_map (dict): Mapping old labels to new labels. Default: dict(). 223 | reduce_zero_label (bool): Wether ignore zero label. Default: False. 224 | beta (int): Determines the weight of recall in the combined score. 225 | Default: False. 226 | Returns: 227 | dict[str, float | ndarray]: Default metrics. 228 | float: Overall accuracy on all images. 229 | ndarray: Per category recall, shape (num_classes, ). 230 | ndarray: Per category precision, shape (num_classes, ). 231 | ndarray: Per category f-score, shape (num_classes, ). 232 | """ 233 | fscore_result = eval_metrics( 234 | results=results, 235 | gt_seg_maps=gt_seg_maps, 236 | num_classes=num_classes, 237 | ignore_index=ignore_index, 238 | metrics=['mFscore'], 239 | nan_to_num=nan_to_num, 240 | label_map=label_map, 241 | reduce_zero_label=reduce_zero_label, 242 | beta=beta) 243 | return fscore_result 244 | 245 | 246 | def eval_metrics(results, 247 | gt_seg_maps, 248 | num_classes, 249 | ignore_index, 250 | metrics=['mIoU'], 251 | nan_to_num=None, 252 | label_map=dict(), 253 | reduce_zero_label=False, 254 | beta=1): 255 | """Calculate evaluation metrics 256 | Args: 257 | results (list[ndarray] | list[str]): List of prediction segmentation 258 | maps or list of prediction result filenames. 259 | gt_seg_maps (list[ndarray] | list[str]): list of ground truth 260 | segmentation maps or list of label filenames. 261 | num_classes (int): Number of categories. 262 | ignore_index (int): Index that will be ignored in evaluation. 263 | metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'. 264 | nan_to_num (int, optional): If specified, NaN values will be replaced 265 | by the numbers defined by the user. Default: None. 266 | label_map (dict): Mapping old labels to new labels. Default: dict(). 267 | reduce_zero_label (bool): Wether ignore zero label. Default: False. 268 | Returns: 269 | float: Overall accuracy on all images. 270 | ndarray: Per category accuracy, shape (num_classes, ). 271 | ndarray: Per category evaluation metrics, shape (num_classes, ). 272 | """ 273 | if isinstance(metrics, str): 274 | metrics = [metrics] 275 | allowed_metrics = ['mIoU', 'mDice', 'mFscore', "mHd95"] 276 | if not set(metrics).issubset(set(allowed_metrics)): 277 | raise KeyError('metrics {} is not supported'.format(metrics)) 278 | 279 | total_area_intersect, total_area_union, total_area_pred_label, \ 280 | total_area_label = total_intersect_and_union( 281 | results, gt_seg_maps, num_classes, ignore_index, label_map, 282 | reduce_zero_label) 283 | all_acc = total_area_intersect.sum() / total_area_label.sum() 284 | ret_metrics = OrderedDict({'aAcc': all_acc}) 285 | for metric in metrics: 286 | if metric == 'mIoU': 287 | iou = total_area_intersect / total_area_union 288 | acc = total_area_intersect / total_area_label 289 | ret_metrics['IoU'] = iou 290 | ret_metrics['Acc'] = acc 291 | elif metric == 'mDice': 292 | dice = 2 * total_area_intersect / ( 293 | total_area_pred_label + total_area_label) 294 | acc = total_area_intersect / total_area_label 295 | ret_metrics['Dice'] = dice 296 | ret_metrics['Acc'] = acc 297 | elif metric == 'mFscore': 298 | precision = total_area_intersect / total_area_pred_label 299 | recall = total_area_intersect / total_area_label 300 | f_value = torch.tensor( 301 | [f_score(x[0], x[1], beta) for x in zip(precision, recall)]) 302 | ret_metrics['Fscore'] = f_value 303 | ret_metrics['Precision'] = precision 304 | ret_metrics['Recall'] = recall 305 | elif metric == "mHd95": 306 | pass 307 | 308 | ret_metrics = { 309 | metric: value.numpy() 310 | for metric, value in ret_metrics.items() 311 | } 312 | if nan_to_num is not None: 313 | ret_metrics = OrderedDict({ 314 | metric: np.nan_to_num(metric_value, nan=nan_to_num) 315 | for metric, metric_value in ret_metrics.items() 316 | }) 317 | return ret_metrics 318 | -------------------------------------------------------------------------------- /mctrans/metrics/builder.py: -------------------------------------------------------------------------------- 1 | from mmcv.utils import Registry, build_from_cfg 2 | from abc import ABC, abstractmethod 3 | 4 | METRICS = Registry('metric') 5 | 6 | def build_metrics(cfg): 7 | """Build metric.""" 8 | class_names = cfg.class_names 9 | cfg_types = cfg.metric_types 10 | metrics = [build_from_cfg(_cfg, METRICS) for _cfg in cfg_types] 11 | for metric in metrics: 12 | metric.set_class_name(class_names) 13 | return metrics 14 | 15 | class Metric(ABC): 16 | 17 | @abstractmethod 18 | def __call__(self, pred, target, *args, **kwargs): 19 | raise NotImplementedError 20 | 21 | def set_class_name(self, class_names): 22 | self.class_names = class_names 23 | 24 | -------------------------------------------------------------------------------- /mctrans/metrics/hausdorff_distance.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Optional 2 | 3 | import mmcv 4 | import numpy as np 5 | import torch 6 | from monai.metrics import get_mask_edges, get_surface_distance 7 | 8 | 9 | def compute_hausdorff_distance( 10 | y_pred: Union[np.ndarray, torch.Tensor], 11 | y: Union[np.ndarray, torch.Tensor], 12 | include_background: bool = False, 13 | distance_metric: str = "euclidean", 14 | percentile: Optional[float] = None, 15 | directed: bool = False, 16 | ): 17 | """ 18 | Compute the Hausdorff distance. 19 | 20 | Args: 21 | y_pred: input data to compute, typical segmentation model output. 22 | It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values 23 | should be binarized. 24 | y: ground truth to compute mean the distance. It must be one-hot format and first dim is batch. 25 | The values should be binarized. 26 | include_background: whether to skip distance computation on the first channel of 27 | the predicted output. Defaults to ``False``. 28 | distance_metric: : [``"euclidean"``, ``"chessboard"``, ``"taxicab"``] 29 | the metric used to compute surface distance. Defaults to ``"euclidean"``. 30 | percentile: an optional float number between 0 and 100. If specified, the corresponding 31 | percentile of the Hausdorff Distance rather than the maximum result will be achieved. 32 | Defaults to ``None``. 33 | directed: whether to calculate directed Hausdorff distance. Defaults to ``False``. 34 | """ 35 | if isinstance(y_pred, str): 36 | y_pred = torch.from_numpy(np.load(y_pred)) 37 | else: 38 | y_pred = torch.from_numpy((y_pred)) 39 | 40 | if isinstance(y, str): 41 | y = torch.from_numpy( 42 | mmcv.imread(y, flag='unchanged', backend='pillow')) 43 | else: 44 | y = torch.from_numpy(y) 45 | 46 | if isinstance(y, torch.Tensor): 47 | y = y.float() 48 | if isinstance(y_pred, torch.Tensor): 49 | y_pred = y_pred.float() 50 | 51 | if y.shape != y_pred.shape: 52 | y_pred = torch.squeeze(y_pred, dim=0) 53 | 54 | batch_size, n_class = y_pred.shape[:2] 55 | hd = np.empty((batch_size, n_class)) 56 | 57 | (edges_pred, edges_gt) = get_mask_edges(y_pred, y) 58 | distance_1 = compute_percent_hausdorff_distance(edges_pred, edges_gt, distance_metric, percentile) 59 | 60 | for b, c in np.ndindex(batch_size, n_class): 61 | (edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c]) 62 | distance_1 = compute_percent_hausdorff_distance(edges_pred, edges_gt, distance_metric, percentile) 63 | if directed: 64 | hd[b, c] = distance_1 65 | else: 66 | distance_2 = compute_percent_hausdorff_distance(edges_gt, edges_pred, distance_metric, percentile) 67 | hd[b, c] = max(distance_1, distance_2) 68 | return torch.from_numpy(hd) 69 | 70 | 71 | 72 | def compute_percent_hausdorff_distance( 73 | edges_pred: np.ndarray, 74 | edges_gt: np.ndarray, 75 | distance_metric: str = "euclidean", 76 | percentile: Optional[float] = None, 77 | ): 78 | """ 79 | This function is used to compute the directed Hausdorff distance. 80 | """ 81 | 82 | surface_distance = get_surface_distance(edges_pred, edges_gt, distance_metric=distance_metric) 83 | 84 | # for both pred and gt do not have foreground 85 | if surface_distance.shape == (0,): 86 | return np.nan 87 | 88 | if not percentile: 89 | return surface_distance.max() 90 | 91 | if 0 <= percentile <= 100: 92 | return np.percentile(surface_distance, percentile) 93 | raise ValueError(f"percentile should be a value between 0 and 100, get {percentile}.") 94 | 95 | 96 | def total_hausdorff_distance(results, 97 | gt_seg_maps, 98 | ): 99 | num_imgs = len(results) 100 | assert len(gt_seg_maps) == num_imgs 101 | 102 | for i in range(num_imgs): 103 | x= compute_hausdorff_distance(results[i], gt_seg_maps[i], percentile=95) 104 | return 0 -------------------------------------------------------------------------------- /mctrans/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .centers import * 2 | from .losses import * 3 | from .builder import build_model, build_losses, MODEL, NETWORKS, LOSSES, HEADS, ENCODERS, DECODERS 4 | from .encoders import * 5 | from .decoders import * 6 | from .heads import * 7 | from .segmentors import * -------------------------------------------------------------------------------- /mctrans/models/builder.py: -------------------------------------------------------------------------------- 1 | from mmcv.utils import Registry, build_from_cfg 2 | from torch import nn 3 | 4 | NETWORKS = Registry('network') 5 | LOSSES = Registry('loss') 6 | MODEL = Registry('model') 7 | ENCODERS = Registry('encoder') 8 | DECODERS = Registry('decoder') 9 | CENTERS = Registry('center') 10 | HEADS = Registry('head') 11 | 12 | 13 | def build(cfg, registry, default_args=None): 14 | """Build a module. 15 | 16 | Args: 17 | cfg (dict, list[dict]): The config of modules, is is either a dict 18 | or a list of configs. 19 | registry (:obj:`Registry`): A registry the module belongs to. 20 | default_args (dict, optional): Default arguments to build the module. 21 | Defaults to None. 22 | 23 | Returns: 24 | nn.Module: A built nn module. 25 | """ 26 | 27 | if isinstance(cfg, list): 28 | modules = [ 29 | build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg 30 | ] 31 | return nn.Sequential(*modules) 32 | else: 33 | return build_from_cfg(cfg, registry, default_args) 34 | 35 | 36 | def build_network(cfg): 37 | """Build network.""" 38 | return build(cfg, NETWORKS) 39 | 40 | 41 | def build_losses(cfg): 42 | """Build loss.""" 43 | return [build_from_cfg(_cfg, LOSSES) for _cfg in cfg] 44 | 45 | 46 | def build_model(cfg): 47 | """Build model.""" 48 | return build(cfg, MODEL) 49 | 50 | 51 | def build_encoder(cfg): 52 | """Build Encoder.""" 53 | return build(cfg, ENCODERS) 54 | 55 | 56 | def build_decoder(cfg): 57 | """Build Decoder.""" 58 | return build(cfg, DECODERS) 59 | 60 | 61 | def build_center(cfg): 62 | """Build Center.""" 63 | return build(cfg, CENTERS) 64 | 65 | 66 | def build_head(cfg): 67 | """Build SegHead.""" 68 | return build(cfg, HEADS) 69 | -------------------------------------------------------------------------------- /mctrans/models/centers/__init__.py: -------------------------------------------------------------------------------- 1 | from .mctrans import MCTrans 2 | from .cenet import CEncoder 3 | from .non_local import NonLocal 4 | from .vit import Vit -------------------------------------------------------------------------------- /mctrans/models/centers/cenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ..builder import CENTERS 6 | 7 | 8 | class DacBlock(nn.Module): 9 | def __init__(self, channel): 10 | super(DacBlock, self).__init__() 11 | self.dilate1 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1) 12 | self.dilate2 = nn.Conv2d(channel, channel, kernel_size=3, dilation=3, padding=3) 13 | self.dilate3 = nn.Conv2d(channel, channel, kernel_size=3, dilation=5, padding=5) 14 | self.conv1x1 = nn.Conv2d(channel, channel, kernel_size=1, dilation=1, padding=0) 15 | self.relu = nn.ReLU(inplace=True) 16 | for m in self.modules(): 17 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 18 | if m.bias is not None: 19 | m.bias.data.zero_() 20 | 21 | def forward(self, x): 22 | dilate1_out = self.relu(self.dilate1(x)) 23 | dilate2_out = self.relu(self.conv1x1(self.dilate2(x))) 24 | dilate3_out = self.relu(self.conv1x1(self.dilate2(self.dilate1(x)))) 25 | dilate4_out = self.relu(self.conv1x1(self.dilate3(self.dilate2(self.dilate1(x))))) 26 | out = x + dilate1_out + dilate2_out + dilate3_out + dilate4_out 27 | return out 28 | 29 | 30 | class SppBlock(nn.Module): 31 | def __init__(self, in_channels): 32 | super(SppBlock, self).__init__() 33 | self.pool1 = nn.MaxPool2d(kernel_size=[2, 2], stride=2) 34 | self.pool2 = nn.MaxPool2d(kernel_size=[3, 3], stride=3) 35 | self.pool3 = nn.MaxPool2d(kernel_size=[5, 5], stride=5) 36 | self.pool4 = nn.MaxPool2d(kernel_size=[6, 6], stride=6) 37 | 38 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=1, kernel_size=1, padding=0) 39 | 40 | def forward(self, x): 41 | self.in_channels, h, w = x.size(1), x.size(2), x.size(3) 42 | self.layer1 = F.interpolate(self.conv(self.pool1(x)), size=(h, w), mode='bilinear', align_corners=True) 43 | self.layer2 = F.interpolate(self.conv(self.pool2(x)), size=(h, w), mode='bilinear', align_corners=True) 44 | self.layer3 = F.interpolate(self.conv(self.pool3(x)), size=(h, w), mode='bilinear', align_corners=True) 45 | self.layer4 = F.interpolate(self.conv(self.pool4(x)), size=(h, w), mode='bilinear', align_corners=True) 46 | 47 | out = torch.cat([self.layer1, self.layer2, self.layer3, self.layer4, x], 1) 48 | 49 | return out 50 | 51 | 52 | @CENTERS.register_module() 53 | class CEncoder(nn.Module): 54 | def __init__(self, 55 | in_channels=[240], 56 | ): 57 | super().__init__() 58 | in_channels = in_channels[-1] 59 | self.dac = DacBlock(in_channels) 60 | self.spp = SppBlock(in_channels) 61 | 62 | def forward(self, x): 63 | feat = x[-1] 64 | feat = self.dac(feat) 65 | feat = self.spp(feat) 66 | x[-1] = feat 67 | return x 68 | -------------------------------------------------------------------------------- /mctrans/models/centers/mctrans.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.init import normal_ 4 | 5 | from mmcv.cnn import ConvModule 6 | 7 | from ..builder import CENTERS 8 | from ..ops.modules import MSDeformAttn 9 | from ..trans.transformer import DSALayer, DSA 10 | from ..trans.utils import build_position_encoding, NestedTensor 11 | 12 | 13 | @CENTERS.register_module() 14 | class MCTrans(nn.Module): 15 | def __init__(self, 16 | d_model=240, 17 | nhead=8, 18 | d_ffn=1024, 19 | dropout=0.1, 20 | act="relu", 21 | n_points=4, 22 | n_levels=3, 23 | n_sa_layers=6, 24 | in_channles=[64, 64, 128, 256, 512], 25 | proj_idxs=(2, 3, 4), 26 | 27 | ): 28 | super().__init__() 29 | self.nhead = nhead 30 | self.d_model = d_model 31 | self.n_levels = n_levels 32 | 33 | self.proj_idxs = proj_idxs 34 | self.projs = nn.ModuleList() 35 | for idx in self.proj_idxs: 36 | self.projs.append(ConvModule(in_channles[idx], 37 | d_model, 38 | kernel_size=3, 39 | padding=1, 40 | conv_cfg=dict(type="Conv"), 41 | norm_cfg=dict(type='BN'), 42 | act_cfg=dict(type='ReLU') 43 | )) 44 | 45 | dsa_layer = DSALayer(d_model=d_model, 46 | d_ffn=d_ffn, 47 | dropout=dropout, 48 | activation=act, 49 | n_levels=n_levels, 50 | n_heads=nhead, 51 | n_points=n_points) 52 | 53 | self.dsa = DSA(att_layer=dsa_layer, 54 | n_layers=n_sa_layers) 55 | 56 | self.level_embed = nn.Parameter(torch.Tensor(n_levels, d_model)) 57 | self.position_embedding = build_position_encoding(position_embedding="sine", hidden_dim=d_model) 58 | self._reset_parameters() 59 | 60 | def _reset_parameters(self): 61 | for p in self.parameters(): 62 | if p.dim() > 1: 63 | nn.init.xavier_uniform_(p) 64 | for m in self.modules(): 65 | if isinstance(m, MSDeformAttn): 66 | m._reset_parameters() 67 | normal_(self.level_embed) 68 | 69 | def get_valid_ratio(self, mask): 70 | _, H, W = mask.shape 71 | valid_H = torch.sum(~mask[:, :, 0], 1) 72 | valid_W = torch.sum(~mask[:, 0, :], 1) 73 | valid_ratio_h = valid_H.float() / H 74 | valid_ratio_w = valid_W.float() / W 75 | valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) 76 | return valid_ratio 77 | 78 | def projection(self, feats): 79 | pos = [] 80 | masks = [] 81 | cnn_feats = [] 82 | tran_feats = [] 83 | 84 | for idx, feats in enumerate(feats): 85 | if idx not in self.proj_idxs: 86 | cnn_feats.append(feats) 87 | else: 88 | n, c, h, w = feats.shape 89 | mask = torch.zeros((n, h, w)).to(torch.bool).to(feats.device) 90 | nested_feats = NestedTensor(feats, mask) 91 | masks.append(mask) 92 | pos.append(self.position_embedding(nested_feats).to(nested_feats.tensors.dtype)) 93 | tran_feats.append(feats) 94 | 95 | for idx, proj in enumerate(self.projs): 96 | tran_feats[idx] = proj(tran_feats[idx]) 97 | 98 | return cnn_feats, tran_feats, pos, masks 99 | 100 | def forward(self, x): 101 | # project and prepare for the input 102 | cnn_feats, trans_feats, pos_embs, masks = self.projection(x) 103 | # dsa 104 | features_flatten = [] 105 | mask_flatten = [] 106 | lvl_pos_embed_flatten = [] 107 | feature_shapes = [] 108 | spatial_shapes = [] 109 | for lvl, (feature, mask, pos_embed) in enumerate(zip(trans_feats, masks, pos_embs)): 110 | bs, c, h, w = feature.shape 111 | spatial_shapes.append((h, w)) 112 | feature_shapes.append(feature.shape) 113 | 114 | feature = feature.flatten(2).transpose(1, 2) 115 | mask = mask.flatten(1) 116 | pos_embed = pos_embed.flatten(2).transpose(1, 2) 117 | lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) 118 | lvl_pos_embed_flatten.append(lvl_pos_embed) 119 | 120 | features_flatten.append(feature) 121 | mask_flatten.append(mask) 122 | 123 | features_flatten = torch.cat(features_flatten, 1) 124 | mask_flatten = torch.cat(mask_flatten, 1) 125 | lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) 126 | spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=features_flatten.device) 127 | level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) 128 | valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) 129 | 130 | # self att 131 | feats = self.dsa(features_flatten, 132 | spatial_shapes, 133 | level_start_index, 134 | valid_ratios, 135 | lvl_pos_embed_flatten, 136 | mask_flatten) 137 | # recover 138 | out = [] 139 | features = feats.split(spatial_shapes.prod(1).tolist(), dim=1) 140 | for idx, (feats, ori_shape) in enumerate(zip(features, spatial_shapes)): 141 | out.append(feats.transpose(1, 2).reshape(feature_shapes[idx])) 142 | 143 | cnn_feats.extend(out) 144 | return cnn_feats 145 | -------------------------------------------------------------------------------- /mctrans/models/centers/non_local.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from mctrans.models.builder import CENTERS 5 | 6 | 7 | class _NonLocalBlockND(nn.Module): 8 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 9 | super(_NonLocalBlockND, self).__init__() 10 | 11 | assert dimension in [1, 2, 3] 12 | 13 | self.dimension = dimension 14 | self.sub_sample = sub_sample 15 | 16 | self.in_channels = in_channels 17 | self.inter_channels = inter_channels 18 | 19 | if self.inter_channels is None: 20 | self.inter_channels = in_channels // 2 21 | if self.inter_channels == 0: 22 | self.inter_channels = 1 23 | 24 | if dimension == 3: 25 | conv_nd = nn.Conv3d 26 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 27 | bn = nn.BatchNorm3d 28 | elif dimension == 2: 29 | conv_nd = nn.Conv2d 30 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 31 | bn = nn.BatchNorm2d 32 | else: 33 | conv_nd = nn.Conv1d 34 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 35 | bn = nn.BatchNorm1d 36 | 37 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 38 | kernel_size=1, stride=1, padding=0) 39 | 40 | if bn_layer: 41 | self.W = nn.Sequential( 42 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 43 | kernel_size=1, stride=1, padding=0), 44 | bn(self.in_channels) 45 | ) 46 | nn.init.constant_(self.W[1].weight, 0) 47 | nn.init.constant_(self.W[1].bias, 0) 48 | else: 49 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 50 | kernel_size=1, stride=1, padding=0) 51 | nn.init.constant_(self.W.weight, 0) 52 | nn.init.constant_(self.W.bias, 0) 53 | 54 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 55 | kernel_size=1, stride=1, padding=0) 56 | 57 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 58 | kernel_size=1, stride=1, padding=0) 59 | 60 | if sub_sample: 61 | self.g = nn.Sequential(self.g, max_pool_layer) 62 | self.phi = nn.Sequential(self.phi, max_pool_layer) 63 | 64 | def forward(self, x): 65 | ''' 66 | :param x: (b, c, t, h, w) 67 | :return: 68 | ''' 69 | 70 | batch_size = x.size(0) 71 | 72 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 73 | g_x = g_x.permute(0, 2, 1) 74 | 75 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 76 | theta_x = theta_x.permute(0, 2, 1) 77 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 78 | f = torch.matmul(theta_x, phi_x) 79 | N = f.size(-1) 80 | f_div_C = f / N 81 | 82 | y = torch.matmul(f_div_C, g_x) 83 | y = y.permute(0, 2, 1).contiguous() 84 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 85 | W_y = self.W(y) 86 | z = W_y + x 87 | 88 | return z 89 | 90 | 91 | @CENTERS.register_module() 92 | class NonLocal(_NonLocalBlockND): 93 | def __init__(self, 94 | in_channels=[240], 95 | ): 96 | in_channels = in_channels[-1] 97 | super(NonLocal, self).__init__(in_channels, 98 | inter_channels=None, 99 | dimension=2, 100 | sub_sample=True, 101 | bn_layer=True) 102 | 103 | def forward(self, x): 104 | x[-1] = super().forward(x[-1]) 105 | return x 106 | -------------------------------------------------------------------------------- /mctrans/models/centers/vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from mctrans.models.builder import CENTERS 5 | 6 | 7 | class FixedPositionalEncoding(nn.Module): 8 | def __init__(self, embedding_dim, max_length=5000): 9 | super(FixedPositionalEncoding, self).__init__() 10 | 11 | pe = torch.zeros(max_length, embedding_dim) 12 | position = torch.arange(0, max_length, dtype=torch.float).unsqueeze(1) 13 | div_term = torch.exp( 14 | torch.arange(0, embedding_dim, 2).float() 15 | * (-torch.log(torch.tensor(10000.0)) / embedding_dim) 16 | ) 17 | pe[:, 0::2] = torch.sin(position * div_term) 18 | pe[:, 1::2] = torch.cos(position * div_term) 19 | pe = pe.unsqueeze(0).transpose(0, 1) 20 | self.register_buffer('pe', pe) 21 | 22 | def forward(self, x): 23 | x = x + self.pe[: x.size(0), :] 24 | return x 25 | 26 | 27 | class LearnedPositionalEncoding(nn.Module): 28 | def __init__(self, max_position_embeddings, embedding_dim, seq_length): 29 | super(LearnedPositionalEncoding, self).__init__() 30 | self.pe = nn.Embedding(max_position_embeddings, embedding_dim) 31 | self.seq_length = seq_length 32 | 33 | self.register_buffer( 34 | "position_ids", 35 | torch.arange(max_position_embeddings).expand((1, -1)), 36 | ) 37 | 38 | def forward(self, x, position_ids=None): 39 | if position_ids is None: 40 | position_ids = self.position_ids[:, : self.seq_length] 41 | 42 | position_embeddings = self.pe(position_ids) 43 | return x + position_embeddings 44 | 45 | 46 | class SelfAttention(nn.Module): 47 | def __init__( 48 | self, dim, heads=8, qkv_bias=False, qk_scale=None, dropout_rate=0.0 49 | ): 50 | super().__init__() 51 | self.num_heads = heads 52 | head_dim = dim // heads 53 | self.scale = qk_scale or head_dim ** -0.5 54 | 55 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 56 | self.attn_drop = nn.Dropout(dropout_rate) 57 | self.proj = nn.Linear(dim, dim) 58 | self.proj_drop = nn.Dropout(dropout_rate) 59 | 60 | def forward(self, x): 61 | B, N, C = x.shape 62 | qkv = ( 63 | self.qkv(x) 64 | .reshape(B, N, 3, self.num_heads, C // self.num_heads) 65 | .permute(2, 0, 3, 1, 4) 66 | ) 67 | q, k, v = ( 68 | qkv[0], 69 | qkv[1], 70 | qkv[2], 71 | ) # make torchscript happy (cannot use tensor as tuple) 72 | 73 | attn = (q @ k.transpose(-2, -1)) * self.scale 74 | attn = attn.softmax(dim=-1) 75 | attn = self.attn_drop(attn) 76 | 77 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 78 | x = self.proj(x) 79 | x = self.proj_drop(x) 80 | return x 81 | 82 | 83 | class Residual(nn.Module): 84 | def __init__(self, fn): 85 | super().__init__() 86 | self.fn = fn 87 | 88 | def forward(self, x): 89 | return self.fn(x) + x 90 | 91 | 92 | class PreNorm(nn.Module): 93 | def __init__(self, dim, fn): 94 | super().__init__() 95 | self.norm = nn.LayerNorm(dim) 96 | self.fn = fn 97 | 98 | def forward(self, x): 99 | return self.fn(self.norm(x)) 100 | 101 | 102 | class PreNormDrop(nn.Module): 103 | def __init__(self, dim, dropout_rate, fn): 104 | super().__init__() 105 | self.norm = nn.LayerNorm(dim) 106 | self.dropout = nn.Dropout(p=dropout_rate) 107 | self.fn = fn 108 | 109 | def forward(self, x): 110 | return self.dropout(self.fn(self.norm(x))) 111 | 112 | 113 | class FeedForward(nn.Module): 114 | def __init__(self, dim, hidden_dim, dropout_rate): 115 | super().__init__() 116 | self.net = nn.Sequential( 117 | nn.Linear(dim, hidden_dim), 118 | nn.GELU(), 119 | nn.Dropout(p=dropout_rate), 120 | nn.Linear(hidden_dim, dim), 121 | nn.Dropout(p=dropout_rate), 122 | ) 123 | 124 | def forward(self, x): 125 | return self.net(x) 126 | 127 | 128 | class TransformerModel(nn.Module): 129 | def __init__( 130 | self, 131 | dim, 132 | depth, 133 | heads, 134 | mlp_dim, 135 | dropout_rate=0.1, 136 | attn_dropout_rate=0.1, 137 | ): 138 | super().__init__() 139 | layers = [] 140 | for _ in range(depth): 141 | layers.extend( 142 | [ 143 | Residual( 144 | PreNormDrop( 145 | dim, 146 | dropout_rate, 147 | SelfAttention( 148 | dim, heads=heads, dropout_rate=attn_dropout_rate 149 | ), 150 | ) 151 | ), 152 | Residual( 153 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate)) 154 | ), 155 | ] 156 | ) 157 | self.net = nn.Sequential(*layers) 158 | 159 | def forward(self, x): 160 | return self.net(x) 161 | 162 | 163 | @CENTERS.register_module() 164 | class Vit(nn.Module): 165 | def __init__(self, 166 | input_size=(30, 16, 16), 167 | embedding_dims=None, 168 | num_encoders=6, 169 | positional_encoding_type="fixed"): 170 | super().__init__() 171 | input_channel, input_height, input_width = input_size 172 | num_patches = input_height * input_width 173 | flatten_dims = input_channel 174 | 175 | if embedding_dims is None: 176 | embedding_dims = flatten_dims 177 | 178 | self.linear_encoding = nn.Linear(flatten_dims, embedding_dims) 179 | 180 | if positional_encoding_type == "learnable": 181 | self.position_encoding = LearnedPositionalEncoding(max_position_embeddings=num_patches, 182 | embedding_dim=embedding_dims, 183 | seq_length=num_patches) 184 | elif positional_encoding_type == "fixed": 185 | self.position_encoding = FixedPositionalEncoding(embedding_dim=embedding_dims) 186 | 187 | self.encoders = TransformerModel(dim=embedding_dims, 188 | depth=num_encoders, 189 | heads=8, 190 | mlp_dim=embedding_dims, 191 | dropout_rate=0.1, 192 | attn_dropout_rate=0.0) 193 | 194 | def forward(self, x): 195 | feat = x[-1] 196 | 197 | n, c, h, w = feat.size() 198 | feat = feat.view(n, c, h * w).transpose(1, 2) 199 | feat = self.linear_encoding(feat) 200 | feat = self.position_encoding(feat) 201 | feat = self.encoders(feat) 202 | feat = feat.transpose(1, 2) 203 | feat = feat.view(n, c, h, w) 204 | 205 | x[-1] = feat 206 | return x 207 | -------------------------------------------------------------------------------- /mctrans/models/decoders/__init__.py: -------------------------------------------------------------------------------- 1 | from .unet_decoder import UNetDecoder 2 | from .unet_plus_plus_decoder import UNetPlusPlusDecoder 3 | -------------------------------------------------------------------------------- /mctrans/models/decoders/unet_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from ..builder import DECODERS 5 | from ..utils import conv_bn_relu 6 | 7 | 8 | class AttBlock(nn.Module): 9 | def __init__(self, F_g, F_l, F_int): 10 | super(AttBlock, self).__init__() 11 | self.W_g = nn.Sequential( 12 | nn.Conv2d(in_channels=F_g, 13 | out_channels=F_int, 14 | kernel_size=1, 15 | stride=1, 16 | padding=0, 17 | bias=True), 18 | nn.BatchNorm2d(F_int) 19 | ) 20 | 21 | self.W_x = nn.Sequential( 22 | nn.Conv2d(in_channels=F_l, 23 | out_channels=F_int, 24 | kernel_size=1, 25 | stride=1, 26 | padding=0, 27 | bias=True), 28 | nn.BatchNorm2d(F_int) 29 | ) 30 | 31 | self.psi = nn.Sequential( 32 | nn.Conv2d(in_channels=F_int, 33 | out_channels=1, 34 | kernel_size=1, 35 | stride=1, 36 | padding=0, 37 | bias=True), 38 | nn.BatchNorm2d(1), 39 | nn.Sigmoid() 40 | ) 41 | 42 | self.relu = nn.ReLU(inplace=True) 43 | 44 | def forward(self, g, x): 45 | g1 = self.W_g(g) 46 | x1 = self.W_x(x) 47 | psi = self.relu(g1 + x1) 48 | psi = self.psi(psi) 49 | 50 | return x * psi 51 | 52 | 53 | class DecBlock(nn.Module): 54 | def __init__( 55 | self, 56 | in_channels, 57 | skip_channels, 58 | out_channels, 59 | attention=False 60 | ): 61 | super().__init__() 62 | self.conv1 = conv_bn_relu(in_channels=in_channels + skip_channels, 63 | out_channels=out_channels) 64 | 65 | self.conv2 = conv_bn_relu(in_channels=out_channels, 66 | out_channels=out_channels) 67 | 68 | self.up = nn.Upsample(scale_factor=2, 69 | mode='bilinear', 70 | align_corners=True) 71 | 72 | if attention: 73 | self.att = AttBlock(F_g=in_channels, F_l=skip_channels, F_int=in_channels) 74 | 75 | def forward(self, x, skip=None): 76 | x = self.up(x) 77 | if skip is not None: 78 | if hasattr(self, "att"): 79 | skip = self.att(g=x, x=skip) 80 | x = torch.cat([x, skip], dim=1) 81 | x = self.conv1(x) 82 | x = self.conv2(x) 83 | return x 84 | 85 | 86 | @DECODERS.register_module() 87 | class UNetDecoder(nn.Module): 88 | def __init__( 89 | self, 90 | in_channels, 91 | att=False 92 | ): 93 | super().__init__() 94 | self.decoders = nn.ModuleList() 95 | in_channels = in_channels[::-1] 96 | skip_channels = in_channels[1:] 97 | for in_c, skip_c in zip(in_channels, skip_channels): 98 | self.decoders.append(DecBlock(in_c, skip_c, skip_c, att)) 99 | 100 | def forward(self, features): 101 | features = features[::-1] 102 | x = features[0] 103 | skips = features[1:] 104 | 105 | for i, layer in enumerate(self.decoders): 106 | x = layer(x, skips[i]) 107 | 108 | return x 109 | 110 | def init_weights(self): 111 | pass 112 | 113 | -------------------------------------------------------------------------------- /mctrans/models/decoders/unet_plus_plus_decoder.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from ..utils import conv_bn_relu 7 | from ..builder import DECODERS 8 | 9 | 10 | class AttBlock(nn.Module): 11 | def __init__(self, F_g, F_l, F_int): 12 | super(AttBlock, self).__init__() 13 | self.W_g = nn.Sequential( 14 | nn.Conv2d(in_channels=F_g, 15 | out_channels=F_int, 16 | kernel_size=1, 17 | stride=1, 18 | padding=0, 19 | bias=True), 20 | nn.BatchNorm2d(F_int) 21 | ) 22 | 23 | self.W_x = nn.Sequential( 24 | nn.Conv2d(in_channels=F_l, 25 | out_channels=F_int, 26 | kernel_size=1, 27 | stride=1, 28 | padding=0, 29 | bias=True), 30 | nn.BatchNorm2d(F_int) 31 | ) 32 | 33 | self.psi = nn.Sequential( 34 | nn.Conv2d(in_channels=F_int, 35 | out_channels=1, 36 | kernel_size=1, 37 | stride=1, 38 | padding=0, 39 | bias=True), 40 | nn.BatchNorm2d(1), 41 | nn.Sigmoid() 42 | ) 43 | 44 | self.relu = nn.ReLU(inplace=True) 45 | 46 | def forward(self, g, x): 47 | g1 = self.W_g(g) 48 | x1 = self.W_x(x) 49 | psi = self.relu(g1 + x1) 50 | psi = self.psi(psi) 51 | 52 | return x * psi 53 | 54 | 55 | class DecBlock(nn.Module): 56 | def __init__( 57 | self, 58 | in_channels, 59 | skip_channels, 60 | out_channels, 61 | attention=False 62 | ): 63 | super().__init__() 64 | self.conv1 = conv_bn_relu(in_channels=in_channels + skip_channels, 65 | out_channels=out_channels) 66 | 67 | self.conv2 = conv_bn_relu(in_channels=out_channels, 68 | out_channels=out_channels) 69 | 70 | self.up = nn.Upsample(scale_factor=2, 71 | mode='bilinear', 72 | align_corners=True) 73 | 74 | if attention: 75 | self.att = AttBlock(F_g=in_channels, F_l=skip_channels, F_int=in_channels) 76 | 77 | def forward(self, x, skip=None): 78 | x = self.up(x) 79 | if skip is not None: 80 | if hasattr(self, "att"): 81 | skip = self.att(g=x, x=skip) 82 | x = torch.cat([x, skip], dim=1) 83 | x = self.conv1(x) 84 | x = self.conv2(x) 85 | return x 86 | 87 | 88 | @DECODERS.register_module() 89 | class UNetPlusPlusDecoder(nn.Module): 90 | def __init__( 91 | self, 92 | in_channels, 93 | ): 94 | super().__init__() 95 | 96 | self.decoder_layers = nn.ModuleList() 97 | self.in_channels = in_channels 98 | skip_channels = in_channels[:-1] 99 | 100 | blocks = {} 101 | for stage_idx in range(1, len(self.in_channels)): 102 | for lvl_idx in range(len(self.in_channels) - stage_idx): 103 | in_ch = self.in_channels[lvl_idx + 1] 104 | skip_ch = skip_channels[lvl_idx] * (stage_idx) 105 | out_ch = self.in_channels[lvl_idx] 106 | blocks[f'x_{lvl_idx}_{stage_idx}'] = DecBlock(in_ch, skip_ch, out_ch, False) 107 | 108 | self.blocks = nn.ModuleDict(blocks) 109 | 110 | def forward(self, features): 111 | dense_x = OrderedDict() 112 | for idx, item in enumerate(features): 113 | dense_x[f'x_{idx}_{0}'] = features[idx] 114 | 115 | for stage_idx in range(1, len(self.in_channels)): 116 | for lvl_idx in range(len(self.in_channels) - stage_idx): 117 | skip_features = [dense_x[f'x_{lvl_idx}_{idx}'] for idx in range(stage_idx)] 118 | skip_features = torch.cat(skip_features, dim=1) 119 | output = self.blocks[f'x_{lvl_idx}_{stage_idx}'](dense_x[f'x_{lvl_idx + 1}_{stage_idx - 1}'], 120 | skip_features) 121 | dense_x[f'x_{lvl_idx}_{stage_idx}'] = output 122 | 123 | return dense_x[next(reversed(dense_x))] 124 | 125 | def init_weights(self): 126 | pass 127 | 128 | -------------------------------------------------------------------------------- /mctrans/models/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | from .vgg import VGG 2 | from .resnet import ResNet 3 | 4 | -------------------------------------------------------------------------------- /mctrans/models/encoders/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from mmcv.cnn import kaiming_init, constant_init 4 | from torch.nn.modules.batchnorm import _BatchNorm 5 | 6 | from ..builder import ENCODERS 7 | from ..utils import make_vgg_layer 8 | 9 | 10 | @ENCODERS.register_module() 11 | class VGG(nn.Module): 12 | def __init__(self, in_channel=1, depth=5, init_channels=16, num_blocks=2): 13 | super(VGG, self).__init__() 14 | filters = [(2 ** i) * init_channels for i in range(depth)] 15 | self.out_channels = filters.copy() 16 | 17 | filters.insert(0, in_channel) 18 | self.stages = nn.ModuleList() 19 | 20 | for idx in range(depth): 21 | down_sample = False if idx == 0 else True 22 | self.stages.append(make_vgg_layer(inplanes=filters[idx], 23 | planes=filters[idx + 1], 24 | num_blocks=num_blocks, 25 | with_bn=True, 26 | down_sample=down_sample)) 27 | 28 | def forward(self, x): 29 | 30 | features = [] 31 | for stage in self.stages: 32 | x = stage(x) 33 | features.append(x) 34 | return features 35 | 36 | def init_weights(self, pretrained=None): 37 | pass 38 | # for m in self.modules(): 39 | # if isinstance(m, nn.Conv2d): 40 | # kaiming_init(m) 41 | # elif isinstance(m, (_BatchNorm, nn.GroupNorm)): 42 | # constant_init(m, 1) 43 | -------------------------------------------------------------------------------- /mctrans/models/heads/__init__.py: -------------------------------------------------------------------------------- 1 | from .basic_seg_head import BasicSegHead 2 | from .mctrans_aux_head import MCTransAuxHead 3 | 4 | __all__ = ["BasicSegHead", "MCTransAuxHead"] -------------------------------------------------------------------------------- /mctrans/models/heads/basic_seg_head.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from mmcv.cnn import normal_init 3 | 4 | from ..builder import HEADS, build_losses 5 | from ...data.transforms import build_transforms 6 | 7 | 8 | @HEADS.register_module() 9 | class BasicSegHead(nn.Module): 10 | def __init__(self, in_channels, num_classes, kernel_size=1, post_trans=None, losses=None): 11 | super(BasicSegHead, self).__init__() 12 | self.head = nn.Conv2d(in_channels=in_channels, out_channels=num_classes, kernel_size=kernel_size) 13 | self.post_trans = build_transforms(post_trans) 14 | self.losses = build_losses(losses) 15 | 16 | def forward_train(self, inputs, seg_label, **kwargs): 17 | logits = self.head(inputs) 18 | losses = dict() 19 | for _loss in self.losses: 20 | losses[_loss.__class__.__name__] = _loss(logits, seg_label) 21 | return losses 22 | 23 | def forward_test(self, inputs, **kwargs): 24 | logits = self.head(inputs) 25 | preds = self.post_trans(logits) 26 | return preds 27 | 28 | def init_weights(self): 29 | pass 30 | # normal_init(self.head, mean=0, std=0.01) 31 | -------------------------------------------------------------------------------- /mctrans/models/heads/mctrans_aux_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from mmcv.cnn import normal_init 4 | 5 | from ..builder import HEADS, build_losses 6 | from ..trans.transformer import CALayer, CA 7 | 8 | 9 | @HEADS.register_module() 10 | class MCTransAuxHead(nn.Module): 11 | def __init__(self, 12 | d_model=128, 13 | d_ffn=1024, 14 | dropout=0.1, 15 | act="ReLu", 16 | n_head=8, 17 | n_layers=4, 18 | num_classes=6, 19 | in_channles=[64, 64, 128, 256, 512], 20 | proj_idxs=(2, 3, 4), 21 | losses=None 22 | ): 23 | super(MCTransAuxHead, self).__init__() 24 | self.in_channles = in_channles 25 | self.proj_idxs = proj_idxs 26 | 27 | ca_layer = CALayer(d_model=d_model, 28 | d_ffn=d_ffn, 29 | dropout=dropout, 30 | activation=act, 31 | n_heads=n_head) 32 | 33 | self.ca = CA(att_layer=ca_layer, 34 | n_layers=n_layers, 35 | n_category=num_classes, 36 | d_model=d_model) 37 | 38 | self.head = nn.Sequential(nn.Linear(num_classes*d_model, d_model), 39 | nn.Linear(d_model, num_classes)) 40 | 41 | self.losses = build_losses(losses) 42 | 43 | def forward_train(self, inputs, seg_label, **kwargs): 44 | # flatten 45 | inputs = [inputs[idx] for idx in self.proj_idxs] 46 | inputs_flatten = [item.flatten(2).transpose(1, 2) for item in inputs] 47 | inputs_flatten = torch.cat(inputs_flatten, 1) 48 | # ca 49 | outputs = self.ca(inputs_flatten) 50 | logits = self.head(outputs.flatten(1)) 51 | losses = dict() 52 | for _loss in self.losses: 53 | losses[_loss.__class__.__name__] = _loss(logits, seg_label) 54 | return losses 55 | 56 | def init_weights(self): 57 | normal_init(self.head, mean=0, std=0.01) 58 | -------------------------------------------------------------------------------- /mctrans/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .cross_entropy_loss import MCTransAuxLoss 2 | from .monai import DiceLoss, DiceCELoss, DiceFocalLoss 3 | from .debug_focal import FLoss 4 | __all__ = ["MCTransAuxLoss", "DiceLoss", "DiceCELoss", "DiceFocalLoss"] -------------------------------------------------------------------------------- /mctrans/models/losses/cross_entropy_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ..builder import LOSSES 6 | 7 | 8 | def reduce_loss(loss, reduction): 9 | """Reduce loss as specified. 10 | 11 | Args: 12 | loss (Tensor): Elementwise loss tensor. 13 | reduction (str): Options are "none", "mean" and "sum". 14 | 15 | Return: 16 | Tensor: Reduced loss tensor. 17 | """ 18 | reduction_enum = F._Reduction.get_enum(reduction) 19 | # none: 0, elementwise_mean:1, sum: 2 20 | if reduction_enum == 0: 21 | return loss 22 | elif reduction_enum == 1: 23 | return loss.mean() 24 | elif reduction_enum == 2: 25 | return loss.sum() 26 | 27 | 28 | def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None): 29 | """Apply element-wise weight and reduce loss. 30 | 31 | Args: 32 | loss (Tensor): Element-wise loss. 33 | weight (Tensor): Element-wise weights. 34 | reduction (str): Same as built-in losses of PyTorch. 35 | avg_factor (float): Avarage factor when computing the mean of losses. 36 | 37 | Returns: 38 | Tensor: Processed loss values. 39 | """ 40 | # if weight is specified, apply element-wise weight 41 | if weight is not None: 42 | loss = loss * weight 43 | 44 | # if avg_factor is not specified, just reduce the loss 45 | if avg_factor is None: 46 | loss = reduce_loss(loss, reduction) 47 | else: 48 | # if reduction is mean, then average the loss by avg_factor 49 | if reduction == 'mean': 50 | loss = loss.sum() / avg_factor 51 | # if reduction is 'none', then do nothing, otherwise raise an error 52 | elif reduction != 'none': 53 | raise ValueError('avg_factor can not be used with reduction="sum"') 54 | return loss 55 | 56 | 57 | def cross_entropy(pred, label, weight=None, reduction='mean', avg_factor=None): 58 | """Calculate the CrossEntropy loss. 59 | 60 | Args: 61 | pred (torch.Tensor): The prediction with shape (N, C), C is the number 62 | of classes. 63 | label (torch.Tensor): The gt label of the prediction. 64 | weight (torch.Tensor, optional): Sample-wise loss weight. 65 | reduction (str): The method used to reduce the loss. 66 | avg_factor (int, optional): Average factor that is used to average 67 | the loss. Defaults to None. 68 | 69 | Returns: 70 | torch.Tensor: The calculated loss 71 | """ 72 | # element-wise losses 73 | loss = F.cross_entropy(pred, label, reduction='none') 74 | 75 | # apply weights and do the reduction 76 | if weight is not None: 77 | weight = weight.float() 78 | loss = weight_reduce_loss( 79 | loss, weight=weight, reduction=reduction, avg_factor=avg_factor) 80 | 81 | return loss 82 | 83 | 84 | def soft_cross_entropy(pred, 85 | label, 86 | weight=None, 87 | reduction='mean', 88 | avg_factor=None): 89 | """Calculate the Soft CrossEntropy loss. The label can be float. 90 | 91 | Args: 92 | pred (torch.Tensor): The prediction with shape (N, C), C is the number 93 | of classes. 94 | label (torch.Tensor): The gt label of the prediction with shape (N, C). 95 | When using "mixup", the label can be float. 96 | weight (torch.Tensor, optional): Sample-wise loss weight. 97 | reduction (str): The method used to reduce the loss. 98 | avg_factor (int, optional): Average factor that is used to average 99 | the loss. Defaults to None. 100 | 101 | Returns: 102 | torch.Tensor: The calculated loss 103 | """ 104 | # element-wise losses 105 | loss = -label * F.log_softmax(pred, dim=-1) 106 | loss = loss.sum(dim=-1) 107 | 108 | # apply weights and do the reduction 109 | if weight is not None: 110 | weight = weight.float() 111 | loss = weight_reduce_loss( 112 | loss, weight=weight, reduction=reduction, avg_factor=avg_factor) 113 | 114 | return loss 115 | 116 | 117 | def binary_cross_entropy(pred, 118 | label, 119 | weight=None, 120 | reduction='mean', 121 | avg_factor=None): 122 | """Calculate the binary CrossEntropy loss with logits. 123 | 124 | Args: 125 | pred (torch.Tensor): The prediction with shape (N, *). 126 | label (torch.Tensor): The gt label with shape (N, *). 127 | weight (torch.Tensor, optional): Element-wise weight of loss with shape 128 | (N, ). Defaults to None. 129 | reduction (str): The method used to reduce the loss. 130 | Options are "none", "mean" and "sum". If reduction is 'none' , loss 131 | is same shape as pred and label. Defaults to 'mean'. 132 | avg_factor (int, optional): Average factor that is used to average 133 | the loss. Defaults to None. 134 | 135 | Returns: 136 | torch.Tensor: The calculated loss 137 | """ 138 | assert pred.dim() == label.dim() 139 | 140 | loss = F.binary_cross_entropy_with_logits(pred, label, reduction='none') 141 | 142 | # apply weights and do the reduction 143 | if weight is not None: 144 | assert weight.dim() == 1 145 | weight = weight.float() 146 | if pred.dim() > 1: 147 | weight = weight.reshape(-1, 1) 148 | loss = weight_reduce_loss( 149 | loss, weight=weight, reduction=reduction, avg_factor=avg_factor) 150 | return loss 151 | 152 | @LOSSES.register_module() 153 | class CELoss(nn.Module): 154 | 155 | def __init__(self, reduction='mean', loss_weight=1.0): 156 | super(CELoss, self).__init__() 157 | self.reduction = reduction 158 | self.loss_weight = loss_weight 159 | 160 | self.cls_criterion = cross_entropy 161 | 162 | def forward(self, 163 | cls_score, 164 | label, 165 | weight=None, 166 | avg_factor=None, 167 | reduction_override=None, 168 | **kwargs): 169 | assert reduction_override in (None, 'none', 'mean', 'sum') 170 | reduction = ( 171 | reduction_override if reduction_override else self.reduction) 172 | 173 | n_pred_ch, n_target_ch = cls_score.shape[1], label.shape[1] 174 | if n_pred_ch == n_target_ch: 175 | # target is in the one-hot format, convert to BH[WD] format to calculate ce loss 176 | label = torch.argmax(label, dim=1) 177 | else: 178 | label = torch.squeeze(label, dim=1) 179 | label = label.long() 180 | loss_cls = self.loss_weight * self.cls_criterion( 181 | cls_score, 182 | label, 183 | weight, 184 | reduction=reduction, 185 | avg_factor=avg_factor, 186 | **kwargs) 187 | return loss_cls 188 | 189 | 190 | @LOSSES.register_module() 191 | class CrossEntropyLoss(nn.Module): 192 | """Cross entropy loss. 193 | 194 | Args: 195 | use_sigmoid (bool): Whether the prediction uses sigmoid 196 | of softmax. Defaults to False. 197 | use_soft (bool): Whether to use the soft version of CrossEntropyLoss. 198 | Defaults to False. 199 | reduction (str): The method used to reduce the loss. 200 | Options are "none", "mean" and "sum". Defaults to 'mean'. 201 | loss_weight (float): Weight of the loss. Defaults to 1.0. 202 | """ 203 | 204 | def __init__(self, 205 | sigmoid=False, 206 | softmax=False, 207 | reduction='mean', 208 | loss_weight=1.0): 209 | super(CrossEntropyLoss, self).__init__() 210 | self.use_sigmoid = sigmoid 211 | self.use_soft = softmax 212 | assert not ( 213 | self.use_soft and self.use_sigmoid 214 | ), 'use_sigmoid and use_soft could not be set simultaneously' 215 | 216 | self.reduction = reduction 217 | self.loss_weight = loss_weight 218 | 219 | if self.use_sigmoid: 220 | self.cls_criterion = binary_cross_entropy 221 | elif self.use_soft: 222 | self.cls_criterion = soft_cross_entropy 223 | else: 224 | self.cls_criterion = cross_entropy 225 | 226 | def forward(self, 227 | cls_score, 228 | label, 229 | weight=None, 230 | avg_factor=None, 231 | reduction_override=None, 232 | **kwargs): 233 | assert reduction_override in (None, 'none', 'mean', 'sum') 234 | reduction = ( 235 | reduction_override if reduction_override else self.reduction) 236 | 237 | n_pred_ch, n_target_ch = cls_score.shape[1], label.shape[1] 238 | if n_pred_ch == n_target_ch: 239 | label = torch.argmax(label, dim=1) 240 | else: 241 | label = torch.squeeze(label, dim=1) 242 | label = label.long() 243 | 244 | loss_cls = self.loss_weight * self.cls_criterion( 245 | cls_score, 246 | label, 247 | weight, 248 | reduction=reduction, 249 | avg_factor=avg_factor, 250 | **kwargs) 251 | return loss_cls 252 | 253 | 254 | @LOSSES.register_module() 255 | class MCTransAuxLoss(CrossEntropyLoss): 256 | def __init__(self,**kwargs): 257 | super(MCTransAuxLoss, self).__init__(**kwargs) 258 | 259 | def forward(self, 260 | cls_score, 261 | label, 262 | weight=None, 263 | avg_factor=None, 264 | reduction_override=None, 265 | **kwargs): 266 | assert reduction_override in (None, 'none', 'mean', 'sum') 267 | #To one hot 268 | num_classes = cls_score.shape[1] 269 | one_hot = [] 270 | for l in label: 271 | one_hot.append(self.one_hot(torch.unique(l), num_classes=num_classes).sum(dim=0)) 272 | label = torch.stack(one_hot) 273 | 274 | reduction = ( 275 | reduction_override if reduction_override else self.reduction) 276 | loss_cls = self.loss_weight * self.cls_criterion( 277 | cls_score, 278 | label, 279 | weight, 280 | reduction=reduction, 281 | avg_factor=avg_factor, 282 | **kwargs) 283 | return loss_cls 284 | 285 | def one_hot(self, input, num_classes, dtype=torch.float): 286 | assert input.dim() > 0, "input should have dim of 1 or more." 287 | 288 | # if 1D, add singelton dim at the end 289 | if input.dim() == 1: 290 | input = input.view(-1, 1) 291 | 292 | sh = list(input.shape) 293 | 294 | assert sh[1] == 1, "labels should have a channel with length equals to one." 295 | sh[1] = num_classes 296 | 297 | o = torch.zeros(size=sh, dtype=dtype, device=input.device) 298 | labels = o.scatter_(dim=1, index=input.long(), value=1) 299 | 300 | return labels -------------------------------------------------------------------------------- /mctrans/models/losses/debug_focal.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 - 2021 MONAI Consortium 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # Unless required by applicable law or agreed to in writing, software 7 | # distributed under the License is distributed on an "AS IS" BASIS, 8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | # See the License for the specific language governing permissions and 10 | # limitations under the License. 11 | 12 | from typing import Optional, Union 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | from torch.nn.modules.loss import _Loss 17 | from torch import Tensor 18 | from monai.utils import LossReduction 19 | 20 | from ..builder import LOSSES 21 | 22 | class _WeightedLoss(_Loss): 23 | def __init__(self, weight: Optional[Tensor] = None, size_average=None, reduce=None, reduction: str = 'mean') -> None: 24 | super(_WeightedLoss, self).__init__(size_average, reduce, reduction) 25 | self.register_buffer('weight', weight) 26 | @LOSSES.register_module() 27 | class FLoss(_WeightedLoss): 28 | """ 29 | Reimplementation of the Focal Loss described in: 30 | 31 | - "Focal Loss for Dense Object Detection", T. Lin et al., ICCV 2017 32 | - "AnatomyNet: Deep learning for fast and fully automated whole‐volume segmentation of head and neck anatomy", 33 | Zhu et al., Medical Physics 2018 34 | """ 35 | 36 | def __init__( 37 | self, 38 | gamma: float = 2.0, 39 | weight: Optional[torch.Tensor] = None, 40 | reduction: Union[LossReduction, str] = LossReduction.MEAN, 41 | ) -> None: 42 | """ 43 | Args: 44 | gamma: value of the exponent gamma in the definition of the Focal loss. 45 | weight: weights to apply to the voxels of each class. If None no weights are applied. 46 | This corresponds to the weights `\alpha` in [1]. 47 | reduction: {``"none"``, ``"mean"``, ``"sum"``} 48 | Specifies the reduction to apply to the output. Defaults to ``"mean"``. 49 | 50 | - ``"none"``: no reduction will be applied. 51 | - ``"mean"``: the sum of the output will be divided by the number of elements in the output. 52 | - ``"sum"``: the output will be summed. 53 | 54 | Example: 55 | .. code-block:: python 56 | 57 | import torch 58 | from monai.losses import FocalLoss 59 | 60 | pred = torch.tensor([[1, 0], [0, 1], [1, 0]], dtype=torch.float32) 61 | grnd = torch.tensor([[0], [1], [0]], dtype=torch.int64) 62 | fl = FocalLoss() 63 | fl(pred, grnd) 64 | 65 | """ 66 | super(FLoss, self).__init__(weight=weight, reduction=LossReduction(reduction).value) 67 | self.gamma = gamma 68 | self.weight: Optional[torch.Tensor] = None 69 | 70 | def forward(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 71 | """ 72 | Args: 73 | logits: the shape should be BCH[WD]. 74 | where C (greater than 1) is the number of classes. 75 | Softmax over the logits is integrated in this module for improved numerical stability. 76 | target: the shape should be B1H[WD] or BCH[WD]. 77 | If the target's shape is B1H[WD], the target that this loss expects should be a class index 78 | in the range [0, C-1] where C is the number of classes. 79 | 80 | Raises: 81 | ValueError: When ``target`` ndim differs from ``logits``. 82 | ValueError: When ``target`` channel is not 1 and ``target`` shape differs from ``logits``. 83 | ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. 84 | 85 | """ 86 | i = logits 87 | t = target 88 | 89 | if i.ndimension() != t.ndimension(): 90 | raise ValueError(f"logits and target ndim must match, got logits={i.ndimension()} target={t.ndimension()}.") 91 | 92 | if t.shape[1] != 1 and t.shape[1] != i.shape[1]: 93 | raise ValueError( 94 | "target must have one channel or have the same shape as the logits. " 95 | "If it has one channel, it should be a class index in the range [0, C-1] " 96 | f"where C is the number of classes inferred from 'logits': C={i.shape[1]}. " 97 | ) 98 | if i.shape[1] == 1: 99 | raise NotImplementedError("Single-channel predictions not supported.") 100 | 101 | # Change the shape of logits and target to 102 | # num_batch x num_class x num_voxels. 103 | if i.dim() > 2: 104 | i = i.view(i.size(0), i.size(1), -1) # N,C,H,W => N,C,H*W 105 | t = t.view(t.size(0), t.size(1), -1) # N,1,H,W => N,1,H*W or N,C,H*W 106 | else: # Compatibility with classification. 107 | i = i.unsqueeze(2) # N,C => N,C,1 108 | t = t.unsqueeze(2) # N,1 => N,1,1 or N,C,1 109 | 110 | # Compute the log proba (more stable numerically than softmax). 111 | logpt = F.log_softmax(i, dim=1) # N,C,H*W 112 | # Keep only log proba values of the ground truth class for each voxel. 113 | if target.shape[1] == 1: 114 | logpt = logpt.gather(1, t.long()) # N,C,H*W => N,1,H*W 115 | logpt = torch.squeeze(logpt, dim=1) # N,1,H*W => N,H*W 116 | 117 | # Get the proba 118 | pt = torch.exp(logpt) # N,H*W or N,C,H*W 119 | 120 | if self.weight is not None: 121 | self.weight = self.weight.to(i) 122 | # Convert the weight to a map in which each voxel 123 | # has the weight associated with the ground-truth label 124 | # associated with this voxel in target. 125 | at = self.weight[None, :, None] # C => 1,C,1 126 | at = at.expand((t.size(0), -1, t.size(2))) # 1,C,1 => N,C,H*W 127 | if target.shape[1] == 1: 128 | at = at.gather(1, t.long()) # selection of the weights => N,1,H*W 129 | at = torch.squeeze(at, dim=1) # N,1,H*W => N,H*W 130 | # Multiply the log proba by their weights. 131 | logpt = logpt * at 132 | 133 | # Compute the loss mini-batch. 134 | weight = torch.pow(-pt + 1.0, self.gamma) 135 | if target.shape[1] == 1: 136 | loss = torch.mean(-weight * logpt, dim=1) # N 137 | else: 138 | loss = torch.mean(-weight * t * logpt, dim=-1) # N,C 139 | 140 | if self.reduction == LossReduction.SUM.value: 141 | return loss.sum() 142 | if self.reduction == LossReduction.NONE.value: 143 | return loss 144 | if self.reduction == LossReduction.MEAN.value: 145 | return loss.mean() 146 | raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') 147 | -------------------------------------------------------------------------------- /mctrans/models/losses/monai.py: -------------------------------------------------------------------------------- 1 | from ..builder import LOSSES 2 | from monai.losses import DiceLoss, DiceCELoss, FocalLoss, DiceFocalLoss 3 | 4 | 5 | class DiceLoss(DiceLoss): 6 | def __init__(self, loss_weight=1.0, **kwargs): 7 | self.loss_weight = loss_weight 8 | super(DiceLoss, self).__init__(**kwargs) 9 | 10 | def forward(self, input, target): 11 | loss = self.loss_weight * super().forward(input=input, target=target) 12 | return loss 13 | 14 | 15 | class DiceCELoss(DiceCELoss): 16 | def __init__(self, loss_weight=1.0, **kwargs): 17 | self.loss_weight = loss_weight 18 | super(DiceCELoss, self).__init__(**kwargs) 19 | 20 | def forward(self, input, target): 21 | loss = self.loss_weight * super().forward(input=input, target=target) 22 | return loss 23 | 24 | 25 | LOSSES.register_module(name="DiceLoss", module=DiceLoss) 26 | LOSSES.register_module(name="FocalLoss", module=FocalLoss) 27 | 28 | LOSSES.register_module(name="DiceCELoss", module=DiceCELoss) 29 | LOSSES.register_module(name="DiceFocalLoss", module=DiceFocalLoss) 30 | -------------------------------------------------------------------------------- /mctrans/models/ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiYuanFeng/MCTrans/9b8b5677eef584b423d5e1630680a4b667cbe823/mctrans/models/ops/__init__.py -------------------------------------------------------------------------------- /mctrans/models/ops/functions/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .ms_deform_attn_func import MSDeformAttnFunction 10 | 11 | -------------------------------------------------------------------------------- /mctrans/models/ops/functions/ms_deform_attn_func.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch.autograd import Function 16 | from torch.autograd.function import once_differentiable 17 | 18 | import MultiScaleDeformableAttention as MSDA 19 | 20 | 21 | class MSDeformAttnFunction(Function): 22 | @staticmethod 23 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): 24 | ctx.im2col_step = im2col_step 25 | output = MSDA.ms_deform_attn_forward( 26 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) 27 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) 28 | return output 29 | 30 | @staticmethod 31 | @once_differentiable 32 | def backward(ctx, grad_output): 33 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors 34 | grad_value, grad_sampling_loc, grad_attn_weight = \ 35 | MSDA.ms_deform_attn_backward( 36 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) 37 | 38 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None 39 | 40 | 41 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): 42 | # for debug and test only, 43 | # need to use cuda version instead 44 | N_, S_, M_, D_ = value.shape 45 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape 46 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) 47 | sampling_grids = 2 * sampling_locations - 1 48 | sampling_value_list = [] 49 | for lid_, (H_, W_) in enumerate(value_spatial_shapes): 50 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ 51 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) 52 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 53 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) 54 | # N_*M_, D_, Lq_, P_ 55 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, 56 | mode='bilinear', padding_mode='zeros', align_corners=False) 57 | sampling_value_list.append(sampling_value_l_) 58 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) 59 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) 60 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) 61 | return output.transpose(1, 2).contiguous() -------------------------------------------------------------------------------- /mctrans/models/ops/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ------------------------------------------------------------------------------------------------ 3 | # Deformable DETR 4 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------------------------------ 7 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | # ------------------------------------------------------------------------------------------------ 9 | 10 | python setup.py build install 11 | -------------------------------------------------------------------------------- /mctrans/models/ops/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .ms_deform_attn import MSDeformAttn 10 | -------------------------------------------------------------------------------- /mctrans/models/ops/modules/ms_deform_attn.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import warnings 14 | import math 15 | 16 | import torch 17 | from torch import nn 18 | import torch.nn.functional as F 19 | from torch.nn.init import xavier_uniform_, constant_ 20 | 21 | from ..functions import MSDeformAttnFunction 22 | 23 | 24 | def _is_power_of_2(n): 25 | if (not isinstance(n, int)) or (n < 0): 26 | raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) 27 | return (n & (n-1) == 0) and n != 0 28 | 29 | 30 | class MSDeformAttn(nn.Module): 31 | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): 32 | """ 33 | Multi-Scale Deformable Attention Module 34 | :param d_model hidden dimension 35 | :param n_levels number of feature levels 36 | :param n_heads number of attention heads 37 | :param n_points number of sampling points per attention head per feature level 38 | """ 39 | super().__init__() 40 | if d_model % n_heads != 0: 41 | raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) 42 | _d_per_head = d_model // n_heads 43 | # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation 44 | if not _is_power_of_2(_d_per_head): 45 | warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " 46 | "which is more efficient in our CUDA implementation.") 47 | 48 | self.im2col_step = 64 49 | 50 | self.d_model = d_model 51 | self.n_levels = n_levels 52 | self.n_heads = n_heads 53 | self.n_points = n_points 54 | 55 | self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) 56 | self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) 57 | self.value_proj = nn.Linear(d_model, d_model) 58 | self.output_proj = nn.Linear(d_model, d_model) 59 | 60 | self._reset_parameters() 61 | 62 | def _reset_parameters(self): 63 | constant_(self.sampling_offsets.weight.data, 0.) 64 | thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) 65 | grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) 66 | grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) 67 | for i in range(self.n_points): 68 | grid_init[:, :, i, :] *= i + 1 69 | with torch.no_grad(): 70 | self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) 71 | constant_(self.attention_weights.weight.data, 0.) 72 | constant_(self.attention_weights.bias.data, 0.) 73 | xavier_uniform_(self.value_proj.weight.data) 74 | constant_(self.value_proj.bias.data, 0.) 75 | xavier_uniform_(self.output_proj.weight.data) 76 | constant_(self.output_proj.bias.data, 0.) 77 | 78 | def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): 79 | """ 80 | :param query (N, Length_{query}, C) 81 | :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area 82 | or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes 83 | :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) 84 | :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] 85 | :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] 86 | :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements 87 | 88 | :return output (N, Length_{query}, C) 89 | """ 90 | N, Len_q, _ = query.shape 91 | N, Len_in, _ = input_flatten.shape 92 | assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in 93 | 94 | value = self.value_proj(input_flatten) 95 | if input_padding_mask is not None: 96 | value = value.masked_fill(input_padding_mask[..., None], float(0)) 97 | value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) 98 | sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) 99 | attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) 100 | attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) 101 | # N, Len_q, n_heads, n_levels, n_points, 2 102 | if reference_points.shape[-1] == 2: 103 | offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) 104 | sampling_locations = reference_points[:, :, None, :, None, :] \ 105 | + sampling_offsets / offset_normalizer[None, None, None, :, None, :] 106 | elif reference_points.shape[-1] == 4: 107 | sampling_locations = reference_points[:, :, None, :, None, :2] \ 108 | + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 109 | else: 110 | raise ValueError( 111 | 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) 112 | output = MSDeformAttnFunction.apply( 113 | value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) 114 | output = self.output_proj(output) 115 | return output 116 | -------------------------------------------------------------------------------- /mctrans/models/ops/setup.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | import os 10 | import glob 11 | 12 | import torch 13 | 14 | from torch.utils.cpp_extension import CUDA_HOME 15 | from torch.utils.cpp_extension import CppExtension 16 | from torch.utils.cpp_extension import CUDAExtension 17 | 18 | from setuptools import find_packages 19 | from setuptools import setup 20 | 21 | requirements = ["torch", "torchvision"] 22 | 23 | def get_extensions(): 24 | this_dir = os.path.dirname(os.path.abspath(__file__)) 25 | extensions_dir = os.path.join(this_dir, "src") 26 | 27 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 28 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 29 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 30 | 31 | sources = main_file + source_cpu 32 | extension = CppExtension 33 | extra_compile_args = {"cxx": []} 34 | define_macros = [] 35 | 36 | if torch.cuda.is_available() and CUDA_HOME is not None: 37 | extension = CUDAExtension 38 | sources += source_cuda 39 | define_macros += [("WITH_CUDA", None)] 40 | extra_compile_args["nvcc"] = [ 41 | "-DCUDA_HAS_FP16=0", 42 | "-D__CUDA_NO_HALF_OPERATORS__", 43 | "-D__CUDA_NO_HALF_CONVERSIONS__", 44 | "-D__CUDA_NO_HALF2_OPERATORS__", 45 | ] 46 | else: 47 | raise NotImplementedError('Cuda is not availabel') 48 | 49 | sources = [os.path.join(extensions_dir, s) for s in sources] 50 | include_dirs = [extensions_dir] 51 | ext_modules = [ 52 | extension( 53 | "MultiScaleDeformableAttention", 54 | sources, 55 | include_dirs=include_dirs, 56 | define_macros=define_macros, 57 | extra_compile_args=extra_compile_args, 58 | ) 59 | ] 60 | return ext_modules 61 | 62 | setup( 63 | name="MultiScaleDeformableAttention", 64 | version="1.0", 65 | author="Weijie Su", 66 | url="https://github.com/fundamentalvision/Deformable-DETR", 67 | description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", 68 | packages=find_packages(exclude=("configs", "tests",)), 69 | ext_modules=get_extensions(), 70 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 71 | ) 72 | -------------------------------------------------------------------------------- /mctrans/models/ops/src/cpu/ms_deform_attn_cpu.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | 17 | at::Tensor 18 | ms_deform_attn_cpu_forward( 19 | const at::Tensor &value, 20 | const at::Tensor &spatial_shapes, 21 | const at::Tensor &level_start_index, 22 | const at::Tensor &sampling_loc, 23 | const at::Tensor &attn_weight, 24 | const int im2col_step) 25 | { 26 | AT_ERROR("Not implement on cpu"); 27 | } 28 | 29 | std::vector 30 | ms_deform_attn_cpu_backward( 31 | const at::Tensor &value, 32 | const at::Tensor &spatial_shapes, 33 | const at::Tensor &level_start_index, 34 | const at::Tensor &sampling_loc, 35 | const at::Tensor &attn_weight, 36 | const at::Tensor &grad_output, 37 | const int im2col_step) 38 | { 39 | AT_ERROR("Not implement on cpu"); 40 | } 41 | 42 | -------------------------------------------------------------------------------- /mctrans/models/ops/src/cpu/ms_deform_attn_cpu.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | at::Tensor 15 | ms_deform_attn_cpu_forward( 16 | const at::Tensor &value, 17 | const at::Tensor &spatial_shapes, 18 | const at::Tensor &level_start_index, 19 | const at::Tensor &sampling_loc, 20 | const at::Tensor &attn_weight, 21 | const int im2col_step); 22 | 23 | std::vector 24 | ms_deform_attn_cpu_backward( 25 | const at::Tensor &value, 26 | const at::Tensor &spatial_shapes, 27 | const at::Tensor &level_start_index, 28 | const at::Tensor &sampling_loc, 29 | const at::Tensor &attn_weight, 30 | const at::Tensor &grad_output, 31 | const int im2col_step); 32 | 33 | 34 | -------------------------------------------------------------------------------- /mctrans/models/ops/src/cuda/ms_deform_attn_cuda.cu: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include 12 | #include "cuda/ms_deform_im2col_cuda.cuh" 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | 20 | at::Tensor ms_deform_attn_cuda_forward( 21 | const at::Tensor &value, 22 | const at::Tensor &spatial_shapes, 23 | const at::Tensor &level_start_index, 24 | const at::Tensor &sampling_loc, 25 | const at::Tensor &attn_weight, 26 | const int im2col_step) 27 | { 28 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 29 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 30 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 31 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 32 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 33 | 34 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 35 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 36 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 37 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 38 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 39 | 40 | const int batch = value.size(0); 41 | const int spatial_size = value.size(1); 42 | const int num_heads = value.size(2); 43 | const int channels = value.size(3); 44 | 45 | const int num_levels = spatial_shapes.size(0); 46 | 47 | const int num_query = sampling_loc.size(1); 48 | const int num_point = sampling_loc.size(4); 49 | 50 | const int im2col_step_ = std::min(batch, im2col_step); 51 | 52 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 53 | 54 | auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); 55 | 56 | const int batch_n = im2col_step_; 57 | auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 58 | auto per_value_size = spatial_size * num_heads * channels; 59 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 60 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 61 | for (int n = 0; n < batch/im2col_step_; ++n) 62 | { 63 | auto columns = output_n.select(0, n); 64 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { 65 | ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), 66 | value.data() + n * im2col_step_ * per_value_size, 67 | spatial_shapes.data(), 68 | level_start_index.data(), 69 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 70 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 71 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 72 | columns.data()); 73 | 74 | })); 75 | } 76 | 77 | output = output.view({batch, num_query, num_heads*channels}); 78 | 79 | return output; 80 | } 81 | 82 | 83 | std::vector ms_deform_attn_cuda_backward( 84 | const at::Tensor &value, 85 | const at::Tensor &spatial_shapes, 86 | const at::Tensor &level_start_index, 87 | const at::Tensor &sampling_loc, 88 | const at::Tensor &attn_weight, 89 | const at::Tensor &grad_output, 90 | const int im2col_step) 91 | { 92 | 93 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 94 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 95 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 96 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 97 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 98 | AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); 99 | 100 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 101 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 102 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 103 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 104 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 105 | AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); 106 | 107 | const int batch = value.size(0); 108 | const int spatial_size = value.size(1); 109 | const int num_heads = value.size(2); 110 | const int channels = value.size(3); 111 | 112 | const int num_levels = spatial_shapes.size(0); 113 | 114 | const int num_query = sampling_loc.size(1); 115 | const int num_point = sampling_loc.size(4); 116 | 117 | const int im2col_step_ = std::min(batch, im2col_step); 118 | 119 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 120 | 121 | auto grad_value = at::zeros_like(value); 122 | auto grad_sampling_loc = at::zeros_like(sampling_loc); 123 | auto grad_attn_weight = at::zeros_like(attn_weight); 124 | 125 | const int batch_n = im2col_step_; 126 | auto per_value_size = spatial_size * num_heads * channels; 127 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 128 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 129 | auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 130 | 131 | for (int n = 0; n < batch/im2col_step_; ++n) 132 | { 133 | auto grad_output_g = grad_output_n.select(0, n); 134 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { 135 | ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), 136 | grad_output_g.data(), 137 | value.data() + n * im2col_step_ * per_value_size, 138 | spatial_shapes.data(), 139 | level_start_index.data(), 140 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 141 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 142 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 143 | grad_value.data() + n * im2col_step_ * per_value_size, 144 | grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 145 | grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); 146 | 147 | })); 148 | } 149 | 150 | return { 151 | grad_value, grad_sampling_loc, grad_attn_weight 152 | }; 153 | } -------------------------------------------------------------------------------- /mctrans/models/ops/src/cuda/ms_deform_attn_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | at::Tensor ms_deform_attn_cuda_forward( 15 | const at::Tensor &value, 16 | const at::Tensor &spatial_shapes, 17 | const at::Tensor &level_start_index, 18 | const at::Tensor &sampling_loc, 19 | const at::Tensor &attn_weight, 20 | const int im2col_step); 21 | 22 | std::vector ms_deform_attn_cuda_backward( 23 | const at::Tensor &value, 24 | const at::Tensor &spatial_shapes, 25 | const at::Tensor &level_start_index, 26 | const at::Tensor &sampling_loc, 27 | const at::Tensor &attn_weight, 28 | const at::Tensor &grad_output, 29 | const int im2col_step); 30 | 31 | -------------------------------------------------------------------------------- /mctrans/models/ops/src/ms_deform_attn.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "cpu/ms_deform_attn_cpu.h" 14 | 15 | #ifdef WITH_CUDA 16 | #include "cuda/ms_deform_attn_cuda.h" 17 | #endif 18 | 19 | 20 | at::Tensor 21 | ms_deform_attn_forward( 22 | const at::Tensor &value, 23 | const at::Tensor &spatial_shapes, 24 | const at::Tensor &level_start_index, 25 | const at::Tensor &sampling_loc, 26 | const at::Tensor &attn_weight, 27 | const int im2col_step) 28 | { 29 | if (value.type().is_cuda()) 30 | { 31 | #ifdef WITH_CUDA 32 | return ms_deform_attn_cuda_forward( 33 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); 34 | #else 35 | AT_ERROR("Not compiled with GPU support"); 36 | #endif 37 | } 38 | AT_ERROR("Not implemented on the CPU"); 39 | } 40 | 41 | std::vector 42 | ms_deform_attn_backward( 43 | const at::Tensor &value, 44 | const at::Tensor &spatial_shapes, 45 | const at::Tensor &level_start_index, 46 | const at::Tensor &sampling_loc, 47 | const at::Tensor &attn_weight, 48 | const at::Tensor &grad_output, 49 | const int im2col_step) 50 | { 51 | if (value.type().is_cuda()) 52 | { 53 | #ifdef WITH_CUDA 54 | return ms_deform_attn_cuda_backward( 55 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); 56 | #else 57 | AT_ERROR("Not compiled with GPU support"); 58 | #endif 59 | } 60 | AT_ERROR("Not implemented on the CPU"); 61 | } 62 | 63 | -------------------------------------------------------------------------------- /mctrans/models/ops/src/vision.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include "ms_deform_attn.h" 12 | 13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 14 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); 15 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); 16 | } 17 | -------------------------------------------------------------------------------- /mctrans/models/ops/test.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import time 14 | import torch 15 | import torch.nn as nn 16 | from torch.autograd import gradcheck 17 | 18 | from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch 19 | 20 | 21 | N, M, D = 1, 2, 2 22 | Lq, L, P = 2, 2, 2 23 | shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() 24 | level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) 25 | S = sum([(H*W).item() for H, W in shapes]) 26 | 27 | 28 | torch.manual_seed(3) 29 | 30 | 31 | @torch.no_grad() 32 | def check_forward_equal_with_pytorch_double(): 33 | value = torch.rand(N, S, M, D).cuda() * 0.01 34 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 35 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 36 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 37 | im2col_step = 2 38 | output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu() 39 | output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu() 40 | fwdok = torch.allclose(output_cuda, output_pytorch) 41 | max_abs_err = (output_cuda - output_pytorch).abs().max() 42 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 43 | 44 | print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 45 | 46 | 47 | @torch.no_grad() 48 | def check_forward_equal_with_pytorch_float(): 49 | value = torch.rand(N, S, M, D).cuda() * 0.01 50 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 51 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 52 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 53 | im2col_step = 2 54 | output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu() 55 | output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu() 56 | fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) 57 | max_abs_err = (output_cuda - output_pytorch).abs().max() 58 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 59 | 60 | print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 61 | 62 | 63 | def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): 64 | 65 | value = torch.rand(N, S, M, channels).cuda() * 0.01 66 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 67 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 68 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 69 | im2col_step = 2 70 | func = MSDeformAttnFunction.apply 71 | 72 | value.requires_grad = grad_value 73 | sampling_locations.requires_grad = grad_sampling_loc 74 | attention_weights.requires_grad = grad_attn_weight 75 | 76 | gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step)) 77 | 78 | print(f'* {gradok} check_gradient_numerical(D={channels})') 79 | 80 | 81 | if __name__ == '__main__': 82 | check_forward_equal_with_pytorch_double() 83 | check_forward_equal_with_pytorch_float() 84 | 85 | for channels in [30, 32, 64, 71, 1025, 2048, 3096]: 86 | check_gradient_numerical(channels, True, True, True) 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /mctrans/models/segmentors/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseSegmentor 2 | from .encoder_decoder import EncoderDecoder 3 | 4 | __all__ = ['BaseSegmentor', "EncoderDecoder"] 5 | -------------------------------------------------------------------------------- /mctrans/models/segmentors/base.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import warnings 3 | from abc import ABCMeta, abstractmethod 4 | from collections import OrderedDict 5 | 6 | import mmcv 7 | import numpy as np 8 | import torch 9 | import torch.distributed as dist 10 | import torch.nn as nn 11 | from mmcv.runner import auto_fp16 12 | 13 | 14 | class BaseSegmentor(nn.Module): 15 | """Base class for segmentors.""" 16 | 17 | __metaclass__ = ABCMeta 18 | 19 | def __init__(self): 20 | super(BaseSegmentor, self).__init__() 21 | self.fp16_enabled = False 22 | 23 | @property 24 | def with_center(self): 25 | """bool: whether the segmentor has neck""" 26 | return hasattr(self, 'center') and self.center is not None 27 | 28 | @property 29 | def with_auxiliary_head(self): 30 | """bool: whether the segmentor has auxiliary head""" 31 | return hasattr(self, 32 | 'aux_head') and self.aux_head is not None 33 | 34 | @property 35 | def with_decode_head(self): 36 | """bool: whether the segmentor has decode head""" 37 | return hasattr(self, 'decode_head') and self.decode_head is not None 38 | 39 | @abstractmethod 40 | def extract_feat(self, imgs): 41 | """Placeholder for extract features from images.""" 42 | pass 43 | 44 | @abstractmethod 45 | def encode_decode(self, img, img_metas): 46 | """Placeholder for encode images with backbone and decode into a 47 | semantic segmentation map of the same size as input.""" 48 | pass 49 | 50 | @abstractmethod 51 | def forward_train(self, imgs, img_metas, **kwargs): 52 | """Placeholder for Forward function for training.""" 53 | pass 54 | 55 | @abstractmethod 56 | def simple_test(self, img, img_meta, **kwargs): 57 | """Placeholder for single image test.""" 58 | pass 59 | 60 | @abstractmethod 61 | def aug_test(self, imgs, img_metas, **kwargs): 62 | """Placeholder for augmentation test.""" 63 | pass 64 | 65 | def init_weights(self, pretrained=None): 66 | """Initialize the weights in segmentor. 67 | 68 | Args: 69 | pretrained (str, optional): Path to pre-trained weights. 70 | Defaults to None. 71 | """ 72 | if pretrained is not None: 73 | logger = logging.getLogger() 74 | logger.info(f'load model from: {pretrained}') 75 | 76 | def forward_test(self, imgs, img_metas, **kwargs): 77 | """ 78 | Args: 79 | imgs (List[Tensor]): the outer list indicates test-time 80 | augmentations and inner Tensor should have a shape NxCxHxW, 81 | which contains all images in the batch. 82 | img_metas (List[List[dict]]): the outer list indicates test-time 83 | augs (multiscale, flip, etc.) and the inner list indicates 84 | images in a batch. 85 | """ 86 | for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]: 87 | if not isinstance(var, list): 88 | raise TypeError(f'{name} must be a list, but got ' 89 | f'{type(var)}') 90 | 91 | num_augs = len(imgs) 92 | if num_augs != len(img_metas): 93 | raise ValueError(f'num of augmentations ({len(imgs)}) != ' 94 | f'num of image meta ({len(img_metas)})') 95 | # all images in the same aug batch all of the same ori_shape and pad 96 | # shape 97 | for img_meta in img_metas: 98 | ori_shapes = [_['ori_shape'] for _ in img_meta] 99 | assert all(shape == ori_shapes[0] for shape in ori_shapes) 100 | img_shapes = [_['img_shape'] for _ in img_meta] 101 | assert all(shape == img_shapes[0] for shape in img_shapes) 102 | pad_shapes = [_['pad_shape'] for _ in img_meta] 103 | assert all(shape == pad_shapes[0] for shape in pad_shapes) 104 | 105 | if num_augs == 1: 106 | return self.simple_test(imgs[0], img_metas[0], **kwargs) 107 | else: 108 | return self.aug_test(imgs, img_metas, **kwargs) 109 | 110 | @auto_fp16(apply_to=('img', )) 111 | def forward(self, img, img_metas, return_loss=True, **kwargs): 112 | """Calls either :func:`forward_train` or :func:`forward_test` depending 113 | on whether ``return_loss`` is ``True``. 114 | 115 | Note this setting will change the expected inputs. When 116 | ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor 117 | and List[dict]), and when ``resturn_loss=False``, img and img_meta 118 | should be double nested (i.e. List[Tensor], List[List[dict]]), with 119 | the outer list indicating test time augmentations. 120 | """ 121 | if return_loss: 122 | return self.forward_train(img, img_metas, **kwargs) 123 | else: 124 | return self.forward_test(img, img_metas, **kwargs) 125 | 126 | def train_step(self, data_batch, optimizer, **kwargs): 127 | """The iteration step during training. 128 | 129 | This method defines an iteration step during training, except for the 130 | back propagation and optimizer updating, which are done in an optimizer 131 | hook. Note that in some complicated cases or models, the whole process 132 | including back propagation and optimizer updating is also defined in 133 | this method, such as GAN. 134 | 135 | Args: 136 | data (dict): The output of dataloader. 137 | optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of 138 | runner is passed to ``train_step()``. This argument is unused 139 | and reserved. 140 | 141 | Returns: 142 | dict: It should contain at least 3 keys: ``loss``, ``log_vars``, 143 | ``num_samples``. 144 | ``loss`` is a tensor for back propagation, which can be a 145 | weighted sum of multiple losses. 146 | ``log_vars`` contains all the variables to be sent to the 147 | logger. 148 | ``num_samples`` indicates the batch size (when the model is 149 | DDP, it means the batch size on each GPU), which is used for 150 | averaging the logs. 151 | """ 152 | losses = self(**data_batch) 153 | loss, log_vars = self._parse_losses(losses) 154 | 155 | outputs = dict( 156 | loss=loss, 157 | log_vars=log_vars, 158 | num_samples=len(data_batch['img'].data)) 159 | 160 | return outputs 161 | 162 | def val_step(self, data_batch, **kwargs): 163 | """The iteration step during validation. 164 | 165 | This method shares the same signature as :func:`train_step`, but used 166 | during val epochs. Note that the evaluation after training epochs is 167 | not implemented with this method, but an evaluation hook. 168 | """ 169 | output = self(**data_batch, **kwargs) 170 | return output 171 | 172 | @staticmethod 173 | def _parse_losses(losses): 174 | """Parse the raw outputs (losses) of the network. 175 | 176 | Args: 177 | losses (dict): Raw output of the network, which usually contain 178 | losses and other necessary information. 179 | 180 | Returns: 181 | tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor 182 | which may be a weighted sum of all losses, log_vars contains 183 | all the variables to be sent to the logger. 184 | """ 185 | 186 | log_vars = OrderedDict() 187 | for loss_name, loss_value in losses.items(): 188 | if isinstance(loss_value, torch.Tensor): 189 | log_vars[loss_name] = loss_value.mean() 190 | elif isinstance(loss_value, list): 191 | log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) 192 | else: 193 | raise TypeError( 194 | f'{loss_name} is not a tensor or list of tensors') 195 | 196 | loss = sum(_value for _key, _value in log_vars.items() if 'Loss' in _key) 197 | 198 | log_vars['Loss'] = loss 199 | for loss_name, loss_value in log_vars.items(): 200 | # reduce loss when distributed training 201 | if dist.is_available() and dist.is_initialized(): 202 | loss_value = loss_value.data.clone() 203 | dist.all_reduce(loss_value.div_(dist.get_world_size())) 204 | log_vars[loss_name] = loss_value.item() 205 | 206 | return loss, log_vars 207 | 208 | def show_result(self, 209 | img, 210 | result, 211 | palette=None, 212 | win_name='', 213 | show=False, 214 | wait_time=0, 215 | out_file=None): 216 | """Draw `result` over `img`. 217 | 218 | Args: 219 | img (str or Tensor): The image to be displayed. 220 | result (Tensor): The semantic segmentation results to draw over 221 | `img`. 222 | palette (list[list[int]]] | np.ndarray | None): The palette of 223 | segmentation map. If None is given, random palette will be 224 | generated. Default: None 225 | win_name (str): The window name. 226 | wait_time (int): Value of waitKey param. 227 | Default: 0. 228 | show (bool): Whether to show the image. 229 | Default: False. 230 | out_file (str or None): The filename to write the image. 231 | Default: None. 232 | 233 | Returns: 234 | img (Tensor): Only if not `show` or `out_file` 235 | """ 236 | img = mmcv.imread(img) 237 | img = img.copy() 238 | seg = result[0] 239 | if palette is None: 240 | if self.PALETTE is None: 241 | palette = np.random.randint( 242 | 0, 255, size=(len(self.CLASSES), 3)) 243 | else: 244 | palette = self.PALETTE 245 | palette = np.array(palette) 246 | assert palette.shape[0] == len(self.CLASSES) 247 | assert palette.shape[1] == 3 248 | assert len(palette.shape) == 2 249 | color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) 250 | for label, color in enumerate(palette): 251 | color_seg[seg == label, :] = color 252 | # convert to BGR 253 | color_seg = color_seg[..., ::-1] 254 | 255 | img = img * 0.5 + color_seg * 0.5 256 | img = img.astype(np.uint8) 257 | # if out_file specified, do not show image in window 258 | if out_file is not None: 259 | show = False 260 | 261 | if show: 262 | mmcv.imshow(img, win_name, wait_time) 263 | if out_file is not None: 264 | mmcv.imwrite(img, out_file) 265 | 266 | if not (show or out_file): 267 | warnings.warn('show==False and out_file is not specified, only ' 268 | 'result image will be returned') 269 | return img 270 | -------------------------------------------------------------------------------- /mctrans/models/segmentors/encoder_decoder.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from .base import BaseSegmentor 4 | from ..builder import build_network, build_losses, MODEL, build_encoder, build_decoder, build_head, build_center 5 | from ...data.transforms.utils import resize 6 | from ...metrics import build_metrics 7 | 8 | 9 | @MODEL.register_module() 10 | class EncoderDecoder(BaseSegmentor): 11 | def __init__(self, 12 | encoder, 13 | decoder, 14 | seg_head, 15 | center=None, 16 | aux_head=None, 17 | pretrained=None): 18 | super(EncoderDecoder, self).__init__() 19 | self.encoder = build_encoder(encoder) 20 | self.decoder = build_decoder(decoder) 21 | self.seg_head = build_head(seg_head) 22 | 23 | if center is not None: 24 | self.center = build_center(center) 25 | if aux_head is not None: 26 | self.aux_head = build_head(aux_head) 27 | 28 | self.init_weights(pretrained=pretrained) 29 | 30 | def init_weights(self, pretrained=None): 31 | """Initialize the weights in backbone and heads. 32 | Args: 33 | pretrained (str, optional): Path to pre-trained weights. 34 | Defaults to None. 35 | """ 36 | pass 37 | super(EncoderDecoder, self).init_weights(pretrained) 38 | self.encoder.init_weights(pretrained=pretrained) 39 | self.decoder.init_weights() 40 | self.seg_head.init_weights() 41 | if self.with_auxiliary_head: 42 | if isinstance(self.aux_head, nn.ModuleList): 43 | for aux_head in self.aux_head: 44 | aux_head.init_weights() 45 | else: 46 | self.aux_head.init_weights() 47 | 48 | def extract_feat(self, img): 49 | x = self.encoder(img) 50 | if self.with_center: 51 | x = self.center(x) 52 | return x 53 | 54 | def encode_decode(self, img, img_metas, rescale=True): 55 | """Encode images with backbone and decode into a semantic segmentation 56 | map of the same size as input.""" 57 | x = self.extract_feat(img) 58 | pred = self._decode_head_forward(x, img_metas, return_loss=False) 59 | #TODO shold evaluate on more dataset 60 | if rescale: 61 | if not hasattr(img_metas[0], "height"): 62 | re_size = img_metas[0]["spatial_shape"][:2] 63 | else: 64 | re_size = (img_metas[0]['height'], img_metas[0]['width']) 65 | pred = resize( 66 | pred, 67 | size=re_size, 68 | mode='nearest', 69 | align_corners=None, 70 | warning=False) 71 | return pred 72 | 73 | def _decode_head_forward(self, x, seg_label=None, return_loss=False): 74 | x = self.decoder(x) 75 | if return_loss: 76 | return self.seg_head.forward_train(x, seg_label) 77 | else: 78 | return self.seg_head.forward_test(x) 79 | 80 | def _auxiliary_head_forward(self, x, seg_label, return_loss=True): 81 | return self.aux_head.forward_train(x, seg_label) 82 | 83 | def forward_train(self, img, img_metas, seg_label, **kwargs): 84 | # the img_metas may useful in other framework 85 | x = self.extract_feat(img) 86 | 87 | losses = dict() 88 | loss_decode = self._decode_head_forward(x, seg_label, return_loss=True) 89 | losses.update(loss_decode) 90 | 91 | if self.with_auxiliary_head: 92 | loss_aux = self._auxiliary_head_forward(x, seg_label, return_loss=True) 93 | losses.update(loss_aux) 94 | 95 | return losses 96 | 97 | def forward_test(self, img, img_metas, rescale=True, **kwargs): 98 | # TODO support sliding window evaluator 99 | pred = self.encode_decode(img, img_metas, rescale) 100 | pred = pred.cpu().numpy() 101 | # unravel batch dim 102 | pred = list(pred) 103 | return pred 104 | -------------------------------------------------------------------------------- /mctrans/models/trans/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JiYuanFeng/MCTrans/9b8b5677eef584b423d5e1630680a4b667cbe823/mctrans/models/trans/__init__.py -------------------------------------------------------------------------------- /mctrans/models/trans/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from ..ops.modules import MSDeformAttn 4 | from .utils import _get_activation_fn, _get_clones 5 | 6 | 7 | class DSALayer(nn.Module): 8 | def __init__(self, 9 | d_model=256, d_ffn=1024, 10 | dropout=0.1, activation="relu", 11 | n_levels=4, n_heads=8, n_points=4): 12 | super().__init__() 13 | 14 | # self attention 15 | self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) 16 | self.dropout1 = nn.Dropout(dropout) 17 | self.norm1 = nn.LayerNorm(d_model) 18 | 19 | # ffn 20 | self.linear1 = nn.Linear(d_model, d_ffn) 21 | self.activation = _get_activation_fn(activation) 22 | self.dropout2 = nn.Dropout(dropout) 23 | self.linear2 = nn.Linear(d_ffn, d_model) 24 | self.dropout3 = nn.Dropout(dropout) 25 | self.norm2 = nn.LayerNorm(d_model) 26 | 27 | @staticmethod 28 | def with_pos_embed(tensor, pos): 29 | return tensor if pos is None else tensor + pos 30 | 31 | def forward_ffn(self, src): 32 | src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) 33 | src = src + self.dropout3(src2) 34 | src = self.norm2(src) 35 | return src 36 | 37 | def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None): 38 | # self attention 39 | src2 = self.self_attn(self.with_pos_embed(src, pos), 40 | reference_points, 41 | src, 42 | spatial_shapes, 43 | level_start_index, 44 | padding_mask) 45 | src = src + self.dropout1(src2) 46 | src = self.norm1(src) 47 | # ffn 48 | src = self.forward_ffn(src) 49 | 50 | return src 51 | 52 | 53 | class DSA(nn.Module): 54 | def __init__(self, att_layer, n_layers): 55 | super().__init__() 56 | self.layers = _get_clones(att_layer, n_layers) 57 | 58 | @staticmethod 59 | def get_reference_points(spatial_shapes, valid_ratios, device): 60 | reference_points_list = [] 61 | for lvl, (H_, W_) in enumerate(spatial_shapes): 62 | ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), 63 | torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device)) 64 | ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_) 65 | ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_) 66 | ref = torch.stack((ref_x, ref_y), -1) 67 | reference_points_list.append(ref) 68 | reference_points = torch.cat(reference_points_list, 1) 69 | reference_points = reference_points[:, :, None] * valid_ratios[:, None] 70 | return reference_points 71 | 72 | def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None): 73 | output = src 74 | reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device) 75 | for _, layer in enumerate(self.layers): 76 | output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask) 77 | return output 78 | 79 | 80 | class CALayer(nn.Module): 81 | def __init__(self, d_model=256, d_ffn=1024, 82 | dropout=0.1, activation="relu", n_heads=8): 83 | super().__init__() 84 | self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) 85 | self.dropout1 = nn.Dropout(dropout) 86 | self.norm1 = nn.LayerNorm(d_model) 87 | 88 | # self attention 89 | self.cross_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) 90 | self.dropout2 = nn.Dropout(dropout) 91 | self.norm2 = nn.LayerNorm(d_model) 92 | 93 | # ffn 94 | self.linear1 = nn.Linear(d_model, d_ffn) 95 | self.activation = _get_activation_fn(activation) 96 | self.dropout3 = nn.Dropout(dropout) 97 | self.linear2 = nn.Linear(d_ffn, d_model) 98 | self.dropout4 = nn.Dropout(dropout) 99 | self.norm3 = nn.LayerNorm(d_model) 100 | 101 | @staticmethod 102 | def with_pos_embed(tensor, pos): 103 | return tensor if pos is None else tensor + pos 104 | 105 | def forward_ffn(self, tgt): 106 | tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) 107 | tgt = tgt + self.dropout4(tgt2) 108 | tgt = self.norm3(tgt) 109 | return tgt 110 | 111 | def forward(self, tgt, src): 112 | # self attention 113 | q = k = tgt 114 | tgt2 = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1))[0].transpose(0, 1) 115 | tgt = tgt + self.dropout2(tgt2) 116 | tgt = self.norm2(tgt) 117 | # cross attention 118 | tgt2 = self.cross_attn(tgt, src, src)[0] 119 | tgt = tgt + self.dropout1(tgt2) 120 | tgt = self.norm1(tgt) 121 | # ffn 122 | tgt = self.forward_ffn(tgt) 123 | 124 | return tgt 125 | 126 | 127 | class CA(nn.Module): 128 | def __init__(self, att_layer, n_layers, n_category=2, d_model=256): 129 | super().__init__() 130 | self.layers = _get_clones(att_layer, n_layers) 131 | self.proxy_embed = nn.Parameter(torch.zeros(1, n_category, d_model)) 132 | 133 | def forward(self, src): 134 | query = None 135 | B = src.shape[0] 136 | for idx, layer in enumerate(self.layers): 137 | if idx == 0: 138 | query = self.proxy_embed.expand(B, -1, -1) 139 | else: 140 | query += self.proxy_embed.expand(B, -1, -1) 141 | query = layer(query, src) 142 | return query -------------------------------------------------------------------------------- /mctrans/models/trans/utils.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Modified from DETR (https://github.com/facebookresearch/detr) 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 8 | # ------------------------------------------------------------------------ 9 | 10 | """ 11 | Various positional encodings for the transformer. 12 | """ 13 | import copy 14 | import math 15 | import torch 16 | from typing import Optional 17 | 18 | from torch import nn, Tensor 19 | import torch.nn.functional as F 20 | 21 | 22 | def _get_clones(module, N): 23 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 24 | 25 | 26 | def _get_activation_fn(activation): 27 | """Return an activation function given a string""" 28 | if activation == "relu": 29 | return F.relu 30 | if activation == "gelu": 31 | return F.gelu 32 | if activation == "glu": 33 | return F.glu 34 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 35 | 36 | 37 | def build_position_encoding(position_embedding='sine', hidden_dim=240): 38 | N_steps = hidden_dim // 2 39 | if position_embedding in ('v2', 'sine'): 40 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True) 41 | elif position_embedding in ('v3', 'learned'): 42 | position_embedding = PositionEmbeddingLearned(N_steps) 43 | else: 44 | raise ValueError(f"not supported {position_embedding}") 45 | 46 | return position_embedding 47 | 48 | 49 | class NestedTensor(object): 50 | def __init__(self, tensors, mask: Optional[Tensor]): 51 | self.tensors = tensors 52 | self.mask = mask 53 | 54 | def to(self, device, non_blocking=False): 55 | # type: (Device) -> NestedTensor # noqa 56 | cast_tensor = self.tensors.to(device, non_blocking=non_blocking) 57 | mask = self.mask 58 | if mask is not None: 59 | assert mask is not None 60 | cast_mask = mask.to(device, non_blocking=non_blocking) 61 | else: 62 | cast_mask = None 63 | return NestedTensor(cast_tensor, cast_mask) 64 | 65 | def record_stream(self, *args, **kwargs): 66 | self.tensors.record_stream(*args, **kwargs) 67 | if self.mask is not None: 68 | self.mask.record_stream(*args, **kwargs) 69 | 70 | def decompose(self): 71 | return self.tensors, self.mask 72 | 73 | def __repr__(self): 74 | return str(self.tensors) 75 | 76 | 77 | class DevPositionEmbeddingSine(nn.Module): 78 | """ 79 | This is a more standard version of the position embedding, very similar to the one 80 | used by the Attention is all you need paper, generalized to work on images. 81 | """ 82 | 83 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 84 | super().__init__() 85 | self.num_pos_feats = num_pos_feats 86 | self.temperature = temperature 87 | self.normalize = normalize 88 | if scale is not None and normalize is False: 89 | raise ValueError("normalize should be True if scale is passed") 90 | if scale is None: 91 | scale = 2 * math.pi 92 | self.scale = scale 93 | 94 | def forward(self, tensor_list: NestedTensor): 95 | x = tensor_list.tensors 96 | mask = tensor_list.mask 97 | assert mask is not None 98 | # (b, D, H, W) 99 | not_mask = ~mask 100 | z_embed = not_mask.cumsum(1, dtype=torch.float32) 101 | y_embed = not_mask.cumsum(2, dtype=torch.float32) 102 | x_embed = not_mask.cumsum(3, dtype=torch.float32) 103 | if self.normalize: 104 | eps = 1e-6 105 | z_embed = (z_embed - 0.5) / (y_embed[:, -1:, :, :] + eps) * self.scale 106 | y_embed = (y_embed - 0.5) / (y_embed[:, :, -1:, :] + eps) * self.scale 107 | x_embed = (x_embed - 0.5) / (x_embed[:, :, :, -1:] + eps) * self.scale 108 | 109 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 110 | dim_t = self.temperature ** (2 * (dim_t // 3) / self.num_pos_feats) 111 | 112 | pos_x = x_embed[:, :, :, :, None] / dim_t 113 | pos_y = y_embed[:, :, :, :, None] / dim_t 114 | pos_z = z_embed[:, :, :, :, None] / dim_t 115 | 116 | pos_x = torch.stack((pos_x[:, :, :, :, 0::2].sin(), pos_x[:, :, :, :, 1::2].cos()), dim=5).flatten(4) 117 | pos_y = torch.stack((pos_y[:, :, :, :, 0::2].sin(), pos_y[:, :, :, :, 1::2].cos()), dim=5).flatten(4) 118 | pos_z = torch.stack((pos_z[:, :, :, :, 0::2].sin(), pos_z[:, :, :, :, 1::2].cos()), dim=5).flatten(4) 119 | pos = torch.cat((pos_z, pos_y, pos_x), dim=4).permute(0, 4, 1, 2, 3) 120 | return pos 121 | 122 | 123 | class PositionEmbeddingSine(nn.Module): 124 | """ 125 | This is a more standard version of the position embedding, very similar to the one 126 | used by the Attention is all you need paper, generalized to work on images. 127 | """ 128 | 129 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 130 | super().__init__() 131 | self.num_pos_feats = num_pos_feats 132 | self.temperature = temperature 133 | self.normalize = normalize 134 | if scale is not None and normalize is False: 135 | raise ValueError("normalize should be True if scale is passed") 136 | if scale is None: 137 | scale = 2 * math.pi 138 | self.scale = scale 139 | 140 | def forward(self, tensor_list: NestedTensor): 141 | x = tensor_list.tensors 142 | mask = tensor_list.mask 143 | assert mask is not None 144 | not_mask = ~mask 145 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 146 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 147 | if self.normalize: 148 | eps = 1e-6 149 | y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale 150 | x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale 151 | 152 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 153 | dim_t = self.temperature ** (2 * (dim_t // 3) / self.num_pos_feats) 154 | 155 | pos_x = x_embed[:, :, :, None] / dim_t 156 | pos_y = y_embed[:, :, :, None] / dim_t 157 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 158 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 159 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 160 | return pos 161 | 162 | 163 | class PositionEmbeddingLearned(nn.Module): 164 | """ 165 | Absolute pos embedding, learned. 166 | """ 167 | 168 | def __init__(self, num_pos_feats=256): 169 | super().__init__() 170 | self.row_embed = nn.Embedding(50, num_pos_feats) 171 | self.col_embed = nn.Embedding(50, num_pos_feats) 172 | self.reset_parameters() 173 | 174 | def reset_parameters(self): 175 | nn.init.uniform_(self.row_embed.weight) 176 | nn.init.uniform_(self.col_embed.weight) 177 | 178 | def forward(self, tensor_list: NestedTensor): 179 | x = tensor_list.tensors 180 | h, w = x.shape[-2:] 181 | i = torch.arange(w, device=x.device) 182 | j = torch.arange(h, device=x.device) 183 | x_emb = self.col_embed(i) 184 | y_emb = self.row_embed(j) 185 | pos = torch.cat([ 186 | x_emb.unsqueeze(0).repeat(h, 1, 1), 187 | y_emb.unsqueeze(1).repeat(1, w, 1), 188 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 189 | return pos 190 | 191 | -------------------------------------------------------------------------------- /mctrans/models/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def conv3x3(in_planes, out_planes, dilation=1): 5 | return nn.Conv2d( 6 | in_planes, 7 | out_planes, 8 | kernel_size=3, 9 | padding=dilation, 10 | dilation=dilation) 11 | 12 | 13 | def conv_bn_relu(in_channels, out_channels, kernel_size=3, padding=1, stride=1): 14 | return nn.Sequential(nn.Conv2d(in_channels, 15 | out_channels, 16 | kernel_size=kernel_size, 17 | padding=padding, 18 | stride=stride), 19 | nn.BatchNorm2d(out_channels), 20 | nn.ReLU(inplace=True)) 21 | 22 | 23 | def deconv_bn_relu(in_channels, out_channels, kernel_size=3, padding=1, stride=2): 24 | return nn.Sequential(nn.ConvTranspose2d(in_channels, 25 | out_channels, 26 | kernel_size=kernel_size, 27 | padding=padding, 28 | stride=stride, 29 | output_padding=1), 30 | nn.BatchNorm2d(out_channels), 31 | nn.ReLU(inplace=True)) 32 | 33 | 34 | def make_vgg_layer(inplanes, 35 | planes, 36 | num_blocks, 37 | dilation=1, 38 | with_bn=False, 39 | down_sample=False, 40 | ceil_mode=False): 41 | layers = [] 42 | if down_sample: 43 | layers.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=ceil_mode)) 44 | for _ in range(num_blocks): 45 | layers.append(conv3x3(inplanes, planes, dilation)) 46 | if with_bn: 47 | layers.append(nn.BatchNorm2d(planes)) 48 | layers.append(nn.ReLU(inplace=True)) 49 | inplanes = planes 50 | return nn.Sequential(*layers) 51 | -------------------------------------------------------------------------------- /mctrans/pipline/__init__.py: -------------------------------------------------------------------------------- 1 | from .segpipline import SegPipline 2 | -------------------------------------------------------------------------------- /mctrans/pipline/segpipline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel 3 | from mmcv.runner import build_runner, build_optimizer 4 | from mmcv.runner import DistSamplerSeedHook, DistEvalHook, EvalHook 5 | 6 | from mctrans.data import build_dataset, build_dataloader 7 | from mctrans.models import build_model 8 | from mctrans.utils import get_root_logger 9 | 10 | 11 | class SegPipline(object): 12 | def __init__(self, cfg, distributed=False, validate=False): 13 | self.cfg = cfg 14 | self.distributed = distributed 15 | self.logger = get_root_logger(cfg.log_level) 16 | self.data_loaders = self._build_data_loader() 17 | self.model = self._build_model() 18 | # need to update 19 | self.optimizer = build_optimizer(self.model, cfg.optimizer) 20 | self.runner = self._build_runner() 21 | if validate: 22 | self._build_eval_hook() 23 | if cfg.resume_from: 24 | self.runner.resume(cfg.resume_from) 25 | elif cfg.load_from: 26 | self.runner.load_checkpoint(cfg.load_from) 27 | else: 28 | pass 29 | 30 | def _build_data_loader(self): 31 | dataset = [build_dataset(self.cfg.data.train)] 32 | if len(self.cfg.workflow) == 2: 33 | dataset.append(build_dataset(self.cfg.data.val)) 34 | data_loaders = [ 35 | build_dataloader( 36 | ds, 37 | self.cfg.data.samples_per_gpu, 38 | self.cfg.data.workers_per_gpu, 39 | len(self.cfg.gpu_ids), 40 | dist=self.distributed, 41 | seed=self.cfg.seed, 42 | drop_last=True) for ds in dataset] 43 | return data_loaders 44 | 45 | def _build_model(self): 46 | model = build_model(self.cfg.model) 47 | if self.distributed: 48 | find_unused_parameters = self.cfg.get('find_unused_parameters', True) 49 | # Sets the `find_unused_parameters` parameter in 50 | # torch.nn.parallel.DistributedDataParallel 51 | model = MMDistributedDataParallel( 52 | model.cuda(), 53 | device_ids=[torch.cuda.current_device()], 54 | broadcast_buffers=False, 55 | find_unused_parameters=find_unused_parameters) 56 | else: 57 | model = MMDataParallel( 58 | model.cuda(self.cfg.gpu_ids[0]), device_ids=self.cfg.gpu_ids) 59 | return model 60 | 61 | def _build_eval_hook(self): 62 | val_dataset = build_dataset(self.cfg.data.val) 63 | val_dataloader = build_dataloader( 64 | val_dataset, 65 | samples_per_gpu=1, 66 | workers_per_gpu=self.cfg.data.workers_per_gpu, 67 | dist=self.distributed, 68 | shuffle=False) 69 | eval_cfg = self.cfg.get('evaluation', {}) 70 | eval_cfg['by_epoch'] = self.cfg.runner['type'] != 'IterBasedRunner' 71 | eval_hook = DistEvalHook if self.distributed else EvalHook 72 | self.runner.register_hook(eval_hook(val_dataloader, **eval_cfg)) 73 | 74 | def _build_runner(self): 75 | runner = build_runner( 76 | cfg=self.cfg.runner, 77 | default_args=dict( 78 | model=self.model, 79 | batch_processor=None, 80 | optimizer=self.optimizer, 81 | work_dir=self.cfg.work_dir, 82 | logger=self.logger, 83 | meta=None)) 84 | 85 | # register hooks 86 | runner.register_training_hooks(self.cfg.lr_config, self.cfg.optimizer_config, 87 | self.cfg.checkpoint_config, self.cfg.log_config, 88 | self.cfg.get('momentum_config', None)) 89 | 90 | if self.distributed: 91 | runner.register_hook(DistSamplerSeedHook()) 92 | 93 | return runner 94 | 95 | def _report_details(self): 96 | self.logger.info(self.model) 97 | 98 | def run(self): 99 | self.runner.run(self.data_loaders, self.cfg.workflow, self.cfg.max_epochs) 100 | -------------------------------------------------------------------------------- /mctrans/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import get_root_logger 2 | -------------------------------------------------------------------------------- /mctrans/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from mmcv.utils import get_logger 4 | 5 | 6 | def get_root_logger(log_file=None, log_level=logging.INFO): 7 | """Get the root logger. 8 | 9 | The logger will be initialized if it has not been initialized. By default a 10 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 11 | also be added. The name of the root logger is the top-level package name, 12 | e.g., "mmseg". 13 | 14 | Args: 15 | log_file (str | None): The log filename. If specified, a FileHandler 16 | will be added to the root logger. 17 | log_level (int): The root logger level. Note that only the process of 18 | rank 0 is affected, while other processes will set the level to 19 | "Error" and be silent most of the time. 20 | 21 | Returns: 22 | logging.Logger: The root logger. 23 | """ 24 | 25 | logger = get_logger(name='MCTrans', log_file=log_file, log_level=log_level) 26 | 27 | return logger 28 | -------------------------------------------------------------------------------- /mctrans/utils/misc.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Optional 3 | 4 | import numpy as np 5 | import torch 6 | import torch.distributed as dist 7 | from torch import Tensor 8 | 9 | def sync_param(input, reduction='mean'): 10 | if isinstance(input, np.ndarray): 11 | sync_input = torch.from_numpy(input).cuda() 12 | elif isinstance(input, torch.Tensor): 13 | sync_input = input.clone() 14 | else: 15 | raise ValueError('input should be torch tensor or ndarray') 16 | dist.all_reduce(sync_input) 17 | if reduction == 'mean': 18 | sync_input.div_(dist.get_world_size()) 19 | return sync_input 20 | 21 | def is_distributed(): 22 | if dist.is_available() and dist.is_initialized(): 23 | return True 24 | else: 25 | return False 26 | 27 | 28 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.13.0 2 | addict==2.4.0 3 | cachetools==4.2.2 4 | certifi==2021.5.30 5 | chardet==4.0.0 6 | future==0.18.2 7 | google-auth==1.32.1 8 | google-auth-oauthlib==0.4.4 9 | grpcio==1.38.1 10 | idna==2.10 11 | importlib-metadata==4.6.1 12 | joblib==1.0.1 13 | Markdown==3.3.4 14 | mmcv==1.3.9 15 | monai==0.5.3 16 | numpy==1.21.0 17 | oauthlib==3.1.1 18 | opencv-python==4.5.3.56 19 | Pillow==8.3.1 20 | prettytable==2.1.0 21 | protobuf==3.17.3 22 | pyasn1==0.4.8 23 | pyasn1-modules==0.2.8 24 | PyYAML==5.4.1 25 | requests==2.25.1 26 | requests-oauthlib==1.3.0 27 | rsa==4.7.2 28 | scikit-learn==0.24.2 29 | scipy==1.7.0 30 | six==1.16.0 31 | sklearn==0.0 32 | tensorboard==2.5.0 33 | tensorboard-data-server==0.6.1 34 | tensorboard-plugin-wit==1.8.0 35 | threadpoolctl==2.2.0 36 | torch==1.7.1+cu110 37 | torchaudio==0.7.2 38 | torchvision==0.8.2+cu110 39 | typing-extensions==3.10.0.0 40 | urllib3==1.26.6 41 | wcwidth==0.2.5 42 | Werkzeug==2.0.1 43 | yapf==0.31.0 44 | zipp==3.5.0 45 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | 4 | setuptools.setup( 5 | name='MCTrans', 6 | version='0.0.1', 7 | author='Yuanfeng Ji, Jiamin Ren', 8 | author_email='u3008013@connect.hku.hk', 9 | description='Simple Framework for Medcial Image Analysis', 10 | url='https://github.com/JiYuanFeng/MCTrans', 11 | packages=setuptools.find_packages(), 12 | license='Internal', 13 | include_package_data=True, 14 | zip_safe=False, 15 | classifiers=[ 16 | 'License :: OSI Approved :: Internal', 17 | 'Programming Language :: Python', 18 | 'Intended Audience :: Developers', 19 | 'Operating System :: OS Independent', 20 | ] 21 | ) 22 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import os 4 | import os.path as osp 5 | import time 6 | import warnings 7 | warnings.filterwarnings("ignore") 8 | 9 | import mmcv 10 | import torch 11 | from mmcv.runner import init_dist, set_random_seed 12 | from mmcv.utils import Config, DictAction 13 | 14 | from mctrans.pipline import SegPipline 15 | from mctrans.utils import get_root_logger 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser(description='Train a segmentation') 20 | parser.add_argument('config', help='train config file path') 21 | parser.add_argument('--work-dir', help='the dir to save logs and models') 22 | parser.add_argument( 23 | '--load-from', help='the checkpoint file to load weights from') 24 | parser.add_argument( 25 | '--resume-from', help='the checkpoint file to resume from') 26 | parser.add_argument( 27 | '--no-validate', 28 | action='store_true', 29 | help='whether not to evaluate the checkpoint during training') 30 | group_gpus = parser.add_mutually_exclusive_group() 31 | group_gpus.add_argument( 32 | '--gpus', 33 | type=int, 34 | help='number of gpus to use ' 35 | '(only applicable to non-distributed training)') 36 | group_gpus.add_argument( 37 | '--gpu-ids', 38 | type=int, 39 | nargs='+', 40 | help='ids of gpus to use ' 41 | '(only applicable to non-distributed training)') 42 | parser.add_argument('--seed', type=int, default=0, help='random seed') 43 | parser.add_argument( 44 | '--deterministic', 45 | action='store_true', 46 | help='whether to set deterministic options for CUDNN backend.') 47 | parser.add_argument( 48 | '--options', nargs='+', action=DictAction, help='custom options') 49 | parser.add_argument( 50 | '--launcher', 51 | choices=['none', 'pytorch', 'slurm', 'mpi'], 52 | default='none', 53 | help='job launcher') 54 | parser.add_argument('--local_rank', type=int, default=0) 55 | args = parser.parse_args() 56 | if 'LOCAL_RANK' not in os.environ: 57 | os.environ['LOCAL_RANK'] = str(args.local_rank) 58 | 59 | parser.add_argument('--port', type=int, default=29500, 60 | help='port only works when launcher=="slurm"') 61 | 62 | return args 63 | 64 | 65 | def main(): 66 | args = parse_args() 67 | 68 | cfg = Config.fromfile(args.config) 69 | if args.options is not None: 70 | cfg.merge_from_dict(args.options) 71 | # set cudnn_benchmark 72 | if cfg.get('cudnn_benchmark', False): 73 | torch.backends.cudnn.benchmark = True 74 | 75 | # work_dir is determined in this priority: CLI > segment in file > filename 76 | if args.work_dir is not None: 77 | # update configs according to CLI args if args.work_dir is not None 78 | cfg.work_dir = args.work_dir 79 | elif cfg.get('work_dir', None) is None: 80 | # use config filename as default work_dir if cfg.work_dir is None 81 | cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0]) 82 | 83 | if args.load_from is not None: 84 | cfg.load_from = args.load_from 85 | if args.resume_from is not None: 86 | cfg.resume_from = args.resume_from 87 | if args.gpu_ids is not None: 88 | cfg.gpu_ids = args.gpu_ids 89 | else: 90 | cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus) 91 | 92 | # init distributed env first, since logger depends on the dist info. 93 | if args.launcher == 'none': 94 | distributed = False 95 | else: 96 | distributed = True 97 | init_dist(args.launcher, **cfg.dist_params) 98 | 99 | # create work_dir 100 | mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) 101 | # dump config 102 | cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) 103 | # init the logger before other steps 104 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 105 | log_file = osp.join(cfg.work_dir, f'{timestamp}.log') 106 | logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) 107 | # init the meta dict to record some important information such as 108 | # environment info and seed, which will be logged 109 | meta = dict() 110 | 111 | # log some basic info 112 | logger.info(f'Distributed training: {distributed}') 113 | logger.info(f'Config:\n{cfg.pretty_text}') 114 | 115 | # set random seeds 116 | if args.seed is not None: 117 | logger.info(f'Set random seed to {args.seed}, deterministic: ' 118 | f'{args.deterministic}') 119 | set_random_seed(args.seed, deterministic=args.deterministic) 120 | cfg.seed = args.seed 121 | meta['seed'] = args.seed 122 | meta['exp_name'] = osp.basename(args.config) 123 | 124 | SegPipline(cfg, distributed, not args.no_validate).run() 125 | 126 | if __name__ == '__main__': 127 | main() 128 | -------------------------------------------------------------------------------- /tools/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CFG="configs/pannuke-vgg32/mctrans_vgg32_d5_256x256_400ep_pannuke.py" 3 | WORK_DIR=$(echo ${CFG%.*} | sed -e "s/configs/work_dirs/g")/ 4 | python3 tools/train.py ${CFG} --work-dir ${WORK_DIR} --gpus 1 5 | 6 | --------------------------------------------------------------------------------