├── LICENSE
├── README.md
├── configs
├── _base_
│ ├── datasets
│ │ └── pannuke.py
│ ├── default_runtime.py
│ ├── models
│ │ ├── attunet_res34_d5.py
│ │ ├── attunet_vgg32_d5.py
│ │ ├── cenet_res34_d5.py
│ │ ├── cenet_vgg32_d5.py
│ │ ├── mctrans_res34_d5.py
│ │ ├── mctrans_vgg32_d5.py
│ │ ├── nonlocal_res34_d5.py
│ │ ├── nonlocal_vgg32_d5.py
│ │ ├── transunet_res34_d5.py
│ │ ├── transunet_vgg32_d5.py
│ │ ├── unet++_res34_d5.py
│ │ ├── unet++_vgg32_d5.py
│ │ ├── unet_res34_d5.py
│ │ └── unet_vgg32_d5.py
│ └── schedules
│ │ └── pannuke_bs32_ep400.py
└── pannuke-vgg32
│ ├── attunet_vgg32_d5_256x256_400ep_pannuke.py
│ ├── cenet_vgg32_d5_256x256_400ep_pannuke.py
│ ├── mctrans_vgg32_d5_256x256_400ep_pannuke.py
│ ├── nonlocal_vgg32_d5_256x256_400ep_pannuke.py
│ ├── transunet_vgg32_d5_256x256_400ep_pannuke.py
│ ├── unet++_vgg32_d5_256x256_400ep_pannuke.py
│ └── unet_vgg32_d5_256x256_400ep_pannuke.py
├── docs
└── guidance.md
├── imgs
├── logo.png
└── overview.png
├── mctrans
├── __init__.py
├── data
│ ├── __init__.py
│ ├── builder.py
│ ├── datasets
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── pannuke.py
│ │ └── polyp.py
│ └── transforms
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── monai.py
│ │ └── utils.py
├── metrics
│ ├── __init__.py
│ ├── base.py
│ ├── builder.py
│ └── hausdorff_distance.py
├── models
│ ├── __init__.py
│ ├── builder.py
│ ├── centers
│ │ ├── __init__.py
│ │ ├── cenet.py
│ │ ├── mctrans.py
│ │ ├── non_local.py
│ │ └── vit.py
│ ├── decoders
│ │ ├── __init__.py
│ │ ├── unet_decoder.py
│ │ └── unet_plus_plus_decoder.py
│ ├── encoders
│ │ ├── __init__.py
│ │ ├── resnet.py
│ │ └── vgg.py
│ ├── heads
│ │ ├── __init__.py
│ │ ├── basic_seg_head.py
│ │ └── mctrans_aux_head.py
│ ├── losses
│ │ ├── __init__.py
│ │ ├── cross_entropy_loss.py
│ │ ├── debug_focal.py
│ │ └── monai.py
│ ├── ops
│ │ ├── __init__.py
│ │ ├── functions
│ │ │ ├── __init__.py
│ │ │ └── ms_deform_attn_func.py
│ │ ├── make.sh
│ │ ├── modules
│ │ │ ├── __init__.py
│ │ │ └── ms_deform_attn.py
│ │ ├── setup.py
│ │ ├── src
│ │ │ ├── cpu
│ │ │ │ ├── ms_deform_attn_cpu.cpp
│ │ │ │ └── ms_deform_attn_cpu.h
│ │ │ ├── cuda
│ │ │ │ ├── ms_deform_attn_cuda.cu
│ │ │ │ ├── ms_deform_attn_cuda.h
│ │ │ │ └── ms_deform_im2col_cuda.cuh
│ │ │ ├── ms_deform_attn.h
│ │ │ └── vision.cpp
│ │ └── test.py
│ ├── segmentors
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── encoder_decoder.py
│ ├── trans
│ │ ├── __init__.py
│ │ ├── transformer.py
│ │ └── utils.py
│ └── utils.py
├── pipline
│ ├── __init__.py
│ └── segpipline.py
└── utils
│ ├── __init__.py
│ ├── logger.py
│ └── misc.py
├── requirements.txt
├── setup.py
└── tools
├── train.py
└── train.sh
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2020 SenseTime. All Rights Reserved.
2 |
3 | Apache License
4 | Version 2.0, January 2004
5 | http://www.apache.org/licenses/
6 |
7 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
8 |
9 | 1. Definitions.
10 |
11 | "License" shall mean the terms and conditions for use, reproduction,
12 | and distribution as defined by Sections 1 through 9 of this document.
13 |
14 | "Licensor" shall mean the copyright owner or entity authorized by
15 | the copyright owner that is granting the License.
16 |
17 | "Legal Entity" shall mean the union of the acting entity and all
18 | other entities that control, are controlled by, or are under common
19 | control with that entity. For the purposes of this definition,
20 | "control" means (i) the power, direct or indirect, to cause the
21 | direction or management of such entity, whether by contract or
22 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
23 | outstanding shares, or (iii) beneficial ownership of such entity.
24 |
25 | "You" (or "Your") shall mean an individual or Legal Entity
26 | exercising permissions granted by this License.
27 |
28 | "Source" form shall mean the preferred form for making modifications,
29 | including but not limited to software source code, documentation
30 | source, and configuration files.
31 |
32 | "Object" form shall mean any form resulting from mechanical
33 | transformation or translation of a Source form, including but
34 | not limited to compiled object code, generated documentation,
35 | and conversions to other media types.
36 |
37 | "Work" shall mean the work of authorship, whether in Source or
38 | Object form, made available under the License, as indicated by a
39 | copyright notice that is included in or attached to the work
40 | (an example is provided in the Appendix below).
41 |
42 | "Derivative Works" shall mean any work, whether in Source or Object
43 | form, that is based on (or derived from) the Work and for which the
44 | editorial revisions, annotations, elaborations, or other modifications
45 | represent, as a whole, an original work of authorship. For the purposes
46 | of this License, Derivative Works shall not include works that remain
47 | separable from, or merely link (or bind by name) to the interfaces of,
48 | the Work and Derivative Works thereof.
49 |
50 | "Contribution" shall mean any work of authorship, including
51 | the original version of the Work and any modifications or additions
52 | to that Work or Derivative Works thereof, that is intentionally
53 | submitted to Licensor for inclusion in the Work by the copyright owner
54 | or by an individual or Legal Entity authorized to submit on behalf of
55 | the copyright owner. For the purposes of this definition, "submitted"
56 | means any form of electronic, verbal, or written communication sent
57 | to the Licensor or its representatives, including but not limited to
58 | communication on electronic mailing lists, source code control systems,
59 | and issue tracking systems that are managed by, or on behalf of, the
60 | Licensor for the purpose of discussing and improving the Work, but
61 | excluding communication that is conspicuously marked or otherwise
62 | designated in writing by the copyright owner as "Not a Contribution."
63 |
64 | "Contributor" shall mean Licensor and any individual or Legal Entity
65 | on behalf of whom a Contribution has been received by Licensor and
66 | subsequently incorporated within the Work.
67 |
68 | 2. Grant of Copyright License. Subject to the terms and conditions of
69 | this License, each Contributor hereby grants to You a perpetual,
70 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
71 | copyright license to reproduce, prepare Derivative Works of,
72 | publicly display, publicly perform, sublicense, and distribute the
73 | Work and such Derivative Works in Source or Object form.
74 |
75 | 3. Grant of Patent License. Subject to the terms and conditions of
76 | this License, each Contributor hereby grants to You a perpetual,
77 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
78 | (except as stated in this section) patent license to make, have made,
79 | use, offer to sell, sell, import, and otherwise transfer the Work,
80 | where such license applies only to those patent claims licensable
81 | by such Contributor that are necessarily infringed by their
82 | Contribution(s) alone or by combination of their Contribution(s)
83 | with the Work to which such Contribution(s) was submitted. If You
84 | institute patent litigation against any entity (including a
85 | cross-claim or counterclaim in a lawsuit) alleging that the Work
86 | or a Contribution incorporated within the Work constitutes direct
87 | or contributory patent infringement, then any patent licenses
88 | granted to You under this License for that Work shall terminate
89 | as of the date such litigation is filed.
90 |
91 | 4. Redistribution. You may reproduce and distribute copies of the
92 | Work or Derivative Works thereof in any medium, with or without
93 | modifications, and in Source or Object form, provided that You
94 | meet the following conditions:
95 |
96 | (a) You must give any other recipients of the Work or
97 | Derivative Works a copy of this License; and
98 |
99 | (b) You must cause any modified files to carry prominent notices
100 | stating that You changed the files; and
101 |
102 | (c) You must retain, in the Source form of any Derivative Works
103 | that You distribute, all copyright, patent, trademark, and
104 | attribution notices from the Source form of the Work,
105 | excluding those notices that do not pertain to any part of
106 | the Derivative Works; and
107 |
108 | (d) If the Work includes a "NOTICE" text file as part of its
109 | distribution, then any Derivative Works that You distribute must
110 | include a readable copy of the attribution notices contained
111 | within such NOTICE file, excluding those notices that do not
112 | pertain to any part of the Derivative Works, in at least one
113 | of the following places: within a NOTICE text file distributed
114 | as part of the Derivative Works; within the Source form or
115 | documentation, if provided along with the Derivative Works; or,
116 | within a display generated by the Derivative Works, if and
117 | wherever such third-party notices normally appear. The contents
118 | of the NOTICE file are for informational purposes only and
119 | do not modify the License. You may add Your own attribution
120 | notices within Derivative Works that You distribute, alongside
121 | or as an addendum to the NOTICE text from the Work, provided
122 | that such additional attribution notices cannot be construed
123 | as modifying the License.
124 |
125 | You may add Your own copyright statement to Your modifications and
126 | may provide additional or different license terms and conditions
127 | for use, reproduction, or distribution of Your modifications, or
128 | for any such Derivative Works as a whole, provided Your use,
129 | reproduction, and distribution of the Work otherwise complies with
130 | the conditions stated in this License.
131 |
132 | 5. Submission of Contributions. Unless You explicitly state otherwise,
133 | any Contribution intentionally submitted for inclusion in the Work
134 | by You to the Licensor shall be under the terms and conditions of
135 | this License, without any additional terms or conditions.
136 | Notwithstanding the above, nothing herein shall supersede or modify
137 | the terms of any separate license agreement you may have executed
138 | with Licensor regarding such Contributions.
139 |
140 | 6. Trademarks. This License does not grant permission to use the trade
141 | names, trademarks, service marks, or product names of the Licensor,
142 | except as required for reasonable and customary use in describing the
143 | origin of the Work and reproducing the content of the NOTICE file.
144 |
145 | 7. Disclaimer of Warranty. Unless required by applicable law or
146 | agreed to in writing, Licensor provides the Work (and each
147 | Contributor provides its Contributions) on an "AS IS" BASIS,
148 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
149 | implied, including, without limitation, any warranties or conditions
150 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
151 | PARTICULAR PURPOSE. You are solely responsible for determining the
152 | appropriateness of using or redistributing the Work and assume any
153 | risks associated with Your exercise of permissions under this License.
154 |
155 | 8. Limitation of Liability. In no event and under no legal theory,
156 | whether in tort (including negligence), contract, or otherwise,
157 | unless required by applicable law (such as deliberate and grossly
158 | negligent acts) or agreed to in writing, shall any Contributor be
159 | liable to You for damages, including any direct, indirect, special,
160 | incidental, or consequential damages of any character arising as a
161 | result of this License or out of the use or inability to use the
162 | Work (including but not limited to damages for loss of goodwill,
163 | work stoppage, computer failure or malfunction, or any and all
164 | other commercial damages or losses), even if such Contributor
165 | has been advised of the possibility of such damages.
166 |
167 | 9. Accepting Warranty or Additional Liability. While redistributing
168 | the Work or Derivative Works thereof, You may choose to offer,
169 | and charge a fee for, acceptance of support, warranty, indemnity,
170 | or other liability obligations and/or rights consistent with this
171 | License. However, in accepting such obligations, You may act only
172 | on Your own behalf and on Your sole responsibility, not on behalf
173 | of any other Contributor, and only if You agree to indemnify,
174 | defend, and hold each Contributor harmless for any liability
175 | incurred by, or claims asserted against, such Contributor by reason
176 | of your accepting any such warranty or additional liability.
177 |
178 | END OF TERMS AND CONDITIONS
179 |
180 | APPENDIX: How to apply the Apache License to your work.
181 |
182 | To apply the Apache License to your work, attach the following
183 | boilerplate notice, with the fields enclosed by brackets "[]"
184 | replaced with your own identifying information. (Don't include
185 | the brackets!) The text should be enclosed in the appropriate
186 | comment syntax for the file format. We also recommend that a
187 | file or class name and description of purpose be included on the
188 | same "printed page" as the copyright notice for easier
189 | identification within third-party archives.
190 |
191 | Copyright 2020 SenseTime
192 |
193 | Licensed under the Apache License, Version 2.0 (the "License");
194 | you may not use this file except in compliance with the License.
195 | You may obtain a copy of the License at
196 |
197 | http://www.apache.org/licenses/LICENSE-2.0
198 |
199 | Unless required by applicable law or agreed to in writing, software
200 | distributed under the License is distributed on an "AS IS" BASIS,
201 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
202 | See the License for the specific language governing permissions and
203 | limitations under the License.
204 |
205 |
206 | DETR
207 |
208 | Copyright 2020 - present, Facebook, Inc
209 |
210 | Licensed under the Apache License, Version 2.0 (the "License");
211 | you may not use this file except in compliance with the License.
212 | You may obtain a copy of the License at
213 |
214 | http://www.apache.org/licenses/LICENSE-2.0
215 |
216 | Unless required by applicable law or agreed to in writing, software
217 | distributed under the License is distributed on an "AS IS" BASIS,
218 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
219 | See the License for the specific language governing permissions and
220 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |

3 |
4 |
5 |
6 | ## News
7 | The code of MCTrans has been released. if you are interested in contributing to the standardization of the medical image analysis community, please feel free to contact me.
8 |
9 | ## Introduction
10 |
11 | - This repository provides code for "**Multi-Compound Transformer for Accurate Biomedical Image Segmentation**" [[paper](https://arxiv.org/pdf/2106.14385.pdf)].
12 |
13 | - The MCTrans repository heavily references and uses the packages of [MMSegmentation](https://github.com/open-mmlab/mmsegmentation), [MMCV](https://github.com/open-mmlab/mmcv), and [MONAI](https://monai.io/). We thank them for their selfless contributions
14 |
15 |
16 |
17 | ## Highlights
18 |
19 | - A comprehensive toolbox for medical image segmentation, including flexible data loading, processing, modular network construction, and more.
20 |
21 | - Supports representative and popular medical image segmentation methods, e.g. UNet, UNet++, CENet, AttentionUNet, etc.
22 |
23 |
24 |
25 | ## Changelog
26 | The first version was released on 2021.7.16.
27 |
28 | ## Model Zoo
29 |
30 | Supported backbones:
31 |
32 | - [x] VGG
33 | - [x] ResNet
34 |
35 | Supported methods:
36 |
37 | - [x] UNet
38 | - [x] UNet++
39 | - [x] AttentionUNet
40 | - [x] CENet
41 | - [x] TransUNet
42 | - [x] NonLocalUNet
43 |
44 | ## Installation and Usage
45 |
46 | Please see the [guidance.md](docs/guidance.md).
47 |
48 |
49 |
50 | ## Citation
51 |
52 | If you find this project useful in your research, please consider cite:
53 |
54 | ```latex
55 | @article{ji2021multi,
56 | title={Multi-Compound Transformer for Accurate Biomedical Image Segmentation},
57 | author={Ji, Yuanfeng and Zhang, Ruimao and Wang, Huijie and Li, Zhen and Wu, Lingyun and Zhang, Shaoting and Luo, Ping},
58 | journal={arXiv preprint arXiv:2106.14385},
59 | year={2021}
60 | }
61 | ```
62 |
63 |
64 |
65 | ## Contribution
66 |
67 | I don't have a lot of time to improve the code base at this stage, so if you have some free time and are interested in contributing to the standardization of the medical image analysis community, please feel free to contact me (jyuanfeng8@gmail.com).
68 |
69 |
70 |
71 | ## License
72 |
73 | This project is released under the [Apache 2.0 license](LICENSE).
74 |
75 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/pannuke.py:
--------------------------------------------------------------------------------
1 | dataset_type = 'PanNukeDataset'
2 | patch_size = [256, 256]
3 |
4 | keyword = ["img", "seg_label"]
5 |
6 | train_transforms = [
7 | dict(type="LoadImage",
8 | keys=keyword,
9 | meta_key_postfix="metas"),
10 | dict(type="AsChannelFirst",
11 | keys=keyword[0]),
12 | dict(type="AddChannel",
13 | keys=keyword[1]),
14 | dict(type='Resize',
15 | keys=keyword,
16 | spatial_size=patch_size,
17 | mode=("bilinear", "nearest")),
18 | dict(type="ScaleIntensity",
19 | keys=keyword[0]),
20 | dict(type='RandMirror'),
21 | dict(type='ToTensor',
22 | keys=keyword)
23 | ]
24 |
25 | test_transforms = [
26 | dict(type="LoadImage",
27 | keys=keyword,
28 | meta_key_postfix="metas"),
29 | dict(type="AsChannelFirst",
30 | keys=keyword[0]),
31 | dict(type="AddChannel",
32 | keys=keyword[1]),
33 | dict(type='Resize',
34 | keys=keyword,
35 | spatial_size=patch_size,
36 | mode=("bilinear", "nearest")),
37 | dict(type="ScaleIntensity",
38 | keys=keyword[0]),
39 | dict(type='ToTensor',
40 | keys=keyword)
41 | ]
42 | data = dict(
43 | samples_per_gpu=32,
44 | workers_per_gpu=8,
45 | train=dict(
46 | type=dataset_type,
47 | img_dirs=["data/medical/pannuke/split-images-npy/0",
48 | "data/medical/pannuke/split-images-npy/1"],
49 | img_suffix=".npy",
50 | label_dirs=["data/medical/pannuke/split-masks-npy/0",
51 | "data/medical/pannuke/split-masks-npy/1"],
52 | label_suffix=".npy",
53 | phase="train",
54 | transforms=train_transforms,
55 | ),
56 | val=dict(
57 | type=dataset_type,
58 | img_dirs=["data/medical/pannuke/split-images-npy/2"],
59 | img_suffix=".npy",
60 | label_dirs=["data/medical/pannuke/split-masks-npy/2"],
61 | label_suffix=".npy",
62 | phase="val",
63 | transforms=test_transforms
64 | ),
65 | )
66 | #support data with different formats
67 | #
68 | # data = dict(
69 | # samples_per_gpu=32,
70 | # workers_per_gpu=8,
71 | # train=dict(
72 | # type=dataset_type,
73 | # img_dirs=["data/medical/pannuke/split-images/0",
74 | # "data/medical/pannuke/split-images/1"],
75 | # img_suffix=".png",
76 | # label_dirs=["data/medical/pannuke/split-masks/0",
77 | # "data/medical/pannuke/split-masks/1"],
78 | # label_suffix=".png",
79 | # phase="train",
80 | # transforms=train_transforms,
81 | # ),
82 | # val=dict(
83 | # type=dataset_type,
84 | # img_dirs=["data/medical/pannuke/split-images/2"],
85 | # img_suffix=".png",
86 | # label_dirs=["data/medical/pannuke/split-masks/2"],
87 | # label_suffix=".png",
88 | # phase="val",
89 | # transforms=test_transforms
90 | # ),
91 | # )
--------------------------------------------------------------------------------
/configs/_base_/default_runtime.py:
--------------------------------------------------------------------------------
1 | # yapf:disable
2 | log_config = dict(
3 | interval=10,
4 | hooks=[
5 | dict(type='TextLoggerHook', by_epoch=True),
6 | dict(type='TensorboardLoggerHook')
7 | ])
8 | # yapf:enable
9 | dist_params = dict(backend='nccl')
10 | log_level = 'INFO'
11 | load_from = None
12 | resume_from = None
13 | workflow = [('train', 1)]
14 | cudnn_benchmark = True
--------------------------------------------------------------------------------
/configs/_base_/models/attunet_res34_d5.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='EncoderDecoder',
3 | pretrained=None,
4 | encoder=dict(
5 | type="ResNet",
6 | depth=34,
7 | in_channels=3),
8 | decoder=dict(
9 | type="UNetDecoder",
10 | in_channels=[64, 64, 128, 256, 512],
11 | att=True
12 | ),
13 | seg_head=dict(
14 | type="BasicSegHead",
15 | in_channels=64,
16 | num_classes=6,
17 | post_trans=[dict(type="Activations", softmax=True),
18 | dict(type="AsDiscrete", argmax=True)],
19 | losses=[dict(type="DiceCELoss", softmax=True, to_onehot_y=True),
20 | dict(type="FocalLoss", to_onehot_y=True)])
21 | )
22 |
--------------------------------------------------------------------------------
/configs/_base_/models/attunet_vgg32_d5.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='EncoderDecoder',
3 | pretrained=None,
4 | encoder=dict(
5 | type="VGG",
6 | in_channel=3,
7 | init_channels=32,
8 | num_blocks=2),
9 | decoder=dict(
10 | type="UNetDecoder",
11 | in_channels=[32, 64, 128, 256, 512],
12 | att=True
13 | ),
14 | seg_head=dict(
15 | type="BasicSegHead",
16 | in_channels=32,
17 | num_classes=6,
18 | post_trans=[dict(type="Activations", softmax=True),
19 | dict(type="AsDiscrete", argmax=True)],
20 | losses=[dict(type="DiceLoss", softmax=True, to_onehot_y=True),
21 | dict(type="CrossEntropyLoss"),
22 | dict(type="FLoss")]
23 | )
24 | )
25 |
--------------------------------------------------------------------------------
/configs/_base_/models/cenet_res34_d5.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='EncoderDecoder',
3 | pretrained=None,
4 | encoder=dict(
5 | type="ResNet",
6 | depth=34,
7 | in_channels=3),
8 | center=dict(
9 | type="CEncoder",
10 | in_channels=[64, 64, 128, 256, 512]),
11 | decoder=dict(
12 | type="UNetDecoder",
13 | in_channels=[64, 64, 128, 256, 512 + 4]),
14 | seg_head=dict(
15 | type="BasicSegHead",
16 | in_channels=64,
17 | num_classes=6,
18 | post_trans=[dict(type="Activations", softmax=True),
19 | dict(type="AsDiscrete", argmax=True)],
20 | losses=[dict(type="DiceCELoss", softmax=True, to_onehot_y=True),
21 | dict(type="FocalLoss", to_onehot_y=True)])
22 | )
23 |
--------------------------------------------------------------------------------
/configs/_base_/models/cenet_vgg32_d5.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='EncoderDecoder',
3 | pretrained=None,
4 | encoder=dict(
5 | type="VGG",
6 | in_channel=3,
7 | init_channels=32,
8 | num_blocks=2),
9 | center=dict(
10 | type="CEncoder",
11 | in_channels=[32, 64, 128, 256, 512]),
12 | decoder=dict(
13 | type="UNetDecoder",
14 | in_channels=[32, 64, 128, 256, 512+4],
15 | att=True),
16 | seg_head=dict(
17 | type="BasicSegHead",
18 | in_channels=32,
19 | num_classes=6,
20 | post_trans=[dict(type="Activations", softmax=True),
21 | dict(type="AsDiscrete", argmax=True)],
22 | losses=[dict(type="DiceLoss", softmax=True, to_onehot_y=True),
23 | dict(type="CrossEntropyLoss"),
24 | dict(type="FLoss")]
25 | )
26 | )
27 |
--------------------------------------------------------------------------------
/configs/_base_/models/mctrans_res34_d5.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='EncoderDecoder',
3 | pretrained=None,
4 | encoder=dict(
5 | type="ResNet",
6 | depth=34,
7 | in_channels=3),
8 | center=dict(
9 | type="MCTrans",
10 | d_model=128,
11 | nhead=8,
12 | d_ffn=512,
13 | dropout=0.1,
14 | act="relu",
15 | n_levels=3,
16 | n_points=4,
17 | n_sa_layers=6),
18 | decoder=dict(
19 | type="UNetDecoder",
20 | in_channels=[64, 64, 128, 128, 128],
21 | ),
22 | seg_head=dict(
23 | type="BasicSegHead",
24 | in_channels=64,
25 | num_classes=6,
26 | post_trans=[dict(type="Activations", softmax=True),
27 | dict(type="AsDiscrete", argmax=True)],
28 | losses=[dict(type="DiceCELoss", softmax=True, to_onehot_y=True),
29 | dict(type="FocalLoss", to_onehot_y=True)]),
30 | aux_head=dict(
31 | type="MCTransAuxHead",
32 | d_model=128,
33 | d_ffn=512,
34 | act="relu",
35 | num_classes=6,
36 | in_channles=[64, 64, 128, 128, 128],
37 | losses=[dict(type="MCTransAuxLoss", sigmoid=True, loss_weight=0.1)]),
38 | )
39 |
--------------------------------------------------------------------------------
/configs/_base_/models/mctrans_vgg32_d5.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='EncoderDecoder',
3 | pretrained=None,
4 | encoder=dict(
5 | type="VGG",
6 | in_channel=3,
7 | init_channels=32,
8 | num_blocks=2),
9 | center=dict(
10 | type="MCTrans",
11 | d_model=128,
12 | nhead=8,
13 | d_ffn=512,
14 | dropout=0.1,
15 | act="relu",
16 | n_levels=3,
17 | n_points=4,
18 | n_sa_layers=6),
19 | decoder=dict(
20 | type="UNetDecoder",
21 | in_channels=[32, 64, 128, 128, 128],
22 | ),
23 | seg_head=dict(
24 | type="BasicSegHead",
25 | in_channels=32,
26 | num_classes=6,
27 | post_trans=[dict(type="Activations", softmax=True),
28 | dict(type="AsDiscrete", argmax=True)],
29 | losses=[dict(type="DiceLoss", softmax=True, to_onehot_y=True),
30 | dict(type="CrossEntropyLoss"),
31 | dict(type="FLoss")]),
32 | aux_head=dict(
33 | type="MCTransAuxHead",
34 | d_model=128,
35 | d_ffn=512,
36 | act="relu",
37 | num_classes=6,
38 | in_channles=[32, 64, 128, 128, 128],
39 | losses=[dict(type="MCTransAuxLoss", sigmoid=True, loss_weight=0.1)]),
40 | )
41 |
--------------------------------------------------------------------------------
/configs/_base_/models/nonlocal_res34_d5.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='EncoderDecoder',
3 | pretrained=None,
4 | encoder=dict(
5 | type="ResNet",
6 | depth=34,
7 | in_channels=3),
8 | center=dict(
9 | type="NonLocal",
10 | in_channels=[64, 64, 128, 256, 512]),
11 | decoder=dict(
12 | type="UNetDecoder",
13 | in_channels=[64, 64, 128, 256, 512]),
14 | seg_head=dict(
15 | type="BasicSegHead",
16 | in_channels=64,
17 | num_classes=6,
18 | post_trans=[dict(type="Activations", softmax=True),
19 | dict(type="AsDiscrete", argmax=True)],
20 | losses=[dict(type="DiceCELoss", softmax=True, to_onehot_y=True),
21 | dict(type="FocalLoss", to_onehot_y=True)])
22 | )
23 |
--------------------------------------------------------------------------------
/configs/_base_/models/nonlocal_vgg32_d5.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='EncoderDecoder',
3 | pretrained=None,
4 | encoder=dict(
5 | type="VGG",
6 | in_channel=3,
7 | init_channels=32,
8 | num_blocks=2),
9 | center=dict(
10 | type="NonLocal",
11 | in_channels=[32, 64, 128, 256, 512]),
12 | decoder=dict(
13 | type="UNetDecoder",
14 | in_channels=[32, 64, 128, 256, 512]),
15 | seg_head=dict(
16 | type="BasicSegHead",
17 | in_channels=32,
18 | num_classes=6,
19 | post_trans=[dict(type="Activations", softmax=True),
20 | dict(type="AsDiscrete", argmax=True)],
21 | losses=[dict(type="DiceLoss", softmax=True, to_onehot_y=True),
22 | dict(type="CrossEntropyLoss"),
23 | dict(type="FLoss")])
24 | )
25 |
--------------------------------------------------------------------------------
/configs/_base_/models/transunet_res34_d5.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='EncoderDecoder',
3 | pretrained=None,
4 | encoder=dict(
5 | type="ResNet",
6 | depth=34,
7 | in_channels=3),
8 | center=dict(
9 | type="Vit",
10 | input_size=(512, 16, 16)),
11 | decoder=dict(
12 | type="UNetDecoder",
13 | in_channels=[64, 64, 128, 256, 512]),
14 | seg_head=dict(
15 | type="BasicSegHead",
16 | in_channels=64,
17 | num_classes=6,
18 | post_trans=[dict(type="Activations", softmax=True),
19 | dict(type="AsDiscrete", argmax=True)],
20 | losses=[dict(type="DiceCELoss", softmax=True, to_onehot_y=True),
21 | dict(type="FocalLoss", to_onehot_y=True)])
22 | )
23 |
--------------------------------------------------------------------------------
/configs/_base_/models/transunet_vgg32_d5.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='EncoderDecoder',
3 | pretrained=None,
4 | encoder=dict(
5 | type="VGG",
6 | in_channel=3,
7 | init_channels=32,
8 | num_blocks=2),
9 | center=dict(
10 | type="Vit",
11 | input_size=(512, 16, 16)),
12 | decoder=dict(
13 | type="UNetDecoder",
14 | in_channels=[32, 64, 128, 256, 512]),
15 | seg_head=dict(
16 | type="BasicSegHead",
17 | in_channels=32,
18 | num_classes=6,
19 | post_trans=[dict(type="Activations", softmax=True),
20 | dict(type="AsDiscrete", argmax=True)],
21 | losses=[dict(type="DiceLoss", softmax=True, to_onehot_y=True),
22 | dict(type="CrossEntropyLoss"),
23 | dict(type="FLoss")])
24 | )
25 |
26 |
--------------------------------------------------------------------------------
/configs/_base_/models/unet++_res34_d5.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='EncoderDecoder',
3 | pretrained=None,
4 | encoder=dict(
5 | type="ResNet",
6 | depth=34,
7 | in_channels=3,
8 | out_indices=(0, 1, 2, 3, 4)),
9 | decoder=dict(
10 | type="UNetPlusPlusDecoder",
11 | in_channels=[64, 64, 128, 256, 512],
12 | ),
13 | seg_head=dict(
14 | type="BasicSegHead",
15 | in_channels=64,
16 | num_classes=6,
17 | post_trans=[dict(type="Activations", softmax=True),
18 | dict(type="AsDiscrete", argmax=True)],
19 | losses=[dict(type="DiceCELoss", softmax=True, to_onehot_y=True),
20 | dict(type="FocalLoss", to_onehot_y=True)])
21 | )
22 |
--------------------------------------------------------------------------------
/configs/_base_/models/unet++_vgg32_d5.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='EncoderDecoder',
3 | pretrained=None,
4 | encoder=dict(
5 | type="VGG",
6 | in_channel=3,
7 | init_channels=32,
8 | num_blocks=2),
9 | decoder=dict(
10 | type="UNetPlusPlusDecoder",
11 | in_channels=[32, 64, 128, 256, 512],
12 | ),
13 | seg_head=dict(
14 | type="BasicSegHead",
15 | in_channels=32,
16 | num_classes=6,
17 | post_trans=[dict(type="Activations", softmax=True),
18 | dict(type="AsDiscrete", argmax=True)],
19 | losses=[dict(type="DiceLoss", softmax=True, to_onehot_y=True),
20 | dict(type="CrossEntropyLoss"),
21 | dict(type="FLoss")])
22 | )
23 |
--------------------------------------------------------------------------------
/configs/_base_/models/unet_res34_d5.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='EncoderDecoder',
3 | pretrained=None,
4 | encoder=dict(
5 | type="ResNet",
6 | depth=34,
7 | in_channels=3),
8 | decoder=dict(
9 | type="UNetDecoder",
10 | in_channels=[64, 64, 128, 256, 512],
11 | ),
12 | seg_head=dict(
13 | type="BasicSegHead",
14 | in_channels=64,
15 | num_classes=6,
16 | post_trans=[dict(type="Activations", softmax=True),
17 | dict(type="AsDiscrete", argmax=True)],
18 | losses=[dict(type="DiceCELoss", softmax=True, to_onehot_y=True),
19 | dict(type="FocalLoss", to_onehot_y=True)])
20 | )
21 |
--------------------------------------------------------------------------------
/configs/_base_/models/unet_vgg32_d5.py:
--------------------------------------------------------------------------------
1 | model = dict(
2 | type='EncoderDecoder',
3 | pretrained=None,
4 | encoder=dict(
5 | type="VGG",
6 | in_channel=3,
7 | init_channels=32,
8 | num_blocks=2),
9 | decoder=dict(
10 | type="UNetDecoder",
11 | in_channels=[32, 64, 128, 256, 512],
12 | ),
13 | seg_head=dict(
14 | type="BasicSegHead",
15 | in_channels=32,
16 | num_classes=6,
17 | post_trans=[dict(type="Activations", softmax=True),
18 | dict(type="AsDiscrete", argmax=True)],
19 | losses=[dict(type="DiceLoss", softmax=True, to_onehot_y=True),
20 | dict(type="CrossEntropyLoss"),
21 | dict(type="FLoss")]
22 | )
23 | )
24 |
--------------------------------------------------------------------------------
/configs/_base_/schedules/pannuke_bs32_ep400.py:
--------------------------------------------------------------------------------
1 | # optimizer
2 | optimizer = dict(type='Adam', lr=0.0003, betas=[0.9, 0.99])
3 | optimizer_config = dict()
4 | # learning policy
5 | lr_config = dict(policy='CosineAnnealing', min_lr=0.0)
6 | # runtime settings
7 | max_epochs = 400
8 | runner = dict(type='EpochBasedRunner', max_epochs=max_epochs)
9 | checkpoint_config = dict(by_epoch=True, interval=5)
10 | evaluation = dict(interval=1, save_best="mDice")
--------------------------------------------------------------------------------
/configs/pannuke-vgg32/attunet_vgg32_d5_256x256_400ep_pannuke.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/models/attunet_vgg32_d5.py', '../_base_/datasets/pannuke.py',
3 | '../_base_/default_runtime.py', '../_base_/schedules/pannuke_bs32_ep400.py'
4 | ]
5 |
--------------------------------------------------------------------------------
/configs/pannuke-vgg32/cenet_vgg32_d5_256x256_400ep_pannuke.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/models/cenet_vgg32_d5.py', '../_base_/datasets/pannuke.py',
3 | '../_base_/default_runtime.py', '../_base_/schedules/pannuke_bs32_ep400.py'
4 | ]
5 |
--------------------------------------------------------------------------------
/configs/pannuke-vgg32/mctrans_vgg32_d5_256x256_400ep_pannuke.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/models/mctrans_vgg32_d5.py', '../_base_/datasets/pannuke.py',
3 | '../_base_/default_runtime.py', '../_base_/schedules/pannuke_bs32_ep400.py'
4 | ]
5 |
--------------------------------------------------------------------------------
/configs/pannuke-vgg32/nonlocal_vgg32_d5_256x256_400ep_pannuke.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/models/nonlocal_vgg32_d5.py', '../_base_/datasets/pannuke.py',
3 | '../_base_/default_runtime.py', '../_base_/schedules/pannuke_bs32_ep400.py'
4 | ]
5 |
--------------------------------------------------------------------------------
/configs/pannuke-vgg32/transunet_vgg32_d5_256x256_400ep_pannuke.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/models/transunet_vgg32_d5.py', '../_base_/datasets/pannuke.py',
3 | '../_base_/default_runtime.py', '../_base_/schedules/pannuke_bs32_ep400.py'
4 | ]
5 |
--------------------------------------------------------------------------------
/configs/pannuke-vgg32/unet++_vgg32_d5_256x256_400ep_pannuke.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/models/unet++_vgg32_d5.py', '../_base_/datasets/pannuke.py',
3 | '../_base_/default_runtime.py', '../_base_/schedules/pannuke_bs32_ep400.py'
4 | ]
5 |
--------------------------------------------------------------------------------
/configs/pannuke-vgg32/unet_vgg32_d5_256x256_400ep_pannuke.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/models/unet_vgg32_d5.py', '../_base_/datasets/pannuke.py',
3 | '../_base_/default_runtime.py', '../_base_/schedules/pannuke_bs32_ep400.py'
4 | ]
5 |
--------------------------------------------------------------------------------
/docs/guidance.md:
--------------------------------------------------------------------------------
1 | ## Installation
2 |
3 | a. Create a conda virtual environment and install required packages.
4 |
5 | ```shell
6 | conda create -n mctrans pip python=3.7
7 | conda activate mctrans
8 | git clone https://github.com/JiYuanFeng/MCTrans.git
9 | cd MCTrans
10 | python setup.py develop
11 | pip install -r requirements.txt
12 | ```
13 |
14 | a. Complie other CUDA operators such as [MultiScaleDeformableAttention](https://github.com/fundamentalvision/Deformable-DETR).
15 |
16 | ```shell
17 | cd mctrans/models/ops/
18 | bash make.sh
19 | ```
20 |
21 | c. Create data folder under the MCTrans and link the actual dataset path ($DATA_ROOT).
22 |
23 | ```shell
24 | mkdir data
25 | ln -s $DATA_ROOT data
26 | ```
27 |
28 |
29 |
30 | ## Datasets Preparation
31 |
32 | - It is recommended to you to convert your dataset (espeacial the label) to standard format. For example, The binary segmengtaion label shoule only contain `0,1` or `0,255`.
33 |
34 | - If your folder structure is different, you may need to change the corresponding paths in config files.
35 |
36 | - We have upload some preprocessed datasets at [drive](https://drive.google.com/file/d/1mcD7Grx2bUQhAL9ClTrCtKv6FyX03Ehd/view?usp=sharing), you can download and unpack them under the data folder.
37 |
38 | ```none
39 | MCTrans
40 | ├── mctrans
41 | ├── data
42 | │ ├── pannuke
43 | │ │ ├── split-images
44 | │ │ ├── split-masks
45 | │ │ ├── split-images-npy
46 | │ │ ├── split-masks-npy
47 | │ ├── cvc-clinic
48 | │ │ ├── images
49 | │ │ ├── masks
50 | │ ├── cvc-colondb
51 | │ │ ├── images
52 | │ │ ├── masks
53 | │ ├── kvasir
54 | │ │ ├── images
55 | │ │ ├── masks
56 | ```
57 |
58 | ## Single GPU Training
59 | ```shell
60 | bash tools/train.sh
61 | ```
62 |
63 | ## Multi GPU Training
64 |
65 | ```none
66 | TO DO
67 | ```
68 |
69 |
--------------------------------------------------------------------------------
/imgs/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiYuanFeng/MCTrans/9b8b5677eef584b423d5e1630680a4b667cbe823/imgs/logo.png
--------------------------------------------------------------------------------
/imgs/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiYuanFeng/MCTrans/9b8b5677eef584b423d5e1630680a4b667cbe823/imgs/overview.png
--------------------------------------------------------------------------------
/mctrans/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiYuanFeng/MCTrans/9b8b5677eef584b423d5e1630680a4b667cbe823/mctrans/__init__.py
--------------------------------------------------------------------------------
/mctrans/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .datasets import *
2 | from .builder import build_dataset, build_dataloader, DATASETS, TRANSFORMS
3 |
--------------------------------------------------------------------------------
/mctrans/data/builder.py:
--------------------------------------------------------------------------------
1 | import platform
2 | import random
3 | from collections import Sequence, Mapping
4 | from functools import partial
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 | from mmcv.parallel import collate, DataContainer
10 | from mmcv.runner import get_dist_info
11 | from mmcv.utils import Registry, build_from_cfg
12 | from mmcv.utils.parrots_wrapper import DataLoader, PoolDataLoader
13 | from torch.utils.data import DistributedSampler
14 |
15 | if platform.system() != 'Windows':
16 | # https://github.com/pytorch/pytorch/issues/973
17 | import resource
18 |
19 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
20 | hard_limit = rlimit[1]
21 | soft_limit = min(4096, hard_limit)
22 |
23 | resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit))
24 |
25 | DATASETS = Registry('dataset')
26 | TRANSFORMS = Registry('transform')
27 |
28 |
29 | def collate_fn(batch, samples_per_gpu=1):
30 | """Puts each data field into a tensor/DataContainer with outer dimension
31 | batch size.
32 |
33 | Extend default_collate to add support for
34 | :type:`~mmcv.parallel.DataContainer`. There are 3 cases.
35 |
36 | 1. cpu_only = True, e.g., meta data
37 | 2. cpu_only = False, stack = True, e.g., images tensors
38 | 3. cpu_only = False, stack = False, e.g., gt bboxes
39 | """
40 | if not isinstance(batch, Sequence):
41 | raise TypeError(f'{batch.dtype} is not supported.')
42 |
43 | if isinstance(batch[0], list):
44 | batch = [item for _ in batch for item in _]
45 |
46 | if isinstance(batch[0], DataContainer):
47 | assert len(batch) % samples_per_gpu == 0
48 | stacked = []
49 | if batch[0].cpu_only:
50 | for i in range(0, len(batch), samples_per_gpu):
51 | stacked.append(
52 | [sample.data for sample in batch[i:i + samples_per_gpu]])
53 | return DataContainer(
54 | stacked, batch[0].stack, batch[0].padding_value, cpu_only=True)
55 | elif batch[0].stack:
56 | for i in range(0, len(batch), samples_per_gpu):
57 | assert isinstance(batch[i].data, torch.Tensor)
58 |
59 | if batch[i].pad_dims is not None:
60 | ndim = batch[i].dim()
61 | assert ndim > batch[i].pad_dims
62 | max_shape = [0 for _ in range(batch[i].pad_dims)]
63 | for dim in range(1, batch[i].pad_dims + 1):
64 | max_shape[dim - 1] = batch[i].size(-dim)
65 | for sample in batch[i:i + samples_per_gpu]:
66 | for dim in range(0, ndim - batch[i].pad_dims):
67 | assert batch[i].size(dim) == sample.size(dim)
68 | for dim in range(1, batch[i].pad_dims + 1):
69 | max_shape[dim - 1] = max(max_shape[dim - 1],
70 | sample.size(-dim))
71 | padded_samples = []
72 | for sample in batch[i:i + samples_per_gpu]:
73 | pad = [0 for _ in range(batch[i].pad_dims * 2)]
74 | for dim in range(1, batch[i].pad_dims + 1):
75 | pad[2 * dim -
76 | 1] = max_shape[dim - 1] - sample.size(-dim)
77 | padded_samples.append(
78 | F.pad(
79 | sample.data, pad, value=sample.padding_value))
80 | stacked.append(collate(padded_samples))
81 | elif batch[i].pad_dims is None:
82 | stacked.append(
83 | collate([
84 | sample.data
85 | for sample in batch[i:i + samples_per_gpu]
86 | ]))
87 | else:
88 | raise ValueError(
89 | 'pad_dims should be either None or integers (1-3)')
90 |
91 | else:
92 | for i in range(0, len(batch), samples_per_gpu):
93 | stacked.append(
94 | [sample.data for sample in batch[i:i + samples_per_gpu]])
95 | return DataContainer(stacked, batch[0].stack, batch[0].padding_value)
96 | elif isinstance(batch[0], Sequence):
97 | transposed = zip(*batch)
98 | return [collate(samples, samples_per_gpu) for samples in transposed]
99 |
100 | elif isinstance(batch[0], Mapping):
101 | res = dict()
102 | for key in batch[0]:
103 | if isinstance(batch[0][key], torch.Tensor):
104 | res.update({key: collate([d[key] for d in batch], samples_per_gpu)})
105 | else:
106 | res.update({key: [d[key] for d in batch]})
107 |
108 | return res
109 | # return {
110 | # key: collate([d[key] for d in batch], samples_per_gpu)
111 | # for key in batch[0]
112 | # }
113 | else:
114 | return collate(batch)
115 |
116 |
117 | def worker_init_fn(worker_id, num_workers, rank, seed):
118 | """Worker init func for dataloader.
119 |
120 | The seed of each worker equals to num_worker * rank + worker_id + user_seed
121 |
122 | Args:
123 | worker_id (int): Worker id.
124 | num_workers (int): Number of workers.
125 | rank (int): The rank of current process.
126 | seed (int): The random seed to use.
127 | """
128 |
129 | worker_seed = num_workers * rank + worker_id + seed
130 | np.random.seed(worker_seed)
131 | random.seed(worker_seed)
132 |
133 |
134 | def build_dataset(cfg, default_args=None):
135 | """Build datasets."""
136 | dataset = build_from_cfg(cfg, DATASETS, default_args)
137 | return dataset
138 |
139 |
140 | def build_dataloader(dataset,
141 | samples_per_gpu,
142 | workers_per_gpu,
143 | num_gpus=1,
144 | dist=True,
145 | shuffle=True,
146 | seed=None,
147 | drop_last=False,
148 | pin_memory=True,
149 | dataloader_type='PoolDataLoader',
150 | **kwargs):
151 | """Build PyTorch DataLoader.
152 |
153 | In distributed training, each GPU/process has a dataloader.
154 | In non-distributed training, there is only one dataloader for all GPUs.
155 |
156 | Args:
157 | dataset (Dataset): A PyTorch dataset.
158 | samples_per_gpu (int): Number of training samples on each GPU, i.e.,
159 | batch size of each GPU.
160 | workers_per_gpu (int): How many subprocesses to use for data loading
161 | for each GPU.
162 | num_gpus (int): Number of GPUs. Only used in non-distributed training.
163 | dist (bool): Distributed training/test or not. Default: True.
164 | shuffle (bool): Whether to shuffle the data at every epoch.
165 | Default: True.
166 | seed (int | None): Seed to be used. Default: None.
167 | drop_last (bool): Whether to drop the last incomplete batch in epoch.
168 | Default: False
169 | pin_memory (bool): Whether to use pin_memory in DataLoader.
170 | Default: True
171 | dataloader_type (str): Type of dataloader. Default: 'PoolDataLoader'
172 | kwargs: any keyword argument to be used to initialize DataLoader
173 |
174 | Returns:
175 | DataLoader: A PyTorch dataloader.
176 | """
177 | rank, world_size = get_dist_info()
178 | if dist:
179 | sampler = DistributedSampler(
180 | dataset, world_size, rank, shuffle=shuffle)
181 | shuffle = False
182 | batch_size = samples_per_gpu
183 | num_workers = workers_per_gpu
184 | else:
185 | sampler = None
186 | batch_size = num_gpus * samples_per_gpu
187 | num_workers = num_gpus * workers_per_gpu
188 |
189 | init_fn = partial(
190 | worker_init_fn, num_workers=num_workers, rank=rank,
191 | seed=seed) if seed is not None else None
192 |
193 | assert dataloader_type in (
194 | 'DataLoader',
195 | 'PoolDataLoader'), f'unsupported dataloader {dataloader_type}'
196 |
197 | if dataloader_type == 'PoolDataLoader':
198 | dataloader = PoolDataLoader
199 | elif dataloader_type == 'DataLoader':
200 | dataloader = DataLoader
201 |
202 | data_loader = dataloader(
203 | dataset,
204 | batch_size=batch_size,
205 | sampler=sampler,
206 | num_workers=num_workers,
207 | collate_fn=partial(collate_fn, samples_per_gpu=samples_per_gpu),
208 | pin_memory=pin_memory,
209 | shuffle=shuffle,
210 | worker_init_fn=init_fn,
211 | drop_last=drop_last,
212 | **kwargs)
213 |
214 | return data_loader
215 |
--------------------------------------------------------------------------------
/mctrans/data/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import BaseDataset
2 | from .polyp import LesionDataset
3 | from .pannuke import PanNukeDataset
4 |
5 | __all__ = ["BaseDataset", "LesionDataset", "PanNukeDataset"]
6 |
--------------------------------------------------------------------------------
/mctrans/data/datasets/base.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | from functools import reduce
4 | from collections import OrderedDict
5 |
6 | import mmcv
7 | import monai
8 | from mmcv import print_log
9 | import numpy as np
10 |
11 | from prettytable import PrettyTable
12 | from sklearn.model_selection import KFold
13 | from torch.utils.data import Dataset
14 |
15 | from ..builder import DATASETS
16 | from ..transforms import build_transforms
17 | from ...metrics.base import eval_metrics
18 | from ...utils import get_root_logger
19 |
20 |
21 | @DATASETS.register_module()
22 | class BaseDataset(Dataset):
23 | """ Custom Dataset for medical image segmentation"""
24 | CLASSES = None
25 | PALETTE = None
26 |
27 | def __init__(self,
28 | transforms,
29 | img_dirs,
30 | img_suffix=".jpg",
31 | label_dirs=None,
32 | label_suffix=".png",
33 | phase=None,
34 | cross_valid=False,
35 | fold_idx=0,
36 | fold_nums=5,
37 | data_root=None,
38 | ignore_index=None,
39 | binnary_label=False,
40 | exclude_backgroud=True,
41 | label_map=None
42 | ):
43 | self.transforms = build_transforms(transforms)
44 |
45 | self.img_dirs = img_dirs if isinstance(img_dirs, (list, tuple)) else [img_dirs]
46 | self.label_dirs = label_dirs if isinstance(label_dirs, (list, tuple)) else [label_dirs]
47 | self.img_suffix = img_suffix
48 | self.label_suffix = label_suffix
49 | self.phase = phase
50 | self.cross_valid = cross_valid
51 | self.fold_idx = fold_idx
52 | self.fold_nums = fold_nums
53 | self.data_root = data_root
54 | self.label_map = label_map
55 | self.binary_label = binnary_label
56 | self.ignore_index = ignore_index
57 | self.exclude_backgroud = exclude_backgroud
58 | self.data_list = self.generate_data_list(self.img_dirs, self.img_suffix, self.label_dirs, self.label_suffix,
59 | self.cross_valid, self.phase, self.fold_idx, self.fold_nums)
60 |
61 | def generate_data_list(self,
62 | img_dirs,
63 | img_suffix,
64 | label_dirs,
65 | label_suffix,
66 | cross_valid=False,
67 | phase="Train",
68 | fold_idx=0,
69 | fold_nums=5,
70 | img_key="img",
71 | label_key="seg_label"):
72 |
73 | if label_dirs is not None:
74 | assert len(img_dirs) == len(label_dirs)
75 |
76 | data_list = []
77 | for idx, img_dir in enumerate(img_dirs):
78 | for img in mmcv.scandir(img_dir, img_suffix, recursive=True):
79 | data_info = {}
80 | data_info[img_key] = osp.join(img_dir, img)
81 | if label_dirs is not None:
82 | label = img.replace(img_suffix, label_suffix)
83 | data_info[label_key] = osp.join(label_dirs[idx], label)
84 | data_list.append(data_info)
85 |
86 | if cross_valid:
87 | assert isinstance(fold_idx, int) and isinstance(fold_nums, int)
88 | splits = []
89 | kfold = KFold(n_splits=fold_nums, shuffle=True)
90 |
91 | for tr_idx, te_idx in kfold.split(data_list):
92 | splits.append(dict())
93 | splits[-1]['train'] = [item for idx, item in enumerate(data_list) if idx in tr_idx]
94 | splits[-1]['val'] = [item for idx, item in enumerate(data_list) if idx in te_idx]
95 | data_list = splits[fold_idx][phase]
96 |
97 | print_log("Phase {} : Loaded {} images".format(phase, len(data_list)), logger=get_root_logger())
98 |
99 | return data_list
100 |
101 | def get_gt_seg_maps(self):
102 | """Get ground truth segmentation maps for evaluation."""
103 | gt_seg_maps = []
104 | reader = monai.transforms.LoadImage(image_only=True)
105 | for img_info in self.data_list:
106 | seg_map = osp.join(img_info['seg_label'])
107 | # gt_seg_map = mmcv.imread(
108 | # seg_map, flag='unchanged', backend='pillow')
109 | gt_seg_map = reader(seg_map)
110 | # binary the mask if need
111 | if self.binary_label:
112 | gt_seg_map[gt_seg_map > 0] = 255
113 | # modify if custom classes
114 | if self.label_map is not None:
115 | for old_id, new_id in self.label_map.items():
116 | gt_seg_map[gt_seg_map == old_id] = new_id
117 | gt_seg_maps.append(gt_seg_map)
118 |
119 | return gt_seg_maps
120 |
121 | def evaluate(self,
122 | results,
123 | metric=['mDice', "mIoU", "mFscore", "mHd95"],
124 | logger=None,
125 | **kwargs):
126 | """Evaluate the dataset.
127 | Args:
128 | results (list): Testing results of the dataset.
129 | metric (str | list[str]): Metrics to be evaluated. 'mIoU',
130 | 'mDice' and 'mFscore' are supported.
131 | logger (logging.Logger | None | str): Logger used for printing
132 | related information during evaluation. Default: None.
133 | Returns:
134 | dict[str, float]: Default metrics.
135 | """
136 |
137 | if isinstance(metric, str):
138 | metric = [metric]
139 | allowed_metrics = ['mIoU', 'mDice', 'mFscore', "mHd95"]
140 | if not set(metric).issubset(set(allowed_metrics)):
141 | raise KeyError('metric {} is not supported'.format(metric))
142 | eval_results = {}
143 | gt_seg_maps = self.get_gt_seg_maps()
144 |
145 | if self.CLASSES is None:
146 | num_classes = len(
147 | reduce(np.union1d, [np.unique(_) for _ in gt_seg_maps]))
148 | else:
149 | num_classes = len(self.CLASSES)
150 |
151 | ret_metrics = eval_metrics(
152 | results,
153 | gt_seg_maps,
154 | num_classes,
155 | self.ignore_index,
156 | metric)
157 |
158 | if self.CLASSES is None:
159 | class_names = tuple(range(num_classes))
160 | else:
161 | class_names = self.CLASSES
162 |
163 | # each class table
164 | ret_metrics.pop('aAcc', None)
165 | ret_metrics_class = OrderedDict({
166 | ret_metric: np.round(ret_metric_value * 100, 2)
167 | for ret_metric, ret_metric_value in ret_metrics.items()
168 | })
169 | ret_metrics_class.update({'Class': class_names})
170 | ret_metrics_class.move_to_end('Class', last=False)
171 |
172 | # summary table
173 | # exclude some ignore idx
174 | if self.exclude_backgroud:
175 | for ret_metric, ret_metric_value in ret_metrics.items():
176 | ret_metrics[ret_metric] = ret_metric_value[1:]
177 |
178 | ret_metrics_summary = OrderedDict({
179 | ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2)
180 | for ret_metric, ret_metric_value in ret_metrics.items()
181 | })
182 |
183 | # for logger
184 | class_table_data = PrettyTable()
185 | for key, val in ret_metrics_class.items():
186 | class_table_data.add_column(key, val)
187 |
188 | summary_table_data = PrettyTable()
189 | for key, val in ret_metrics_summary.items():
190 | if key == 'aAcc':
191 | summary_table_data.add_column(key, [val])
192 | else:
193 | summary_table_data.add_column('m' + key, [val])
194 |
195 | print_log('per class results:', logger)
196 | print_log('\n' + class_table_data.get_string(), logger=logger)
197 | print_log('Summary:', logger)
198 | print_log('\n' + summary_table_data.get_string(), logger=logger)
199 |
200 | # each metric dict
201 | for key, value in ret_metrics_summary.items():
202 | if key == 'aAcc':
203 | eval_results[key] = value / 100.0
204 | else:
205 | eval_results['m' + key] = value / 100.0
206 |
207 | ret_metrics_class.pop('Class', None)
208 | for key, value in ret_metrics_class.items():
209 | eval_results.update({
210 | key + '.' + str(name): value[idx] / 100.0
211 | for idx, name in enumerate(class_names)
212 | })
213 |
214 | if mmcv.is_list_of(results, str):
215 | for file_name in results:
216 | os.remove(file_name)
217 | return eval_results
218 |
219 | def __len__(self):
220 | """Total number of samples of data."""
221 | return len(self.data_list)
222 |
223 | def __getitem__(self, idx):
224 | data = self.data_list[idx]
225 | data = self.transforms(data)
226 | return data
227 |
--------------------------------------------------------------------------------
/mctrans/data/datasets/pannuke.py:
--------------------------------------------------------------------------------
1 | from .base import BaseDataset
2 | from ..builder import DATASETS
3 |
4 | @DATASETS.register_module()
5 | class PanNukeDataset(BaseDataset):
6 | CLASSES = ('Background', 'Neoplastic', "Inflammatory", "Connective", "Dead", "Non-Neoplastic Epithelial")
7 | def __init__(self, **kwargs):
8 | super(PanNukeDataset, self).__init__(**kwargs)
9 |
--------------------------------------------------------------------------------
/mctrans/data/datasets/polyp.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | from .base import BaseDataset
3 | from ..builder import DATASETS
4 |
5 |
6 | @DATASETS.register_module()
7 | class LesionDataset(BaseDataset):
8 | CLASSES = ('background', 'lesion')
9 | PALETTE = [[120, 120, 120], [180, 120, 120]]
10 |
11 | def __init__(self, **kwargs):
12 | super(LesionDataset, self).__init__(
13 | label_map={255: 1}, **kwargs)
14 | assert osp.exists(self.img_dir) and self.phase is not None
15 |
--------------------------------------------------------------------------------
/mctrans/data/transforms/__init__.py:
--------------------------------------------------------------------------------
1 | from .monai import *
2 | from .base import BinrayLabel, build_transforms
3 |
--------------------------------------------------------------------------------
/mctrans/data/transforms/base.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from typing import Mapping, Hashable, Dict
3 |
4 | from mmcv.utils import build_from_cfg
5 | from monai.config import KeysCollection
6 | from monai.transforms import apply_transform, MapTransform
7 |
8 | from ..builder import TRANSFORMS
9 |
10 |
11 | class Compose(object):
12 | """Composes several transforms together.
13 | Args:
14 | transforms (list of ``Transform`` objects): list of transforms to compose.
15 | Example:
16 | >>> transforms.Compose([
17 | >>> transforms.CenterCrop(10),
18 | >>> transforms.ToTensor(),
19 | >>> ])
20 | """
21 |
22 | def __init__(self, transforms):
23 | self.transforms = transforms
24 |
25 | def __call__(self, data):
26 | for t in self.transforms:
27 | data = apply_transform(t, data)
28 | return data
29 |
30 | def __repr__(self):
31 | format_string = self.__class__.__name__ + '('
32 | for t in self.transforms:
33 | format_string += '\n'
34 | format_string += ' {0}'.format(t)
35 | format_string += '\n)'
36 | return format_string
37 |
38 |
39 | def build_transforms(cfg):
40 | """Build a transformer.
41 |
42 | Args:
43 | cfg (dict, list[dict]): The config of tranforms, is is either a dict
44 | or a list of configs.
45 | Returns:
46 | nn.Module: A built nn module.
47 | """
48 | if isinstance(cfg, list):
49 | transforms = [
50 | build_from_cfg(cfg_, TRANSFORMS) for cfg_ in cfg
51 | ]
52 | return Compose(transforms)
53 | else:
54 | return build_from_cfg(cfg, TRANSFORMS)
55 |
56 |
57 | @TRANSFORMS.register_module()
58 | class BinrayLabel(MapTransform):
59 | def __init__(self,
60 | keys: KeysCollection,
61 | allow_missing_keys: bool = False,
62 | ) -> None:
63 | super().__init__(keys, allow_missing_keys)
64 |
65 | def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
66 | d = dict(data)
67 | for key in self.key_iterator(d):
68 | d[key][d[key] > 0] = 1
69 | return d
70 |
--------------------------------------------------------------------------------
/mctrans/data/transforms/monai.py:
--------------------------------------------------------------------------------
1 | from monai.transforms import RandSpatialCropd, SpatialPadd, CropForegroundd, RandCropByPosNegLabeld, Transform
2 | from monai.transforms import Resized, Spacingd, Orientationd, RandRotated, RandZoomd, RandAxisFlipd, RandAffined
3 | from monai.transforms import NormalizeIntensityd, ScaleIntensityRanged, ScaleIntensityd
4 | from monai.transforms import AddChanneld, AsChannelFirstd, MapLabelValued, Lambdad, ToTensord
5 | from monai.transforms import LoadImaged, RemoveRepeatedChanneld
6 | from monai.transforms import Activations, AsDiscrete
7 |
8 | import numpy as np
9 | from ..builder import TRANSFORMS
10 |
11 | # cropped
12 | TRANSFORMS.register_module(name="RandCrop", module=RandSpatialCropd)
13 | TRANSFORMS.register_module(name="SpatialPad", module=SpatialPadd)
14 | TRANSFORMS.register_module(name="CropForeground", module=CropForegroundd)
15 | TRANSFORMS.register_module(name="RandCropByPosNegLabel", module=RandCropByPosNegLabeld)
16 | # spatial
17 | TRANSFORMS.register_module(name='Resize', module=Resized)
18 | TRANSFORMS.register_module(name="Spacing", module=Spacingd)
19 | TRANSFORMS.register_module(name="Orientation", module=Orientationd)
20 | TRANSFORMS.register_module(name="RandRotate", module=RandRotated)
21 | TRANSFORMS.register_module(name="RandZoom", module=RandZoomd)
22 | TRANSFORMS.register_module(name="RandAxisFlip", module=RandAxisFlipd)
23 | TRANSFORMS.register_module(name="RandAffine", module=RandAffined)
24 | # intensity
25 | TRANSFORMS.register_module(name='NormalizeIntensity', module=NormalizeIntensityd)
26 | TRANSFORMS.register_module(name='ScaleIntensityRange', module=ScaleIntensityRanged)
27 | TRANSFORMS.register_module(name='ScaleIntensity', module=ScaleIntensityd)
28 | # utility
29 | TRANSFORMS.register_module(name="AddChannel", module=AddChanneld)
30 | TRANSFORMS.register_module(name="AsChannelFirst", module=AsChannelFirstd)
31 | TRANSFORMS.register_module(name="MapLabelValue", module=MapLabelValued)
32 | TRANSFORMS.register_module(name="Lambda", module=Lambdad)
33 | TRANSFORMS.register_module(name="ToTensor", module=ToTensord)
34 | # io
35 | TRANSFORMS.register_module(name="LoadImage", module=LoadImaged)
36 | TRANSFORMS.register_module(name="RemoveRepeatedChannel", module=RemoveRepeatedChanneld)
37 | # post-process
38 | TRANSFORMS.register_module(name="Activations", module=Activations)
39 | TRANSFORMS.register_module(name="AsDiscrete", module=AsDiscrete)
40 |
41 |
42 | @TRANSFORMS.register_module()
43 | class RandMirror(Transform):
44 | """ Mirror the data randomly along each specified axis according to the probability.
45 |
46 | Args:
47 | axis(None or int or tuple of ints): Along which axis to flip.
48 | prob(float): Probability of flipping.
49 | """
50 |
51 | def __init__(self, axis=(0, 1, 2), prob=0.5, image_key='img', label_key='seg_label'):
52 | self.axis = axis
53 | self.prob = prob
54 | self.label_key = label_key
55 | self.image_key = image_key
56 |
57 | def __call__(self, data):
58 | data = dict(data)
59 | image = data[self.image_key]
60 | seg_label = data[self.label_key] if self.label_key else None
61 | image, seg_label = self.augment_mirroring(image, seg_label)
62 |
63 | data[self.image_key] = image
64 |
65 | if self.label_key is not None:
66 | data[self.label_key] = seg_label
67 |
68 | return data
69 |
70 | def augment_mirroring(self, image, seg_label=None):
71 | if (len(image.shape) != 3) and (len(image.shape) != 4):
72 | raise Exception(
73 | "Invalid dimension for sample_data and sample_seg. sample_data and sample_seg should be either "
74 | "[channels, x, y] or [channels, x, y, z]")
75 | if 0 in self.axis and np.random.uniform() < self.prob:
76 | image[:, :] = image[:, ::-1]
77 | if seg_label is not None:
78 | seg_label[:, :] = seg_label[:, ::-1]
79 | if 1 in self.axis and np.random.uniform() < self.prob:
80 | image[:, :, :] = image[:, :, ::-1]
81 | if seg_label is not None:
82 | seg_label[:, :, :] = seg_label[:, :, ::-1]
83 | if 2 in self.axis and len(image.shape) == 4:
84 | if np.random.uniform() < self.prob:
85 | image[:, :, :, :] = image[:, :, :, ::-1]
86 | if seg_label is not None:
87 | seg_label[:, :, :, :] = seg_label[:, :, :, ::-1]
88 | return image, seg_label
--------------------------------------------------------------------------------
/mctrans/data/transforms/utils.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | import numpy
7 |
8 |
9 | def resize(input,
10 | size=None,
11 | scale_factor=None,
12 | mode='nearest',
13 | align_corners=None,
14 | warning=True):
15 | if warning:
16 | if size is not None and align_corners:
17 | input_h, input_w = tuple(int(x) for x in input.shape[2:])
18 | output_h, output_w = tuple(int(x) for x in size)
19 | if output_h > input_h or output_w > output_h:
20 | if ((output_h > 1 and output_w > 1 and input_h > 1
21 | and input_w > 1) and (output_h - 1) % (input_h - 1)
22 | and (output_w - 1) % (input_w - 1)):
23 | warnings.warn(
24 | f'When align_corners={align_corners}, '
25 | 'the output would more aligned if '
26 | f'input size {(input_h, input_w)} is `x+1` and '
27 | f'out size {(output_h, output_w)} is `nx+1`')
28 | if isinstance(size, torch.Size):
29 | size = tuple(int(x) for x in size)
30 | if isinstance(size, (numpy.ndarray, numpy.generic) ):
31 | size = tuple(int(x) for x in size)
32 | return F.interpolate(input, size, scale_factor, mode, align_corners)
--------------------------------------------------------------------------------
/mctrans/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from .builder import build_metrics, METRICS, Metric
--------------------------------------------------------------------------------
/mctrans/metrics/base.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import mmcv
4 | import numpy as np
5 | import torch
6 |
7 |
8 | def f_score(precision, recall, beta=1):
9 | """calcuate the f-score value.
10 | Args:
11 | precision (float | torch.Tensor): The precision value.
12 | recall (float | torch.Tensor): The recall value.
13 | beta (int): Determines the weight of recall in the combined score.
14 | Default: False.
15 | Returns:
16 | [torch.tensor]: The f-score value.
17 | """
18 | score = (1 + beta ** 2) * (precision * recall) / (
19 | (beta ** 2 * precision) + recall)
20 | return score
21 |
22 |
23 | def intersect_and_union(pred_label,
24 | label,
25 | num_classes,
26 | ignore_index,
27 | label_map=dict(),
28 | reduce_zero_label=False):
29 | """Calculate intersection and Union.
30 | Args:
31 | pred_label (ndarray | str): Prediction segmentation map
32 | or predict result filename.
33 | label (ndarray | str): Ground truth segmentation map
34 | or label filename.
35 | num_classes (int): Number of categories.
36 | ignore_index (int): Index that will be ignored in evaluation.
37 | label_map (dict): Mapping old labels to new labels. The parameter will
38 | work only when label is str. Default: dict().
39 | reduce_zero_label (bool): Wether ignore zero label. The parameter will
40 | work only when label is str. Default: False.
41 | Returns:
42 | torch.Tensor: The intersection of prediction and ground truth
43 | histogram on all classes.
44 | torch.Tensor: The union of prediction and ground truth histogram on
45 | all classes.
46 | torch.Tensor: The prediction histogram on all classes.
47 | torch.Tensor: The ground truth histogram on all classes.
48 | """
49 |
50 | if isinstance(pred_label, str):
51 | pred_label = torch.from_numpy(np.load(pred_label))
52 | else:
53 | pred_label = torch.from_numpy((pred_label))
54 |
55 | if isinstance(label, str):
56 | label = torch.from_numpy(
57 | mmcv.imread(label, flag='unchanged', backend='pillow'))
58 | else:
59 | label = torch.from_numpy(label)
60 |
61 | if label_map is not None:
62 | for old_id, new_id in label_map.items():
63 | label[label == old_id] = new_id
64 | if reduce_zero_label:
65 | label[label == 0] = 255
66 | label = label - 1
67 | label[label == 254] = 255
68 |
69 | if len(label.shape) != len(pred_label.shape):
70 | pred_label = torch.squeeze(pred_label, dim=0)
71 | mask = (label != ignore_index)
72 | pred_label = pred_label[mask]
73 | label = label[mask]
74 |
75 | intersect = pred_label[pred_label == label]
76 | area_intersect = torch.histc(
77 | intersect.float(), bins=(num_classes), min=0, max=num_classes - 1)
78 | area_pred_label = torch.histc(
79 | pred_label.float(), bins=(num_classes), min=0, max=num_classes - 1)
80 | area_label = torch.histc(
81 | label.float(), bins=(num_classes), min=0, max=num_classes - 1)
82 | area_union = area_pred_label + area_label - area_intersect
83 | return area_intersect, area_union, area_pred_label, area_label
84 |
85 |
86 | def total_intersect_and_union(results,
87 | gt_seg_maps,
88 | num_classes,
89 | ignore_index,
90 | label_map=dict(),
91 | reduce_zero_label=False):
92 | """Calculate Total Intersection and Union.
93 | Args:
94 | results (list[ndarray] | list[str]): List of prediction segmentation
95 | maps or list of prediction result filenames.
96 | gt_seg_maps (list[ndarray] | list[str]): list of ground truth
97 | segmentation maps or list of label filenames.
98 | num_classes (int): Number of categories.
99 | ignore_index (int): Index that will be ignored in evaluation.
100 | label_map (dict): Mapping old labels to new labels. Default: dict().
101 | reduce_zero_label (bool): Wether ignore zero label. Default: False.
102 | Returns:
103 | ndarray: The intersection of prediction and ground truth histogram
104 | on all classes.
105 | ndarray: The union of prediction and ground truth histogram on all
106 | classes.
107 | ndarray: The prediction histogram on all classes.
108 | ndarray: The ground truth histogram on all classes.
109 | """
110 | num_imgs = len(results)
111 | assert len(gt_seg_maps) == num_imgs
112 | total_area_intersect = torch.zeros((num_classes,), dtype=torch.float64)
113 | total_area_union = torch.zeros((num_classes,), dtype=torch.float64)
114 | total_area_pred_label = torch.zeros((num_classes,), dtype=torch.float64)
115 | total_area_label = torch.zeros((num_classes,), dtype=torch.float64)
116 | for i in range(num_imgs):
117 | area_intersect, area_union, area_pred_label, area_label = \
118 | intersect_and_union(
119 | results[i], gt_seg_maps[i], num_classes, ignore_index,
120 | label_map, reduce_zero_label)
121 | total_area_intersect += area_intersect
122 | total_area_union += area_union
123 | total_area_pred_label += area_pred_label
124 | total_area_label += area_label
125 | return total_area_intersect, total_area_union, total_area_pred_label, \
126 | total_area_label
127 |
128 |
129 | def mean_iou(results,
130 | gt_seg_maps,
131 | num_classes,
132 | ignore_index,
133 | nan_to_num=None,
134 | label_map=dict(),
135 | reduce_zero_label=False):
136 | """Calculate Mean Intersection and Union (mIoU)
137 | Args:
138 | results (list[ndarray] | list[str]): List of prediction segmentation
139 | maps or list of prediction result filenames.
140 | gt_seg_maps (list[ndarray] | list[str]): list of ground truth
141 | segmentation maps or list of label filenames.
142 | num_classes (int): Number of categories.
143 | ignore_index (int): Index that will be ignored in evaluation.
144 | nan_to_num (int, optional): If specified, NaN values will be replaced
145 | by the numbers defined by the user. Default: None.
146 | label_map (dict): Mapping old labels to new labels. Default: dict().
147 | reduce_zero_label (bool): Wether ignore zero label. Default: False.
148 | Returns:
149 | dict[str, float | ndarray]:
150 | float: Overall accuracy on all images.
151 | ndarray: Per category accuracy, shape (num_classes, ).
152 | ndarray: Per category IoU, shape (num_classes, ).
153 | """
154 | iou_result = eval_metrics(
155 | results=results,
156 | gt_seg_maps=gt_seg_maps,
157 | num_classes=num_classes,
158 | ignore_index=ignore_index,
159 | metrics=['mIoU'],
160 | nan_to_num=nan_to_num,
161 | label_map=label_map,
162 | reduce_zero_label=reduce_zero_label)
163 | return iou_result
164 |
165 |
166 | def mean_dice(results,
167 | gt_seg_maps,
168 | num_classes,
169 | ignore_index,
170 | nan_to_num=None,
171 | label_map=dict(),
172 | reduce_zero_label=False):
173 | """Calculate Mean Dice (mDice)
174 | Args:
175 | results (list[ndarray] | list[str]): List of prediction segmentation
176 | maps or list of prediction result filenames.
177 | gt_seg_maps (list[ndarray] | list[str]): list of ground truth
178 | segmentation maps or list of label filenames.
179 | num_classes (int): Number of categories.
180 | ignore_index (int): Index that will be ignored in evaluation.
181 | nan_to_num (int, optional): If specified, NaN values will be replaced
182 | by the numbers defined by the user. Default: None.
183 | label_map (dict): Mapping old labels to new labels. Default: dict().
184 | reduce_zero_label (bool): Wether ignore zero label. Default: False.
185 | Returns:
186 | dict[str, float | ndarray]: Default metrics.
187 | float: Overall accuracy on all images.
188 | ndarray: Per category accuracy, shape (num_classes, ).
189 | ndarray: Per category dice, shape (num_classes, ).
190 | """
191 |
192 | dice_result = eval_metrics(
193 | results=results,
194 | gt_seg_maps=gt_seg_maps,
195 | num_classes=num_classes,
196 | ignore_index=ignore_index,
197 | metrics=['mDice'],
198 | nan_to_num=nan_to_num,
199 | label_map=label_map,
200 | reduce_zero_label=reduce_zero_label)
201 | return dice_result
202 |
203 |
204 | def mean_fscore(results,
205 | gt_seg_maps,
206 | num_classes,
207 | ignore_index,
208 | nan_to_num=None,
209 | label_map=dict(),
210 | reduce_zero_label=False,
211 | beta=1):
212 | """Calculate Mean Intersection and Union (mIoU)
213 | Args:
214 | results (list[ndarray] | list[str]): List of prediction segmentation
215 | maps or list of prediction result filenames.
216 | gt_seg_maps (list[ndarray] | list[str]): list of ground truth
217 | segmentation maps or list of label filenames.
218 | num_classes (int): Number of categories.
219 | ignore_index (int): Index that will be ignored in evaluation.
220 | nan_to_num (int, optional): If specified, NaN values will be replaced
221 | by the numbers defined by the user. Default: None.
222 | label_map (dict): Mapping old labels to new labels. Default: dict().
223 | reduce_zero_label (bool): Wether ignore zero label. Default: False.
224 | beta (int): Determines the weight of recall in the combined score.
225 | Default: False.
226 | Returns:
227 | dict[str, float | ndarray]: Default metrics.
228 | float: Overall accuracy on all images.
229 | ndarray: Per category recall, shape (num_classes, ).
230 | ndarray: Per category precision, shape (num_classes, ).
231 | ndarray: Per category f-score, shape (num_classes, ).
232 | """
233 | fscore_result = eval_metrics(
234 | results=results,
235 | gt_seg_maps=gt_seg_maps,
236 | num_classes=num_classes,
237 | ignore_index=ignore_index,
238 | metrics=['mFscore'],
239 | nan_to_num=nan_to_num,
240 | label_map=label_map,
241 | reduce_zero_label=reduce_zero_label,
242 | beta=beta)
243 | return fscore_result
244 |
245 |
246 | def eval_metrics(results,
247 | gt_seg_maps,
248 | num_classes,
249 | ignore_index,
250 | metrics=['mIoU'],
251 | nan_to_num=None,
252 | label_map=dict(),
253 | reduce_zero_label=False,
254 | beta=1):
255 | """Calculate evaluation metrics
256 | Args:
257 | results (list[ndarray] | list[str]): List of prediction segmentation
258 | maps or list of prediction result filenames.
259 | gt_seg_maps (list[ndarray] | list[str]): list of ground truth
260 | segmentation maps or list of label filenames.
261 | num_classes (int): Number of categories.
262 | ignore_index (int): Index that will be ignored in evaluation.
263 | metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'.
264 | nan_to_num (int, optional): If specified, NaN values will be replaced
265 | by the numbers defined by the user. Default: None.
266 | label_map (dict): Mapping old labels to new labels. Default: dict().
267 | reduce_zero_label (bool): Wether ignore zero label. Default: False.
268 | Returns:
269 | float: Overall accuracy on all images.
270 | ndarray: Per category accuracy, shape (num_classes, ).
271 | ndarray: Per category evaluation metrics, shape (num_classes, ).
272 | """
273 | if isinstance(metrics, str):
274 | metrics = [metrics]
275 | allowed_metrics = ['mIoU', 'mDice', 'mFscore', "mHd95"]
276 | if not set(metrics).issubset(set(allowed_metrics)):
277 | raise KeyError('metrics {} is not supported'.format(metrics))
278 |
279 | total_area_intersect, total_area_union, total_area_pred_label, \
280 | total_area_label = total_intersect_and_union(
281 | results, gt_seg_maps, num_classes, ignore_index, label_map,
282 | reduce_zero_label)
283 | all_acc = total_area_intersect.sum() / total_area_label.sum()
284 | ret_metrics = OrderedDict({'aAcc': all_acc})
285 | for metric in metrics:
286 | if metric == 'mIoU':
287 | iou = total_area_intersect / total_area_union
288 | acc = total_area_intersect / total_area_label
289 | ret_metrics['IoU'] = iou
290 | ret_metrics['Acc'] = acc
291 | elif metric == 'mDice':
292 | dice = 2 * total_area_intersect / (
293 | total_area_pred_label + total_area_label)
294 | acc = total_area_intersect / total_area_label
295 | ret_metrics['Dice'] = dice
296 | ret_metrics['Acc'] = acc
297 | elif metric == 'mFscore':
298 | precision = total_area_intersect / total_area_pred_label
299 | recall = total_area_intersect / total_area_label
300 | f_value = torch.tensor(
301 | [f_score(x[0], x[1], beta) for x in zip(precision, recall)])
302 | ret_metrics['Fscore'] = f_value
303 | ret_metrics['Precision'] = precision
304 | ret_metrics['Recall'] = recall
305 | elif metric == "mHd95":
306 | pass
307 |
308 | ret_metrics = {
309 | metric: value.numpy()
310 | for metric, value in ret_metrics.items()
311 | }
312 | if nan_to_num is not None:
313 | ret_metrics = OrderedDict({
314 | metric: np.nan_to_num(metric_value, nan=nan_to_num)
315 | for metric, metric_value in ret_metrics.items()
316 | })
317 | return ret_metrics
318 |
--------------------------------------------------------------------------------
/mctrans/metrics/builder.py:
--------------------------------------------------------------------------------
1 | from mmcv.utils import Registry, build_from_cfg
2 | from abc import ABC, abstractmethod
3 |
4 | METRICS = Registry('metric')
5 |
6 | def build_metrics(cfg):
7 | """Build metric."""
8 | class_names = cfg.class_names
9 | cfg_types = cfg.metric_types
10 | metrics = [build_from_cfg(_cfg, METRICS) for _cfg in cfg_types]
11 | for metric in metrics:
12 | metric.set_class_name(class_names)
13 | return metrics
14 |
15 | class Metric(ABC):
16 |
17 | @abstractmethod
18 | def __call__(self, pred, target, *args, **kwargs):
19 | raise NotImplementedError
20 |
21 | def set_class_name(self, class_names):
22 | self.class_names = class_names
23 |
24 |
--------------------------------------------------------------------------------
/mctrans/metrics/hausdorff_distance.py:
--------------------------------------------------------------------------------
1 | from typing import Union, Optional
2 |
3 | import mmcv
4 | import numpy as np
5 | import torch
6 | from monai.metrics import get_mask_edges, get_surface_distance
7 |
8 |
9 | def compute_hausdorff_distance(
10 | y_pred: Union[np.ndarray, torch.Tensor],
11 | y: Union[np.ndarray, torch.Tensor],
12 | include_background: bool = False,
13 | distance_metric: str = "euclidean",
14 | percentile: Optional[float] = None,
15 | directed: bool = False,
16 | ):
17 | """
18 | Compute the Hausdorff distance.
19 |
20 | Args:
21 | y_pred: input data to compute, typical segmentation model output.
22 | It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values
23 | should be binarized.
24 | y: ground truth to compute mean the distance. It must be one-hot format and first dim is batch.
25 | The values should be binarized.
26 | include_background: whether to skip distance computation on the first channel of
27 | the predicted output. Defaults to ``False``.
28 | distance_metric: : [``"euclidean"``, ``"chessboard"``, ``"taxicab"``]
29 | the metric used to compute surface distance. Defaults to ``"euclidean"``.
30 | percentile: an optional float number between 0 and 100. If specified, the corresponding
31 | percentile of the Hausdorff Distance rather than the maximum result will be achieved.
32 | Defaults to ``None``.
33 | directed: whether to calculate directed Hausdorff distance. Defaults to ``False``.
34 | """
35 | if isinstance(y_pred, str):
36 | y_pred = torch.from_numpy(np.load(y_pred))
37 | else:
38 | y_pred = torch.from_numpy((y_pred))
39 |
40 | if isinstance(y, str):
41 | y = torch.from_numpy(
42 | mmcv.imread(y, flag='unchanged', backend='pillow'))
43 | else:
44 | y = torch.from_numpy(y)
45 |
46 | if isinstance(y, torch.Tensor):
47 | y = y.float()
48 | if isinstance(y_pred, torch.Tensor):
49 | y_pred = y_pred.float()
50 |
51 | if y.shape != y_pred.shape:
52 | y_pred = torch.squeeze(y_pred, dim=0)
53 |
54 | batch_size, n_class = y_pred.shape[:2]
55 | hd = np.empty((batch_size, n_class))
56 |
57 | (edges_pred, edges_gt) = get_mask_edges(y_pred, y)
58 | distance_1 = compute_percent_hausdorff_distance(edges_pred, edges_gt, distance_metric, percentile)
59 |
60 | for b, c in np.ndindex(batch_size, n_class):
61 | (edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c])
62 | distance_1 = compute_percent_hausdorff_distance(edges_pred, edges_gt, distance_metric, percentile)
63 | if directed:
64 | hd[b, c] = distance_1
65 | else:
66 | distance_2 = compute_percent_hausdorff_distance(edges_gt, edges_pred, distance_metric, percentile)
67 | hd[b, c] = max(distance_1, distance_2)
68 | return torch.from_numpy(hd)
69 |
70 |
71 |
72 | def compute_percent_hausdorff_distance(
73 | edges_pred: np.ndarray,
74 | edges_gt: np.ndarray,
75 | distance_metric: str = "euclidean",
76 | percentile: Optional[float] = None,
77 | ):
78 | """
79 | This function is used to compute the directed Hausdorff distance.
80 | """
81 |
82 | surface_distance = get_surface_distance(edges_pred, edges_gt, distance_metric=distance_metric)
83 |
84 | # for both pred and gt do not have foreground
85 | if surface_distance.shape == (0,):
86 | return np.nan
87 |
88 | if not percentile:
89 | return surface_distance.max()
90 |
91 | if 0 <= percentile <= 100:
92 | return np.percentile(surface_distance, percentile)
93 | raise ValueError(f"percentile should be a value between 0 and 100, get {percentile}.")
94 |
95 |
96 | def total_hausdorff_distance(results,
97 | gt_seg_maps,
98 | ):
99 | num_imgs = len(results)
100 | assert len(gt_seg_maps) == num_imgs
101 |
102 | for i in range(num_imgs):
103 | x= compute_hausdorff_distance(results[i], gt_seg_maps[i], percentile=95)
104 | return 0
--------------------------------------------------------------------------------
/mctrans/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .centers import *
2 | from .losses import *
3 | from .builder import build_model, build_losses, MODEL, NETWORKS, LOSSES, HEADS, ENCODERS, DECODERS
4 | from .encoders import *
5 | from .decoders import *
6 | from .heads import *
7 | from .segmentors import *
--------------------------------------------------------------------------------
/mctrans/models/builder.py:
--------------------------------------------------------------------------------
1 | from mmcv.utils import Registry, build_from_cfg
2 | from torch import nn
3 |
4 | NETWORKS = Registry('network')
5 | LOSSES = Registry('loss')
6 | MODEL = Registry('model')
7 | ENCODERS = Registry('encoder')
8 | DECODERS = Registry('decoder')
9 | CENTERS = Registry('center')
10 | HEADS = Registry('head')
11 |
12 |
13 | def build(cfg, registry, default_args=None):
14 | """Build a module.
15 |
16 | Args:
17 | cfg (dict, list[dict]): The config of modules, is is either a dict
18 | or a list of configs.
19 | registry (:obj:`Registry`): A registry the module belongs to.
20 | default_args (dict, optional): Default arguments to build the module.
21 | Defaults to None.
22 |
23 | Returns:
24 | nn.Module: A built nn module.
25 | """
26 |
27 | if isinstance(cfg, list):
28 | modules = [
29 | build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
30 | ]
31 | return nn.Sequential(*modules)
32 | else:
33 | return build_from_cfg(cfg, registry, default_args)
34 |
35 |
36 | def build_network(cfg):
37 | """Build network."""
38 | return build(cfg, NETWORKS)
39 |
40 |
41 | def build_losses(cfg):
42 | """Build loss."""
43 | return [build_from_cfg(_cfg, LOSSES) for _cfg in cfg]
44 |
45 |
46 | def build_model(cfg):
47 | """Build model."""
48 | return build(cfg, MODEL)
49 |
50 |
51 | def build_encoder(cfg):
52 | """Build Encoder."""
53 | return build(cfg, ENCODERS)
54 |
55 |
56 | def build_decoder(cfg):
57 | """Build Decoder."""
58 | return build(cfg, DECODERS)
59 |
60 |
61 | def build_center(cfg):
62 | """Build Center."""
63 | return build(cfg, CENTERS)
64 |
65 |
66 | def build_head(cfg):
67 | """Build SegHead."""
68 | return build(cfg, HEADS)
69 |
--------------------------------------------------------------------------------
/mctrans/models/centers/__init__.py:
--------------------------------------------------------------------------------
1 | from .mctrans import MCTrans
2 | from .cenet import CEncoder
3 | from .non_local import NonLocal
4 | from .vit import Vit
--------------------------------------------------------------------------------
/mctrans/models/centers/cenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from ..builder import CENTERS
6 |
7 |
8 | class DacBlock(nn.Module):
9 | def __init__(self, channel):
10 | super(DacBlock, self).__init__()
11 | self.dilate1 = nn.Conv2d(channel, channel, kernel_size=3, dilation=1, padding=1)
12 | self.dilate2 = nn.Conv2d(channel, channel, kernel_size=3, dilation=3, padding=3)
13 | self.dilate3 = nn.Conv2d(channel, channel, kernel_size=3, dilation=5, padding=5)
14 | self.conv1x1 = nn.Conv2d(channel, channel, kernel_size=1, dilation=1, padding=0)
15 | self.relu = nn.ReLU(inplace=True)
16 | for m in self.modules():
17 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
18 | if m.bias is not None:
19 | m.bias.data.zero_()
20 |
21 | def forward(self, x):
22 | dilate1_out = self.relu(self.dilate1(x))
23 | dilate2_out = self.relu(self.conv1x1(self.dilate2(x)))
24 | dilate3_out = self.relu(self.conv1x1(self.dilate2(self.dilate1(x))))
25 | dilate4_out = self.relu(self.conv1x1(self.dilate3(self.dilate2(self.dilate1(x)))))
26 | out = x + dilate1_out + dilate2_out + dilate3_out + dilate4_out
27 | return out
28 |
29 |
30 | class SppBlock(nn.Module):
31 | def __init__(self, in_channels):
32 | super(SppBlock, self).__init__()
33 | self.pool1 = nn.MaxPool2d(kernel_size=[2, 2], stride=2)
34 | self.pool2 = nn.MaxPool2d(kernel_size=[3, 3], stride=3)
35 | self.pool3 = nn.MaxPool2d(kernel_size=[5, 5], stride=5)
36 | self.pool4 = nn.MaxPool2d(kernel_size=[6, 6], stride=6)
37 |
38 | self.conv = nn.Conv2d(in_channels=in_channels, out_channels=1, kernel_size=1, padding=0)
39 |
40 | def forward(self, x):
41 | self.in_channels, h, w = x.size(1), x.size(2), x.size(3)
42 | self.layer1 = F.interpolate(self.conv(self.pool1(x)), size=(h, w), mode='bilinear', align_corners=True)
43 | self.layer2 = F.interpolate(self.conv(self.pool2(x)), size=(h, w), mode='bilinear', align_corners=True)
44 | self.layer3 = F.interpolate(self.conv(self.pool3(x)), size=(h, w), mode='bilinear', align_corners=True)
45 | self.layer4 = F.interpolate(self.conv(self.pool4(x)), size=(h, w), mode='bilinear', align_corners=True)
46 |
47 | out = torch.cat([self.layer1, self.layer2, self.layer3, self.layer4, x], 1)
48 |
49 | return out
50 |
51 |
52 | @CENTERS.register_module()
53 | class CEncoder(nn.Module):
54 | def __init__(self,
55 | in_channels=[240],
56 | ):
57 | super().__init__()
58 | in_channels = in_channels[-1]
59 | self.dac = DacBlock(in_channels)
60 | self.spp = SppBlock(in_channels)
61 |
62 | def forward(self, x):
63 | feat = x[-1]
64 | feat = self.dac(feat)
65 | feat = self.spp(feat)
66 | x[-1] = feat
67 | return x
68 |
--------------------------------------------------------------------------------
/mctrans/models/centers/mctrans.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.init import normal_
4 |
5 | from mmcv.cnn import ConvModule
6 |
7 | from ..builder import CENTERS
8 | from ..ops.modules import MSDeformAttn
9 | from ..trans.transformer import DSALayer, DSA
10 | from ..trans.utils import build_position_encoding, NestedTensor
11 |
12 |
13 | @CENTERS.register_module()
14 | class MCTrans(nn.Module):
15 | def __init__(self,
16 | d_model=240,
17 | nhead=8,
18 | d_ffn=1024,
19 | dropout=0.1,
20 | act="relu",
21 | n_points=4,
22 | n_levels=3,
23 | n_sa_layers=6,
24 | in_channles=[64, 64, 128, 256, 512],
25 | proj_idxs=(2, 3, 4),
26 |
27 | ):
28 | super().__init__()
29 | self.nhead = nhead
30 | self.d_model = d_model
31 | self.n_levels = n_levels
32 |
33 | self.proj_idxs = proj_idxs
34 | self.projs = nn.ModuleList()
35 | for idx in self.proj_idxs:
36 | self.projs.append(ConvModule(in_channles[idx],
37 | d_model,
38 | kernel_size=3,
39 | padding=1,
40 | conv_cfg=dict(type="Conv"),
41 | norm_cfg=dict(type='BN'),
42 | act_cfg=dict(type='ReLU')
43 | ))
44 |
45 | dsa_layer = DSALayer(d_model=d_model,
46 | d_ffn=d_ffn,
47 | dropout=dropout,
48 | activation=act,
49 | n_levels=n_levels,
50 | n_heads=nhead,
51 | n_points=n_points)
52 |
53 | self.dsa = DSA(att_layer=dsa_layer,
54 | n_layers=n_sa_layers)
55 |
56 | self.level_embed = nn.Parameter(torch.Tensor(n_levels, d_model))
57 | self.position_embedding = build_position_encoding(position_embedding="sine", hidden_dim=d_model)
58 | self._reset_parameters()
59 |
60 | def _reset_parameters(self):
61 | for p in self.parameters():
62 | if p.dim() > 1:
63 | nn.init.xavier_uniform_(p)
64 | for m in self.modules():
65 | if isinstance(m, MSDeformAttn):
66 | m._reset_parameters()
67 | normal_(self.level_embed)
68 |
69 | def get_valid_ratio(self, mask):
70 | _, H, W = mask.shape
71 | valid_H = torch.sum(~mask[:, :, 0], 1)
72 | valid_W = torch.sum(~mask[:, 0, :], 1)
73 | valid_ratio_h = valid_H.float() / H
74 | valid_ratio_w = valid_W.float() / W
75 | valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
76 | return valid_ratio
77 |
78 | def projection(self, feats):
79 | pos = []
80 | masks = []
81 | cnn_feats = []
82 | tran_feats = []
83 |
84 | for idx, feats in enumerate(feats):
85 | if idx not in self.proj_idxs:
86 | cnn_feats.append(feats)
87 | else:
88 | n, c, h, w = feats.shape
89 | mask = torch.zeros((n, h, w)).to(torch.bool).to(feats.device)
90 | nested_feats = NestedTensor(feats, mask)
91 | masks.append(mask)
92 | pos.append(self.position_embedding(nested_feats).to(nested_feats.tensors.dtype))
93 | tran_feats.append(feats)
94 |
95 | for idx, proj in enumerate(self.projs):
96 | tran_feats[idx] = proj(tran_feats[idx])
97 |
98 | return cnn_feats, tran_feats, pos, masks
99 |
100 | def forward(self, x):
101 | # project and prepare for the input
102 | cnn_feats, trans_feats, pos_embs, masks = self.projection(x)
103 | # dsa
104 | features_flatten = []
105 | mask_flatten = []
106 | lvl_pos_embed_flatten = []
107 | feature_shapes = []
108 | spatial_shapes = []
109 | for lvl, (feature, mask, pos_embed) in enumerate(zip(trans_feats, masks, pos_embs)):
110 | bs, c, h, w = feature.shape
111 | spatial_shapes.append((h, w))
112 | feature_shapes.append(feature.shape)
113 |
114 | feature = feature.flatten(2).transpose(1, 2)
115 | mask = mask.flatten(1)
116 | pos_embed = pos_embed.flatten(2).transpose(1, 2)
117 | lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
118 | lvl_pos_embed_flatten.append(lvl_pos_embed)
119 |
120 | features_flatten.append(feature)
121 | mask_flatten.append(mask)
122 |
123 | features_flatten = torch.cat(features_flatten, 1)
124 | mask_flatten = torch.cat(mask_flatten, 1)
125 | lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
126 | spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=features_flatten.device)
127 | level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
128 | valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
129 |
130 | # self att
131 | feats = self.dsa(features_flatten,
132 | spatial_shapes,
133 | level_start_index,
134 | valid_ratios,
135 | lvl_pos_embed_flatten,
136 | mask_flatten)
137 | # recover
138 | out = []
139 | features = feats.split(spatial_shapes.prod(1).tolist(), dim=1)
140 | for idx, (feats, ori_shape) in enumerate(zip(features, spatial_shapes)):
141 | out.append(feats.transpose(1, 2).reshape(feature_shapes[idx]))
142 |
143 | cnn_feats.extend(out)
144 | return cnn_feats
145 |
--------------------------------------------------------------------------------
/mctrans/models/centers/non_local.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from mctrans.models.builder import CENTERS
5 |
6 |
7 | class _NonLocalBlockND(nn.Module):
8 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
9 | super(_NonLocalBlockND, self).__init__()
10 |
11 | assert dimension in [1, 2, 3]
12 |
13 | self.dimension = dimension
14 | self.sub_sample = sub_sample
15 |
16 | self.in_channels = in_channels
17 | self.inter_channels = inter_channels
18 |
19 | if self.inter_channels is None:
20 | self.inter_channels = in_channels // 2
21 | if self.inter_channels == 0:
22 | self.inter_channels = 1
23 |
24 | if dimension == 3:
25 | conv_nd = nn.Conv3d
26 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
27 | bn = nn.BatchNorm3d
28 | elif dimension == 2:
29 | conv_nd = nn.Conv2d
30 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
31 | bn = nn.BatchNorm2d
32 | else:
33 | conv_nd = nn.Conv1d
34 | max_pool_layer = nn.MaxPool1d(kernel_size=(2))
35 | bn = nn.BatchNorm1d
36 |
37 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
38 | kernel_size=1, stride=1, padding=0)
39 |
40 | if bn_layer:
41 | self.W = nn.Sequential(
42 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
43 | kernel_size=1, stride=1, padding=0),
44 | bn(self.in_channels)
45 | )
46 | nn.init.constant_(self.W[1].weight, 0)
47 | nn.init.constant_(self.W[1].bias, 0)
48 | else:
49 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
50 | kernel_size=1, stride=1, padding=0)
51 | nn.init.constant_(self.W.weight, 0)
52 | nn.init.constant_(self.W.bias, 0)
53 |
54 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
55 | kernel_size=1, stride=1, padding=0)
56 |
57 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
58 | kernel_size=1, stride=1, padding=0)
59 |
60 | if sub_sample:
61 | self.g = nn.Sequential(self.g, max_pool_layer)
62 | self.phi = nn.Sequential(self.phi, max_pool_layer)
63 |
64 | def forward(self, x):
65 | '''
66 | :param x: (b, c, t, h, w)
67 | :return:
68 | '''
69 |
70 | batch_size = x.size(0)
71 |
72 | g_x = self.g(x).view(batch_size, self.inter_channels, -1)
73 | g_x = g_x.permute(0, 2, 1)
74 |
75 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
76 | theta_x = theta_x.permute(0, 2, 1)
77 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
78 | f = torch.matmul(theta_x, phi_x)
79 | N = f.size(-1)
80 | f_div_C = f / N
81 |
82 | y = torch.matmul(f_div_C, g_x)
83 | y = y.permute(0, 2, 1).contiguous()
84 | y = y.view(batch_size, self.inter_channels, *x.size()[2:])
85 | W_y = self.W(y)
86 | z = W_y + x
87 |
88 | return z
89 |
90 |
91 | @CENTERS.register_module()
92 | class NonLocal(_NonLocalBlockND):
93 | def __init__(self,
94 | in_channels=[240],
95 | ):
96 | in_channels = in_channels[-1]
97 | super(NonLocal, self).__init__(in_channels,
98 | inter_channels=None,
99 | dimension=2,
100 | sub_sample=True,
101 | bn_layer=True)
102 |
103 | def forward(self, x):
104 | x[-1] = super().forward(x[-1])
105 | return x
106 |
--------------------------------------------------------------------------------
/mctrans/models/centers/vit.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | from mctrans.models.builder import CENTERS
5 |
6 |
7 | class FixedPositionalEncoding(nn.Module):
8 | def __init__(self, embedding_dim, max_length=5000):
9 | super(FixedPositionalEncoding, self).__init__()
10 |
11 | pe = torch.zeros(max_length, embedding_dim)
12 | position = torch.arange(0, max_length, dtype=torch.float).unsqueeze(1)
13 | div_term = torch.exp(
14 | torch.arange(0, embedding_dim, 2).float()
15 | * (-torch.log(torch.tensor(10000.0)) / embedding_dim)
16 | )
17 | pe[:, 0::2] = torch.sin(position * div_term)
18 | pe[:, 1::2] = torch.cos(position * div_term)
19 | pe = pe.unsqueeze(0).transpose(0, 1)
20 | self.register_buffer('pe', pe)
21 |
22 | def forward(self, x):
23 | x = x + self.pe[: x.size(0), :]
24 | return x
25 |
26 |
27 | class LearnedPositionalEncoding(nn.Module):
28 | def __init__(self, max_position_embeddings, embedding_dim, seq_length):
29 | super(LearnedPositionalEncoding, self).__init__()
30 | self.pe = nn.Embedding(max_position_embeddings, embedding_dim)
31 | self.seq_length = seq_length
32 |
33 | self.register_buffer(
34 | "position_ids",
35 | torch.arange(max_position_embeddings).expand((1, -1)),
36 | )
37 |
38 | def forward(self, x, position_ids=None):
39 | if position_ids is None:
40 | position_ids = self.position_ids[:, : self.seq_length]
41 |
42 | position_embeddings = self.pe(position_ids)
43 | return x + position_embeddings
44 |
45 |
46 | class SelfAttention(nn.Module):
47 | def __init__(
48 | self, dim, heads=8, qkv_bias=False, qk_scale=None, dropout_rate=0.0
49 | ):
50 | super().__init__()
51 | self.num_heads = heads
52 | head_dim = dim // heads
53 | self.scale = qk_scale or head_dim ** -0.5
54 |
55 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
56 | self.attn_drop = nn.Dropout(dropout_rate)
57 | self.proj = nn.Linear(dim, dim)
58 | self.proj_drop = nn.Dropout(dropout_rate)
59 |
60 | def forward(self, x):
61 | B, N, C = x.shape
62 | qkv = (
63 | self.qkv(x)
64 | .reshape(B, N, 3, self.num_heads, C // self.num_heads)
65 | .permute(2, 0, 3, 1, 4)
66 | )
67 | q, k, v = (
68 | qkv[0],
69 | qkv[1],
70 | qkv[2],
71 | ) # make torchscript happy (cannot use tensor as tuple)
72 |
73 | attn = (q @ k.transpose(-2, -1)) * self.scale
74 | attn = attn.softmax(dim=-1)
75 | attn = self.attn_drop(attn)
76 |
77 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
78 | x = self.proj(x)
79 | x = self.proj_drop(x)
80 | return x
81 |
82 |
83 | class Residual(nn.Module):
84 | def __init__(self, fn):
85 | super().__init__()
86 | self.fn = fn
87 |
88 | def forward(self, x):
89 | return self.fn(x) + x
90 |
91 |
92 | class PreNorm(nn.Module):
93 | def __init__(self, dim, fn):
94 | super().__init__()
95 | self.norm = nn.LayerNorm(dim)
96 | self.fn = fn
97 |
98 | def forward(self, x):
99 | return self.fn(self.norm(x))
100 |
101 |
102 | class PreNormDrop(nn.Module):
103 | def __init__(self, dim, dropout_rate, fn):
104 | super().__init__()
105 | self.norm = nn.LayerNorm(dim)
106 | self.dropout = nn.Dropout(p=dropout_rate)
107 | self.fn = fn
108 |
109 | def forward(self, x):
110 | return self.dropout(self.fn(self.norm(x)))
111 |
112 |
113 | class FeedForward(nn.Module):
114 | def __init__(self, dim, hidden_dim, dropout_rate):
115 | super().__init__()
116 | self.net = nn.Sequential(
117 | nn.Linear(dim, hidden_dim),
118 | nn.GELU(),
119 | nn.Dropout(p=dropout_rate),
120 | nn.Linear(hidden_dim, dim),
121 | nn.Dropout(p=dropout_rate),
122 | )
123 |
124 | def forward(self, x):
125 | return self.net(x)
126 |
127 |
128 | class TransformerModel(nn.Module):
129 | def __init__(
130 | self,
131 | dim,
132 | depth,
133 | heads,
134 | mlp_dim,
135 | dropout_rate=0.1,
136 | attn_dropout_rate=0.1,
137 | ):
138 | super().__init__()
139 | layers = []
140 | for _ in range(depth):
141 | layers.extend(
142 | [
143 | Residual(
144 | PreNormDrop(
145 | dim,
146 | dropout_rate,
147 | SelfAttention(
148 | dim, heads=heads, dropout_rate=attn_dropout_rate
149 | ),
150 | )
151 | ),
152 | Residual(
153 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate))
154 | ),
155 | ]
156 | )
157 | self.net = nn.Sequential(*layers)
158 |
159 | def forward(self, x):
160 | return self.net(x)
161 |
162 |
163 | @CENTERS.register_module()
164 | class Vit(nn.Module):
165 | def __init__(self,
166 | input_size=(30, 16, 16),
167 | embedding_dims=None,
168 | num_encoders=6,
169 | positional_encoding_type="fixed"):
170 | super().__init__()
171 | input_channel, input_height, input_width = input_size
172 | num_patches = input_height * input_width
173 | flatten_dims = input_channel
174 |
175 | if embedding_dims is None:
176 | embedding_dims = flatten_dims
177 |
178 | self.linear_encoding = nn.Linear(flatten_dims, embedding_dims)
179 |
180 | if positional_encoding_type == "learnable":
181 | self.position_encoding = LearnedPositionalEncoding(max_position_embeddings=num_patches,
182 | embedding_dim=embedding_dims,
183 | seq_length=num_patches)
184 | elif positional_encoding_type == "fixed":
185 | self.position_encoding = FixedPositionalEncoding(embedding_dim=embedding_dims)
186 |
187 | self.encoders = TransformerModel(dim=embedding_dims,
188 | depth=num_encoders,
189 | heads=8,
190 | mlp_dim=embedding_dims,
191 | dropout_rate=0.1,
192 | attn_dropout_rate=0.0)
193 |
194 | def forward(self, x):
195 | feat = x[-1]
196 |
197 | n, c, h, w = feat.size()
198 | feat = feat.view(n, c, h * w).transpose(1, 2)
199 | feat = self.linear_encoding(feat)
200 | feat = self.position_encoding(feat)
201 | feat = self.encoders(feat)
202 | feat = feat.transpose(1, 2)
203 | feat = feat.view(n, c, h, w)
204 |
205 | x[-1] = feat
206 | return x
207 |
--------------------------------------------------------------------------------
/mctrans/models/decoders/__init__.py:
--------------------------------------------------------------------------------
1 | from .unet_decoder import UNetDecoder
2 | from .unet_plus_plus_decoder import UNetPlusPlusDecoder
3 |
--------------------------------------------------------------------------------
/mctrans/models/decoders/unet_decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from ..builder import DECODERS
5 | from ..utils import conv_bn_relu
6 |
7 |
8 | class AttBlock(nn.Module):
9 | def __init__(self, F_g, F_l, F_int):
10 | super(AttBlock, self).__init__()
11 | self.W_g = nn.Sequential(
12 | nn.Conv2d(in_channels=F_g,
13 | out_channels=F_int,
14 | kernel_size=1,
15 | stride=1,
16 | padding=0,
17 | bias=True),
18 | nn.BatchNorm2d(F_int)
19 | )
20 |
21 | self.W_x = nn.Sequential(
22 | nn.Conv2d(in_channels=F_l,
23 | out_channels=F_int,
24 | kernel_size=1,
25 | stride=1,
26 | padding=0,
27 | bias=True),
28 | nn.BatchNorm2d(F_int)
29 | )
30 |
31 | self.psi = nn.Sequential(
32 | nn.Conv2d(in_channels=F_int,
33 | out_channels=1,
34 | kernel_size=1,
35 | stride=1,
36 | padding=0,
37 | bias=True),
38 | nn.BatchNorm2d(1),
39 | nn.Sigmoid()
40 | )
41 |
42 | self.relu = nn.ReLU(inplace=True)
43 |
44 | def forward(self, g, x):
45 | g1 = self.W_g(g)
46 | x1 = self.W_x(x)
47 | psi = self.relu(g1 + x1)
48 | psi = self.psi(psi)
49 |
50 | return x * psi
51 |
52 |
53 | class DecBlock(nn.Module):
54 | def __init__(
55 | self,
56 | in_channels,
57 | skip_channels,
58 | out_channels,
59 | attention=False
60 | ):
61 | super().__init__()
62 | self.conv1 = conv_bn_relu(in_channels=in_channels + skip_channels,
63 | out_channels=out_channels)
64 |
65 | self.conv2 = conv_bn_relu(in_channels=out_channels,
66 | out_channels=out_channels)
67 |
68 | self.up = nn.Upsample(scale_factor=2,
69 | mode='bilinear',
70 | align_corners=True)
71 |
72 | if attention:
73 | self.att = AttBlock(F_g=in_channels, F_l=skip_channels, F_int=in_channels)
74 |
75 | def forward(self, x, skip=None):
76 | x = self.up(x)
77 | if skip is not None:
78 | if hasattr(self, "att"):
79 | skip = self.att(g=x, x=skip)
80 | x = torch.cat([x, skip], dim=1)
81 | x = self.conv1(x)
82 | x = self.conv2(x)
83 | return x
84 |
85 |
86 | @DECODERS.register_module()
87 | class UNetDecoder(nn.Module):
88 | def __init__(
89 | self,
90 | in_channels,
91 | att=False
92 | ):
93 | super().__init__()
94 | self.decoders = nn.ModuleList()
95 | in_channels = in_channels[::-1]
96 | skip_channels = in_channels[1:]
97 | for in_c, skip_c in zip(in_channels, skip_channels):
98 | self.decoders.append(DecBlock(in_c, skip_c, skip_c, att))
99 |
100 | def forward(self, features):
101 | features = features[::-1]
102 | x = features[0]
103 | skips = features[1:]
104 |
105 | for i, layer in enumerate(self.decoders):
106 | x = layer(x, skips[i])
107 |
108 | return x
109 |
110 | def init_weights(self):
111 | pass
112 |
113 |
--------------------------------------------------------------------------------
/mctrans/models/decoders/unet_plus_plus_decoder.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from ..utils import conv_bn_relu
7 | from ..builder import DECODERS
8 |
9 |
10 | class AttBlock(nn.Module):
11 | def __init__(self, F_g, F_l, F_int):
12 | super(AttBlock, self).__init__()
13 | self.W_g = nn.Sequential(
14 | nn.Conv2d(in_channels=F_g,
15 | out_channels=F_int,
16 | kernel_size=1,
17 | stride=1,
18 | padding=0,
19 | bias=True),
20 | nn.BatchNorm2d(F_int)
21 | )
22 |
23 | self.W_x = nn.Sequential(
24 | nn.Conv2d(in_channels=F_l,
25 | out_channels=F_int,
26 | kernel_size=1,
27 | stride=1,
28 | padding=0,
29 | bias=True),
30 | nn.BatchNorm2d(F_int)
31 | )
32 |
33 | self.psi = nn.Sequential(
34 | nn.Conv2d(in_channels=F_int,
35 | out_channels=1,
36 | kernel_size=1,
37 | stride=1,
38 | padding=0,
39 | bias=True),
40 | nn.BatchNorm2d(1),
41 | nn.Sigmoid()
42 | )
43 |
44 | self.relu = nn.ReLU(inplace=True)
45 |
46 | def forward(self, g, x):
47 | g1 = self.W_g(g)
48 | x1 = self.W_x(x)
49 | psi = self.relu(g1 + x1)
50 | psi = self.psi(psi)
51 |
52 | return x * psi
53 |
54 |
55 | class DecBlock(nn.Module):
56 | def __init__(
57 | self,
58 | in_channels,
59 | skip_channels,
60 | out_channels,
61 | attention=False
62 | ):
63 | super().__init__()
64 | self.conv1 = conv_bn_relu(in_channels=in_channels + skip_channels,
65 | out_channels=out_channels)
66 |
67 | self.conv2 = conv_bn_relu(in_channels=out_channels,
68 | out_channels=out_channels)
69 |
70 | self.up = nn.Upsample(scale_factor=2,
71 | mode='bilinear',
72 | align_corners=True)
73 |
74 | if attention:
75 | self.att = AttBlock(F_g=in_channels, F_l=skip_channels, F_int=in_channels)
76 |
77 | def forward(self, x, skip=None):
78 | x = self.up(x)
79 | if skip is not None:
80 | if hasattr(self, "att"):
81 | skip = self.att(g=x, x=skip)
82 | x = torch.cat([x, skip], dim=1)
83 | x = self.conv1(x)
84 | x = self.conv2(x)
85 | return x
86 |
87 |
88 | @DECODERS.register_module()
89 | class UNetPlusPlusDecoder(nn.Module):
90 | def __init__(
91 | self,
92 | in_channels,
93 | ):
94 | super().__init__()
95 |
96 | self.decoder_layers = nn.ModuleList()
97 | self.in_channels = in_channels
98 | skip_channels = in_channels[:-1]
99 |
100 | blocks = {}
101 | for stage_idx in range(1, len(self.in_channels)):
102 | for lvl_idx in range(len(self.in_channels) - stage_idx):
103 | in_ch = self.in_channels[lvl_idx + 1]
104 | skip_ch = skip_channels[lvl_idx] * (stage_idx)
105 | out_ch = self.in_channels[lvl_idx]
106 | blocks[f'x_{lvl_idx}_{stage_idx}'] = DecBlock(in_ch, skip_ch, out_ch, False)
107 |
108 | self.blocks = nn.ModuleDict(blocks)
109 |
110 | def forward(self, features):
111 | dense_x = OrderedDict()
112 | for idx, item in enumerate(features):
113 | dense_x[f'x_{idx}_{0}'] = features[idx]
114 |
115 | for stage_idx in range(1, len(self.in_channels)):
116 | for lvl_idx in range(len(self.in_channels) - stage_idx):
117 | skip_features = [dense_x[f'x_{lvl_idx}_{idx}'] for idx in range(stage_idx)]
118 | skip_features = torch.cat(skip_features, dim=1)
119 | output = self.blocks[f'x_{lvl_idx}_{stage_idx}'](dense_x[f'x_{lvl_idx + 1}_{stage_idx - 1}'],
120 | skip_features)
121 | dense_x[f'x_{lvl_idx}_{stage_idx}'] = output
122 |
123 | return dense_x[next(reversed(dense_x))]
124 |
125 | def init_weights(self):
126 | pass
127 |
128 |
--------------------------------------------------------------------------------
/mctrans/models/encoders/__init__.py:
--------------------------------------------------------------------------------
1 | from .vgg import VGG
2 | from .resnet import ResNet
3 |
4 |
--------------------------------------------------------------------------------
/mctrans/models/encoders/vgg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from mmcv.cnn import kaiming_init, constant_init
4 | from torch.nn.modules.batchnorm import _BatchNorm
5 |
6 | from ..builder import ENCODERS
7 | from ..utils import make_vgg_layer
8 |
9 |
10 | @ENCODERS.register_module()
11 | class VGG(nn.Module):
12 | def __init__(self, in_channel=1, depth=5, init_channels=16, num_blocks=2):
13 | super(VGG, self).__init__()
14 | filters = [(2 ** i) * init_channels for i in range(depth)]
15 | self.out_channels = filters.copy()
16 |
17 | filters.insert(0, in_channel)
18 | self.stages = nn.ModuleList()
19 |
20 | for idx in range(depth):
21 | down_sample = False if idx == 0 else True
22 | self.stages.append(make_vgg_layer(inplanes=filters[idx],
23 | planes=filters[idx + 1],
24 | num_blocks=num_blocks,
25 | with_bn=True,
26 | down_sample=down_sample))
27 |
28 | def forward(self, x):
29 |
30 | features = []
31 | for stage in self.stages:
32 | x = stage(x)
33 | features.append(x)
34 | return features
35 |
36 | def init_weights(self, pretrained=None):
37 | pass
38 | # for m in self.modules():
39 | # if isinstance(m, nn.Conv2d):
40 | # kaiming_init(m)
41 | # elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
42 | # constant_init(m, 1)
43 |
--------------------------------------------------------------------------------
/mctrans/models/heads/__init__.py:
--------------------------------------------------------------------------------
1 | from .basic_seg_head import BasicSegHead
2 | from .mctrans_aux_head import MCTransAuxHead
3 |
4 | __all__ = ["BasicSegHead", "MCTransAuxHead"]
--------------------------------------------------------------------------------
/mctrans/models/heads/basic_seg_head.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from mmcv.cnn import normal_init
3 |
4 | from ..builder import HEADS, build_losses
5 | from ...data.transforms import build_transforms
6 |
7 |
8 | @HEADS.register_module()
9 | class BasicSegHead(nn.Module):
10 | def __init__(self, in_channels, num_classes, kernel_size=1, post_trans=None, losses=None):
11 | super(BasicSegHead, self).__init__()
12 | self.head = nn.Conv2d(in_channels=in_channels, out_channels=num_classes, kernel_size=kernel_size)
13 | self.post_trans = build_transforms(post_trans)
14 | self.losses = build_losses(losses)
15 |
16 | def forward_train(self, inputs, seg_label, **kwargs):
17 | logits = self.head(inputs)
18 | losses = dict()
19 | for _loss in self.losses:
20 | losses[_loss.__class__.__name__] = _loss(logits, seg_label)
21 | return losses
22 |
23 | def forward_test(self, inputs, **kwargs):
24 | logits = self.head(inputs)
25 | preds = self.post_trans(logits)
26 | return preds
27 |
28 | def init_weights(self):
29 | pass
30 | # normal_init(self.head, mean=0, std=0.01)
31 |
--------------------------------------------------------------------------------
/mctrans/models/heads/mctrans_aux_head.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from mmcv.cnn import normal_init
4 |
5 | from ..builder import HEADS, build_losses
6 | from ..trans.transformer import CALayer, CA
7 |
8 |
9 | @HEADS.register_module()
10 | class MCTransAuxHead(nn.Module):
11 | def __init__(self,
12 | d_model=128,
13 | d_ffn=1024,
14 | dropout=0.1,
15 | act="ReLu",
16 | n_head=8,
17 | n_layers=4,
18 | num_classes=6,
19 | in_channles=[64, 64, 128, 256, 512],
20 | proj_idxs=(2, 3, 4),
21 | losses=None
22 | ):
23 | super(MCTransAuxHead, self).__init__()
24 | self.in_channles = in_channles
25 | self.proj_idxs = proj_idxs
26 |
27 | ca_layer = CALayer(d_model=d_model,
28 | d_ffn=d_ffn,
29 | dropout=dropout,
30 | activation=act,
31 | n_heads=n_head)
32 |
33 | self.ca = CA(att_layer=ca_layer,
34 | n_layers=n_layers,
35 | n_category=num_classes,
36 | d_model=d_model)
37 |
38 | self.head = nn.Sequential(nn.Linear(num_classes*d_model, d_model),
39 | nn.Linear(d_model, num_classes))
40 |
41 | self.losses = build_losses(losses)
42 |
43 | def forward_train(self, inputs, seg_label, **kwargs):
44 | # flatten
45 | inputs = [inputs[idx] for idx in self.proj_idxs]
46 | inputs_flatten = [item.flatten(2).transpose(1, 2) for item in inputs]
47 | inputs_flatten = torch.cat(inputs_flatten, 1)
48 | # ca
49 | outputs = self.ca(inputs_flatten)
50 | logits = self.head(outputs.flatten(1))
51 | losses = dict()
52 | for _loss in self.losses:
53 | losses[_loss.__class__.__name__] = _loss(logits, seg_label)
54 | return losses
55 |
56 | def init_weights(self):
57 | normal_init(self.head, mean=0, std=0.01)
58 |
--------------------------------------------------------------------------------
/mctrans/models/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from .cross_entropy_loss import MCTransAuxLoss
2 | from .monai import DiceLoss, DiceCELoss, DiceFocalLoss
3 | from .debug_focal import FLoss
4 | __all__ = ["MCTransAuxLoss", "DiceLoss", "DiceCELoss", "DiceFocalLoss"]
--------------------------------------------------------------------------------
/mctrans/models/losses/cross_entropy_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from ..builder import LOSSES
6 |
7 |
8 | def reduce_loss(loss, reduction):
9 | """Reduce loss as specified.
10 |
11 | Args:
12 | loss (Tensor): Elementwise loss tensor.
13 | reduction (str): Options are "none", "mean" and "sum".
14 |
15 | Return:
16 | Tensor: Reduced loss tensor.
17 | """
18 | reduction_enum = F._Reduction.get_enum(reduction)
19 | # none: 0, elementwise_mean:1, sum: 2
20 | if reduction_enum == 0:
21 | return loss
22 | elif reduction_enum == 1:
23 | return loss.mean()
24 | elif reduction_enum == 2:
25 | return loss.sum()
26 |
27 |
28 | def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
29 | """Apply element-wise weight and reduce loss.
30 |
31 | Args:
32 | loss (Tensor): Element-wise loss.
33 | weight (Tensor): Element-wise weights.
34 | reduction (str): Same as built-in losses of PyTorch.
35 | avg_factor (float): Avarage factor when computing the mean of losses.
36 |
37 | Returns:
38 | Tensor: Processed loss values.
39 | """
40 | # if weight is specified, apply element-wise weight
41 | if weight is not None:
42 | loss = loss * weight
43 |
44 | # if avg_factor is not specified, just reduce the loss
45 | if avg_factor is None:
46 | loss = reduce_loss(loss, reduction)
47 | else:
48 | # if reduction is mean, then average the loss by avg_factor
49 | if reduction == 'mean':
50 | loss = loss.sum() / avg_factor
51 | # if reduction is 'none', then do nothing, otherwise raise an error
52 | elif reduction != 'none':
53 | raise ValueError('avg_factor can not be used with reduction="sum"')
54 | return loss
55 |
56 |
57 | def cross_entropy(pred, label, weight=None, reduction='mean', avg_factor=None):
58 | """Calculate the CrossEntropy loss.
59 |
60 | Args:
61 | pred (torch.Tensor): The prediction with shape (N, C), C is the number
62 | of classes.
63 | label (torch.Tensor): The gt label of the prediction.
64 | weight (torch.Tensor, optional): Sample-wise loss weight.
65 | reduction (str): The method used to reduce the loss.
66 | avg_factor (int, optional): Average factor that is used to average
67 | the loss. Defaults to None.
68 |
69 | Returns:
70 | torch.Tensor: The calculated loss
71 | """
72 | # element-wise losses
73 | loss = F.cross_entropy(pred, label, reduction='none')
74 |
75 | # apply weights and do the reduction
76 | if weight is not None:
77 | weight = weight.float()
78 | loss = weight_reduce_loss(
79 | loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
80 |
81 | return loss
82 |
83 |
84 | def soft_cross_entropy(pred,
85 | label,
86 | weight=None,
87 | reduction='mean',
88 | avg_factor=None):
89 | """Calculate the Soft CrossEntropy loss. The label can be float.
90 |
91 | Args:
92 | pred (torch.Tensor): The prediction with shape (N, C), C is the number
93 | of classes.
94 | label (torch.Tensor): The gt label of the prediction with shape (N, C).
95 | When using "mixup", the label can be float.
96 | weight (torch.Tensor, optional): Sample-wise loss weight.
97 | reduction (str): The method used to reduce the loss.
98 | avg_factor (int, optional): Average factor that is used to average
99 | the loss. Defaults to None.
100 |
101 | Returns:
102 | torch.Tensor: The calculated loss
103 | """
104 | # element-wise losses
105 | loss = -label * F.log_softmax(pred, dim=-1)
106 | loss = loss.sum(dim=-1)
107 |
108 | # apply weights and do the reduction
109 | if weight is not None:
110 | weight = weight.float()
111 | loss = weight_reduce_loss(
112 | loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
113 |
114 | return loss
115 |
116 |
117 | def binary_cross_entropy(pred,
118 | label,
119 | weight=None,
120 | reduction='mean',
121 | avg_factor=None):
122 | """Calculate the binary CrossEntropy loss with logits.
123 |
124 | Args:
125 | pred (torch.Tensor): The prediction with shape (N, *).
126 | label (torch.Tensor): The gt label with shape (N, *).
127 | weight (torch.Tensor, optional): Element-wise weight of loss with shape
128 | (N, ). Defaults to None.
129 | reduction (str): The method used to reduce the loss.
130 | Options are "none", "mean" and "sum". If reduction is 'none' , loss
131 | is same shape as pred and label. Defaults to 'mean'.
132 | avg_factor (int, optional): Average factor that is used to average
133 | the loss. Defaults to None.
134 |
135 | Returns:
136 | torch.Tensor: The calculated loss
137 | """
138 | assert pred.dim() == label.dim()
139 |
140 | loss = F.binary_cross_entropy_with_logits(pred, label, reduction='none')
141 |
142 | # apply weights and do the reduction
143 | if weight is not None:
144 | assert weight.dim() == 1
145 | weight = weight.float()
146 | if pred.dim() > 1:
147 | weight = weight.reshape(-1, 1)
148 | loss = weight_reduce_loss(
149 | loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
150 | return loss
151 |
152 | @LOSSES.register_module()
153 | class CELoss(nn.Module):
154 |
155 | def __init__(self, reduction='mean', loss_weight=1.0):
156 | super(CELoss, self).__init__()
157 | self.reduction = reduction
158 | self.loss_weight = loss_weight
159 |
160 | self.cls_criterion = cross_entropy
161 |
162 | def forward(self,
163 | cls_score,
164 | label,
165 | weight=None,
166 | avg_factor=None,
167 | reduction_override=None,
168 | **kwargs):
169 | assert reduction_override in (None, 'none', 'mean', 'sum')
170 | reduction = (
171 | reduction_override if reduction_override else self.reduction)
172 |
173 | n_pred_ch, n_target_ch = cls_score.shape[1], label.shape[1]
174 | if n_pred_ch == n_target_ch:
175 | # target is in the one-hot format, convert to BH[WD] format to calculate ce loss
176 | label = torch.argmax(label, dim=1)
177 | else:
178 | label = torch.squeeze(label, dim=1)
179 | label = label.long()
180 | loss_cls = self.loss_weight * self.cls_criterion(
181 | cls_score,
182 | label,
183 | weight,
184 | reduction=reduction,
185 | avg_factor=avg_factor,
186 | **kwargs)
187 | return loss_cls
188 |
189 |
190 | @LOSSES.register_module()
191 | class CrossEntropyLoss(nn.Module):
192 | """Cross entropy loss.
193 |
194 | Args:
195 | use_sigmoid (bool): Whether the prediction uses sigmoid
196 | of softmax. Defaults to False.
197 | use_soft (bool): Whether to use the soft version of CrossEntropyLoss.
198 | Defaults to False.
199 | reduction (str): The method used to reduce the loss.
200 | Options are "none", "mean" and "sum". Defaults to 'mean'.
201 | loss_weight (float): Weight of the loss. Defaults to 1.0.
202 | """
203 |
204 | def __init__(self,
205 | sigmoid=False,
206 | softmax=False,
207 | reduction='mean',
208 | loss_weight=1.0):
209 | super(CrossEntropyLoss, self).__init__()
210 | self.use_sigmoid = sigmoid
211 | self.use_soft = softmax
212 | assert not (
213 | self.use_soft and self.use_sigmoid
214 | ), 'use_sigmoid and use_soft could not be set simultaneously'
215 |
216 | self.reduction = reduction
217 | self.loss_weight = loss_weight
218 |
219 | if self.use_sigmoid:
220 | self.cls_criterion = binary_cross_entropy
221 | elif self.use_soft:
222 | self.cls_criterion = soft_cross_entropy
223 | else:
224 | self.cls_criterion = cross_entropy
225 |
226 | def forward(self,
227 | cls_score,
228 | label,
229 | weight=None,
230 | avg_factor=None,
231 | reduction_override=None,
232 | **kwargs):
233 | assert reduction_override in (None, 'none', 'mean', 'sum')
234 | reduction = (
235 | reduction_override if reduction_override else self.reduction)
236 |
237 | n_pred_ch, n_target_ch = cls_score.shape[1], label.shape[1]
238 | if n_pred_ch == n_target_ch:
239 | label = torch.argmax(label, dim=1)
240 | else:
241 | label = torch.squeeze(label, dim=1)
242 | label = label.long()
243 |
244 | loss_cls = self.loss_weight * self.cls_criterion(
245 | cls_score,
246 | label,
247 | weight,
248 | reduction=reduction,
249 | avg_factor=avg_factor,
250 | **kwargs)
251 | return loss_cls
252 |
253 |
254 | @LOSSES.register_module()
255 | class MCTransAuxLoss(CrossEntropyLoss):
256 | def __init__(self,**kwargs):
257 | super(MCTransAuxLoss, self).__init__(**kwargs)
258 |
259 | def forward(self,
260 | cls_score,
261 | label,
262 | weight=None,
263 | avg_factor=None,
264 | reduction_override=None,
265 | **kwargs):
266 | assert reduction_override in (None, 'none', 'mean', 'sum')
267 | #To one hot
268 | num_classes = cls_score.shape[1]
269 | one_hot = []
270 | for l in label:
271 | one_hot.append(self.one_hot(torch.unique(l), num_classes=num_classes).sum(dim=0))
272 | label = torch.stack(one_hot)
273 |
274 | reduction = (
275 | reduction_override if reduction_override else self.reduction)
276 | loss_cls = self.loss_weight * self.cls_criterion(
277 | cls_score,
278 | label,
279 | weight,
280 | reduction=reduction,
281 | avg_factor=avg_factor,
282 | **kwargs)
283 | return loss_cls
284 |
285 | def one_hot(self, input, num_classes, dtype=torch.float):
286 | assert input.dim() > 0, "input should have dim of 1 or more."
287 |
288 | # if 1D, add singelton dim at the end
289 | if input.dim() == 1:
290 | input = input.view(-1, 1)
291 |
292 | sh = list(input.shape)
293 |
294 | assert sh[1] == 1, "labels should have a channel with length equals to one."
295 | sh[1] = num_classes
296 |
297 | o = torch.zeros(size=sh, dtype=dtype, device=input.device)
298 | labels = o.scatter_(dim=1, index=input.long(), value=1)
299 |
300 | return labels
--------------------------------------------------------------------------------
/mctrans/models/losses/debug_focal.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 - 2021 MONAI Consortium
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | # Unless required by applicable law or agreed to in writing, software
7 | # distributed under the License is distributed on an "AS IS" BASIS,
8 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | # See the License for the specific language governing permissions and
10 | # limitations under the License.
11 |
12 | from typing import Optional, Union
13 |
14 | import torch
15 | import torch.nn.functional as F
16 | from torch.nn.modules.loss import _Loss
17 | from torch import Tensor
18 | from monai.utils import LossReduction
19 |
20 | from ..builder import LOSSES
21 |
22 | class _WeightedLoss(_Loss):
23 | def __init__(self, weight: Optional[Tensor] = None, size_average=None, reduce=None, reduction: str = 'mean') -> None:
24 | super(_WeightedLoss, self).__init__(size_average, reduce, reduction)
25 | self.register_buffer('weight', weight)
26 | @LOSSES.register_module()
27 | class FLoss(_WeightedLoss):
28 | """
29 | Reimplementation of the Focal Loss described in:
30 |
31 | - "Focal Loss for Dense Object Detection", T. Lin et al., ICCV 2017
32 | - "AnatomyNet: Deep learning for fast and fully automated whole‐volume segmentation of head and neck anatomy",
33 | Zhu et al., Medical Physics 2018
34 | """
35 |
36 | def __init__(
37 | self,
38 | gamma: float = 2.0,
39 | weight: Optional[torch.Tensor] = None,
40 | reduction: Union[LossReduction, str] = LossReduction.MEAN,
41 | ) -> None:
42 | """
43 | Args:
44 | gamma: value of the exponent gamma in the definition of the Focal loss.
45 | weight: weights to apply to the voxels of each class. If None no weights are applied.
46 | This corresponds to the weights `\alpha` in [1].
47 | reduction: {``"none"``, ``"mean"``, ``"sum"``}
48 | Specifies the reduction to apply to the output. Defaults to ``"mean"``.
49 |
50 | - ``"none"``: no reduction will be applied.
51 | - ``"mean"``: the sum of the output will be divided by the number of elements in the output.
52 | - ``"sum"``: the output will be summed.
53 |
54 | Example:
55 | .. code-block:: python
56 |
57 | import torch
58 | from monai.losses import FocalLoss
59 |
60 | pred = torch.tensor([[1, 0], [0, 1], [1, 0]], dtype=torch.float32)
61 | grnd = torch.tensor([[0], [1], [0]], dtype=torch.int64)
62 | fl = FocalLoss()
63 | fl(pred, grnd)
64 |
65 | """
66 | super(FLoss, self).__init__(weight=weight, reduction=LossReduction(reduction).value)
67 | self.gamma = gamma
68 | self.weight: Optional[torch.Tensor] = None
69 |
70 | def forward(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
71 | """
72 | Args:
73 | logits: the shape should be BCH[WD].
74 | where C (greater than 1) is the number of classes.
75 | Softmax over the logits is integrated in this module for improved numerical stability.
76 | target: the shape should be B1H[WD] or BCH[WD].
77 | If the target's shape is B1H[WD], the target that this loss expects should be a class index
78 | in the range [0, C-1] where C is the number of classes.
79 |
80 | Raises:
81 | ValueError: When ``target`` ndim differs from ``logits``.
82 | ValueError: When ``target`` channel is not 1 and ``target`` shape differs from ``logits``.
83 | ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
84 |
85 | """
86 | i = logits
87 | t = target
88 |
89 | if i.ndimension() != t.ndimension():
90 | raise ValueError(f"logits and target ndim must match, got logits={i.ndimension()} target={t.ndimension()}.")
91 |
92 | if t.shape[1] != 1 and t.shape[1] != i.shape[1]:
93 | raise ValueError(
94 | "target must have one channel or have the same shape as the logits. "
95 | "If it has one channel, it should be a class index in the range [0, C-1] "
96 | f"where C is the number of classes inferred from 'logits': C={i.shape[1]}. "
97 | )
98 | if i.shape[1] == 1:
99 | raise NotImplementedError("Single-channel predictions not supported.")
100 |
101 | # Change the shape of logits and target to
102 | # num_batch x num_class x num_voxels.
103 | if i.dim() > 2:
104 | i = i.view(i.size(0), i.size(1), -1) # N,C,H,W => N,C,H*W
105 | t = t.view(t.size(0), t.size(1), -1) # N,1,H,W => N,1,H*W or N,C,H*W
106 | else: # Compatibility with classification.
107 | i = i.unsqueeze(2) # N,C => N,C,1
108 | t = t.unsqueeze(2) # N,1 => N,1,1 or N,C,1
109 |
110 | # Compute the log proba (more stable numerically than softmax).
111 | logpt = F.log_softmax(i, dim=1) # N,C,H*W
112 | # Keep only log proba values of the ground truth class for each voxel.
113 | if target.shape[1] == 1:
114 | logpt = logpt.gather(1, t.long()) # N,C,H*W => N,1,H*W
115 | logpt = torch.squeeze(logpt, dim=1) # N,1,H*W => N,H*W
116 |
117 | # Get the proba
118 | pt = torch.exp(logpt) # N,H*W or N,C,H*W
119 |
120 | if self.weight is not None:
121 | self.weight = self.weight.to(i)
122 | # Convert the weight to a map in which each voxel
123 | # has the weight associated with the ground-truth label
124 | # associated with this voxel in target.
125 | at = self.weight[None, :, None] # C => 1,C,1
126 | at = at.expand((t.size(0), -1, t.size(2))) # 1,C,1 => N,C,H*W
127 | if target.shape[1] == 1:
128 | at = at.gather(1, t.long()) # selection of the weights => N,1,H*W
129 | at = torch.squeeze(at, dim=1) # N,1,H*W => N,H*W
130 | # Multiply the log proba by their weights.
131 | logpt = logpt * at
132 |
133 | # Compute the loss mini-batch.
134 | weight = torch.pow(-pt + 1.0, self.gamma)
135 | if target.shape[1] == 1:
136 | loss = torch.mean(-weight * logpt, dim=1) # N
137 | else:
138 | loss = torch.mean(-weight * t * logpt, dim=-1) # N,C
139 |
140 | if self.reduction == LossReduction.SUM.value:
141 | return loss.sum()
142 | if self.reduction == LossReduction.NONE.value:
143 | return loss
144 | if self.reduction == LossReduction.MEAN.value:
145 | return loss.mean()
146 | raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')
147 |
--------------------------------------------------------------------------------
/mctrans/models/losses/monai.py:
--------------------------------------------------------------------------------
1 | from ..builder import LOSSES
2 | from monai.losses import DiceLoss, DiceCELoss, FocalLoss, DiceFocalLoss
3 |
4 |
5 | class DiceLoss(DiceLoss):
6 | def __init__(self, loss_weight=1.0, **kwargs):
7 | self.loss_weight = loss_weight
8 | super(DiceLoss, self).__init__(**kwargs)
9 |
10 | def forward(self, input, target):
11 | loss = self.loss_weight * super().forward(input=input, target=target)
12 | return loss
13 |
14 |
15 | class DiceCELoss(DiceCELoss):
16 | def __init__(self, loss_weight=1.0, **kwargs):
17 | self.loss_weight = loss_weight
18 | super(DiceCELoss, self).__init__(**kwargs)
19 |
20 | def forward(self, input, target):
21 | loss = self.loss_weight * super().forward(input=input, target=target)
22 | return loss
23 |
24 |
25 | LOSSES.register_module(name="DiceLoss", module=DiceLoss)
26 | LOSSES.register_module(name="FocalLoss", module=FocalLoss)
27 |
28 | LOSSES.register_module(name="DiceCELoss", module=DiceCELoss)
29 | LOSSES.register_module(name="DiceFocalLoss", module=DiceFocalLoss)
30 |
--------------------------------------------------------------------------------
/mctrans/models/ops/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiYuanFeng/MCTrans/9b8b5677eef584b423d5e1630680a4b667cbe823/mctrans/models/ops/__init__.py
--------------------------------------------------------------------------------
/mctrans/models/ops/functions/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | from .ms_deform_attn_func import MSDeformAttnFunction
10 |
11 |
--------------------------------------------------------------------------------
/mctrans/models/ops/functions/ms_deform_attn_func.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | from __future__ import absolute_import
10 | from __future__ import print_function
11 | from __future__ import division
12 |
13 | import torch
14 | import torch.nn.functional as F
15 | from torch.autograd import Function
16 | from torch.autograd.function import once_differentiable
17 |
18 | import MultiScaleDeformableAttention as MSDA
19 |
20 |
21 | class MSDeformAttnFunction(Function):
22 | @staticmethod
23 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
24 | ctx.im2col_step = im2col_step
25 | output = MSDA.ms_deform_attn_forward(
26 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
27 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)
28 | return output
29 |
30 | @staticmethod
31 | @once_differentiable
32 | def backward(ctx, grad_output):
33 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
34 | grad_value, grad_sampling_loc, grad_attn_weight = \
35 | MSDA.ms_deform_attn_backward(
36 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)
37 |
38 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
39 |
40 |
41 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
42 | # for debug and test only,
43 | # need to use cuda version instead
44 | N_, S_, M_, D_ = value.shape
45 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape
46 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
47 | sampling_grids = 2 * sampling_locations - 1
48 | sampling_value_list = []
49 | for lid_, (H_, W_) in enumerate(value_spatial_shapes):
50 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
51 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
52 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
53 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
54 | # N_*M_, D_, Lq_, P_
55 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
56 | mode='bilinear', padding_mode='zeros', align_corners=False)
57 | sampling_value_list.append(sampling_value_l_)
58 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
59 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
60 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
61 | return output.transpose(1, 2).contiguous()
--------------------------------------------------------------------------------
/mctrans/models/ops/make.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # ------------------------------------------------------------------------------------------------
3 | # Deformable DETR
4 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | # ------------------------------------------------------------------------------------------------
7 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | # ------------------------------------------------------------------------------------------------
9 |
10 | python setup.py build install
11 |
--------------------------------------------------------------------------------
/mctrans/models/ops/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | from .ms_deform_attn import MSDeformAttn
10 |
--------------------------------------------------------------------------------
/mctrans/models/ops/modules/ms_deform_attn.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | from __future__ import absolute_import
10 | from __future__ import print_function
11 | from __future__ import division
12 |
13 | import warnings
14 | import math
15 |
16 | import torch
17 | from torch import nn
18 | import torch.nn.functional as F
19 | from torch.nn.init import xavier_uniform_, constant_
20 |
21 | from ..functions import MSDeformAttnFunction
22 |
23 |
24 | def _is_power_of_2(n):
25 | if (not isinstance(n, int)) or (n < 0):
26 | raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
27 | return (n & (n-1) == 0) and n != 0
28 |
29 |
30 | class MSDeformAttn(nn.Module):
31 | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
32 | """
33 | Multi-Scale Deformable Attention Module
34 | :param d_model hidden dimension
35 | :param n_levels number of feature levels
36 | :param n_heads number of attention heads
37 | :param n_points number of sampling points per attention head per feature level
38 | """
39 | super().__init__()
40 | if d_model % n_heads != 0:
41 | raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
42 | _d_per_head = d_model // n_heads
43 | # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
44 | if not _is_power_of_2(_d_per_head):
45 | warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
46 | "which is more efficient in our CUDA implementation.")
47 |
48 | self.im2col_step = 64
49 |
50 | self.d_model = d_model
51 | self.n_levels = n_levels
52 | self.n_heads = n_heads
53 | self.n_points = n_points
54 |
55 | self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
56 | self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
57 | self.value_proj = nn.Linear(d_model, d_model)
58 | self.output_proj = nn.Linear(d_model, d_model)
59 |
60 | self._reset_parameters()
61 |
62 | def _reset_parameters(self):
63 | constant_(self.sampling_offsets.weight.data, 0.)
64 | thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
65 | grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
66 | grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
67 | for i in range(self.n_points):
68 | grid_init[:, :, i, :] *= i + 1
69 | with torch.no_grad():
70 | self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
71 | constant_(self.attention_weights.weight.data, 0.)
72 | constant_(self.attention_weights.bias.data, 0.)
73 | xavier_uniform_(self.value_proj.weight.data)
74 | constant_(self.value_proj.bias.data, 0.)
75 | xavier_uniform_(self.output_proj.weight.data)
76 | constant_(self.output_proj.bias.data, 0.)
77 |
78 | def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
79 | """
80 | :param query (N, Length_{query}, C)
81 | :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
82 | or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
83 | :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
84 | :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
85 | :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
86 | :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
87 |
88 | :return output (N, Length_{query}, C)
89 | """
90 | N, Len_q, _ = query.shape
91 | N, Len_in, _ = input_flatten.shape
92 | assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
93 |
94 | value = self.value_proj(input_flatten)
95 | if input_padding_mask is not None:
96 | value = value.masked_fill(input_padding_mask[..., None], float(0))
97 | value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
98 | sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
99 | attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
100 | attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
101 | # N, Len_q, n_heads, n_levels, n_points, 2
102 | if reference_points.shape[-1] == 2:
103 | offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
104 | sampling_locations = reference_points[:, :, None, :, None, :] \
105 | + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
106 | elif reference_points.shape[-1] == 4:
107 | sampling_locations = reference_points[:, :, None, :, None, :2] \
108 | + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
109 | else:
110 | raise ValueError(
111 | 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
112 | output = MSDeformAttnFunction.apply(
113 | value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
114 | output = self.output_proj(output)
115 | return output
116 |
--------------------------------------------------------------------------------
/mctrans/models/ops/setup.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | import os
10 | import glob
11 |
12 | import torch
13 |
14 | from torch.utils.cpp_extension import CUDA_HOME
15 | from torch.utils.cpp_extension import CppExtension
16 | from torch.utils.cpp_extension import CUDAExtension
17 |
18 | from setuptools import find_packages
19 | from setuptools import setup
20 |
21 | requirements = ["torch", "torchvision"]
22 |
23 | def get_extensions():
24 | this_dir = os.path.dirname(os.path.abspath(__file__))
25 | extensions_dir = os.path.join(this_dir, "src")
26 |
27 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
28 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
29 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
30 |
31 | sources = main_file + source_cpu
32 | extension = CppExtension
33 | extra_compile_args = {"cxx": []}
34 | define_macros = []
35 |
36 | if torch.cuda.is_available() and CUDA_HOME is not None:
37 | extension = CUDAExtension
38 | sources += source_cuda
39 | define_macros += [("WITH_CUDA", None)]
40 | extra_compile_args["nvcc"] = [
41 | "-DCUDA_HAS_FP16=0",
42 | "-D__CUDA_NO_HALF_OPERATORS__",
43 | "-D__CUDA_NO_HALF_CONVERSIONS__",
44 | "-D__CUDA_NO_HALF2_OPERATORS__",
45 | ]
46 | else:
47 | raise NotImplementedError('Cuda is not availabel')
48 |
49 | sources = [os.path.join(extensions_dir, s) for s in sources]
50 | include_dirs = [extensions_dir]
51 | ext_modules = [
52 | extension(
53 | "MultiScaleDeformableAttention",
54 | sources,
55 | include_dirs=include_dirs,
56 | define_macros=define_macros,
57 | extra_compile_args=extra_compile_args,
58 | )
59 | ]
60 | return ext_modules
61 |
62 | setup(
63 | name="MultiScaleDeformableAttention",
64 | version="1.0",
65 | author="Weijie Su",
66 | url="https://github.com/fundamentalvision/Deformable-DETR",
67 | description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention",
68 | packages=find_packages(exclude=("configs", "tests",)),
69 | ext_modules=get_extensions(),
70 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
71 | )
72 |
--------------------------------------------------------------------------------
/mctrans/models/ops/src/cpu/ms_deform_attn_cpu.cpp:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #include
12 |
13 | #include
14 | #include
15 |
16 |
17 | at::Tensor
18 | ms_deform_attn_cpu_forward(
19 | const at::Tensor &value,
20 | const at::Tensor &spatial_shapes,
21 | const at::Tensor &level_start_index,
22 | const at::Tensor &sampling_loc,
23 | const at::Tensor &attn_weight,
24 | const int im2col_step)
25 | {
26 | AT_ERROR("Not implement on cpu");
27 | }
28 |
29 | std::vector
30 | ms_deform_attn_cpu_backward(
31 | const at::Tensor &value,
32 | const at::Tensor &spatial_shapes,
33 | const at::Tensor &level_start_index,
34 | const at::Tensor &sampling_loc,
35 | const at::Tensor &attn_weight,
36 | const at::Tensor &grad_output,
37 | const int im2col_step)
38 | {
39 | AT_ERROR("Not implement on cpu");
40 | }
41 |
42 |
--------------------------------------------------------------------------------
/mctrans/models/ops/src/cpu/ms_deform_attn_cpu.h:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #pragma once
12 | #include
13 |
14 | at::Tensor
15 | ms_deform_attn_cpu_forward(
16 | const at::Tensor &value,
17 | const at::Tensor &spatial_shapes,
18 | const at::Tensor &level_start_index,
19 | const at::Tensor &sampling_loc,
20 | const at::Tensor &attn_weight,
21 | const int im2col_step);
22 |
23 | std::vector
24 | ms_deform_attn_cpu_backward(
25 | const at::Tensor &value,
26 | const at::Tensor &spatial_shapes,
27 | const at::Tensor &level_start_index,
28 | const at::Tensor &sampling_loc,
29 | const at::Tensor &attn_weight,
30 | const at::Tensor &grad_output,
31 | const int im2col_step);
32 |
33 |
34 |
--------------------------------------------------------------------------------
/mctrans/models/ops/src/cuda/ms_deform_attn_cuda.cu:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #include
12 | #include "cuda/ms_deform_im2col_cuda.cuh"
13 |
14 | #include
15 | #include
16 | #include
17 | #include
18 |
19 |
20 | at::Tensor ms_deform_attn_cuda_forward(
21 | const at::Tensor &value,
22 | const at::Tensor &spatial_shapes,
23 | const at::Tensor &level_start_index,
24 | const at::Tensor &sampling_loc,
25 | const at::Tensor &attn_weight,
26 | const int im2col_step)
27 | {
28 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
29 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
30 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
31 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
32 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
33 |
34 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
35 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
36 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
37 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
38 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
39 |
40 | const int batch = value.size(0);
41 | const int spatial_size = value.size(1);
42 | const int num_heads = value.size(2);
43 | const int channels = value.size(3);
44 |
45 | const int num_levels = spatial_shapes.size(0);
46 |
47 | const int num_query = sampling_loc.size(1);
48 | const int num_point = sampling_loc.size(4);
49 |
50 | const int im2col_step_ = std::min(batch, im2col_step);
51 |
52 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
53 |
54 | auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
55 |
56 | const int batch_n = im2col_step_;
57 | auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
58 | auto per_value_size = spatial_size * num_heads * channels;
59 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
60 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
61 | for (int n = 0; n < batch/im2col_step_; ++n)
62 | {
63 | auto columns = output_n.select(0, n);
64 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
65 | ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
66 | value.data() + n * im2col_step_ * per_value_size,
67 | spatial_shapes.data(),
68 | level_start_index.data(),
69 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
70 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
71 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
72 | columns.data());
73 |
74 | }));
75 | }
76 |
77 | output = output.view({batch, num_query, num_heads*channels});
78 |
79 | return output;
80 | }
81 |
82 |
83 | std::vector ms_deform_attn_cuda_backward(
84 | const at::Tensor &value,
85 | const at::Tensor &spatial_shapes,
86 | const at::Tensor &level_start_index,
87 | const at::Tensor &sampling_loc,
88 | const at::Tensor &attn_weight,
89 | const at::Tensor &grad_output,
90 | const int im2col_step)
91 | {
92 |
93 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
94 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
95 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
96 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
97 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
98 | AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
99 |
100 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
101 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
102 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
103 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
104 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
105 | AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
106 |
107 | const int batch = value.size(0);
108 | const int spatial_size = value.size(1);
109 | const int num_heads = value.size(2);
110 | const int channels = value.size(3);
111 |
112 | const int num_levels = spatial_shapes.size(0);
113 |
114 | const int num_query = sampling_loc.size(1);
115 | const int num_point = sampling_loc.size(4);
116 |
117 | const int im2col_step_ = std::min(batch, im2col_step);
118 |
119 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
120 |
121 | auto grad_value = at::zeros_like(value);
122 | auto grad_sampling_loc = at::zeros_like(sampling_loc);
123 | auto grad_attn_weight = at::zeros_like(attn_weight);
124 |
125 | const int batch_n = im2col_step_;
126 | auto per_value_size = spatial_size * num_heads * channels;
127 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
128 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
129 | auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
130 |
131 | for (int n = 0; n < batch/im2col_step_; ++n)
132 | {
133 | auto grad_output_g = grad_output_n.select(0, n);
134 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
135 | ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
136 | grad_output_g.data(),
137 | value.data() + n * im2col_step_ * per_value_size,
138 | spatial_shapes.data(),
139 | level_start_index.data(),
140 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
141 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
142 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
143 | grad_value.data() + n * im2col_step_ * per_value_size,
144 | grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
145 | grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size);
146 |
147 | }));
148 | }
149 |
150 | return {
151 | grad_value, grad_sampling_loc, grad_attn_weight
152 | };
153 | }
--------------------------------------------------------------------------------
/mctrans/models/ops/src/cuda/ms_deform_attn_cuda.h:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #pragma once
12 | #include
13 |
14 | at::Tensor ms_deform_attn_cuda_forward(
15 | const at::Tensor &value,
16 | const at::Tensor &spatial_shapes,
17 | const at::Tensor &level_start_index,
18 | const at::Tensor &sampling_loc,
19 | const at::Tensor &attn_weight,
20 | const int im2col_step);
21 |
22 | std::vector ms_deform_attn_cuda_backward(
23 | const at::Tensor &value,
24 | const at::Tensor &spatial_shapes,
25 | const at::Tensor &level_start_index,
26 | const at::Tensor &sampling_loc,
27 | const at::Tensor &attn_weight,
28 | const at::Tensor &grad_output,
29 | const int im2col_step);
30 |
31 |
--------------------------------------------------------------------------------
/mctrans/models/ops/src/ms_deform_attn.h:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #pragma once
12 |
13 | #include "cpu/ms_deform_attn_cpu.h"
14 |
15 | #ifdef WITH_CUDA
16 | #include "cuda/ms_deform_attn_cuda.h"
17 | #endif
18 |
19 |
20 | at::Tensor
21 | ms_deform_attn_forward(
22 | const at::Tensor &value,
23 | const at::Tensor &spatial_shapes,
24 | const at::Tensor &level_start_index,
25 | const at::Tensor &sampling_loc,
26 | const at::Tensor &attn_weight,
27 | const int im2col_step)
28 | {
29 | if (value.type().is_cuda())
30 | {
31 | #ifdef WITH_CUDA
32 | return ms_deform_attn_cuda_forward(
33 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
34 | #else
35 | AT_ERROR("Not compiled with GPU support");
36 | #endif
37 | }
38 | AT_ERROR("Not implemented on the CPU");
39 | }
40 |
41 | std::vector
42 | ms_deform_attn_backward(
43 | const at::Tensor &value,
44 | const at::Tensor &spatial_shapes,
45 | const at::Tensor &level_start_index,
46 | const at::Tensor &sampling_loc,
47 | const at::Tensor &attn_weight,
48 | const at::Tensor &grad_output,
49 | const int im2col_step)
50 | {
51 | if (value.type().is_cuda())
52 | {
53 | #ifdef WITH_CUDA
54 | return ms_deform_attn_cuda_backward(
55 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
56 | #else
57 | AT_ERROR("Not compiled with GPU support");
58 | #endif
59 | }
60 | AT_ERROR("Not implemented on the CPU");
61 | }
62 |
63 |
--------------------------------------------------------------------------------
/mctrans/models/ops/src/vision.cpp:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #include "ms_deform_attn.h"
12 |
13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
14 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
15 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
16 | }
17 |
--------------------------------------------------------------------------------
/mctrans/models/ops/test.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | from __future__ import absolute_import
10 | from __future__ import print_function
11 | from __future__ import division
12 |
13 | import time
14 | import torch
15 | import torch.nn as nn
16 | from torch.autograd import gradcheck
17 |
18 | from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch
19 |
20 |
21 | N, M, D = 1, 2, 2
22 | Lq, L, P = 2, 2, 2
23 | shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
24 | level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1]))
25 | S = sum([(H*W).item() for H, W in shapes])
26 |
27 |
28 | torch.manual_seed(3)
29 |
30 |
31 | @torch.no_grad()
32 | def check_forward_equal_with_pytorch_double():
33 | value = torch.rand(N, S, M, D).cuda() * 0.01
34 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
35 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
36 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
37 | im2col_step = 2
38 | output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu()
39 | output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu()
40 | fwdok = torch.allclose(output_cuda, output_pytorch)
41 | max_abs_err = (output_cuda - output_pytorch).abs().max()
42 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
43 |
44 | print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
45 |
46 |
47 | @torch.no_grad()
48 | def check_forward_equal_with_pytorch_float():
49 | value = torch.rand(N, S, M, D).cuda() * 0.01
50 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
51 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
52 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
53 | im2col_step = 2
54 | output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu()
55 | output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu()
56 | fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
57 | max_abs_err = (output_cuda - output_pytorch).abs().max()
58 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
59 |
60 | print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
61 |
62 |
63 | def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True):
64 |
65 | value = torch.rand(N, S, M, channels).cuda() * 0.01
66 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
67 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
68 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
69 | im2col_step = 2
70 | func = MSDeformAttnFunction.apply
71 |
72 | value.requires_grad = grad_value
73 | sampling_locations.requires_grad = grad_sampling_loc
74 | attention_weights.requires_grad = grad_attn_weight
75 |
76 | gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step))
77 |
78 | print(f'* {gradok} check_gradient_numerical(D={channels})')
79 |
80 |
81 | if __name__ == '__main__':
82 | check_forward_equal_with_pytorch_double()
83 | check_forward_equal_with_pytorch_float()
84 |
85 | for channels in [30, 32, 64, 71, 1025, 2048, 3096]:
86 | check_gradient_numerical(channels, True, True, True)
87 |
88 |
89 |
90 |
--------------------------------------------------------------------------------
/mctrans/models/segmentors/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import BaseSegmentor
2 | from .encoder_decoder import EncoderDecoder
3 |
4 | __all__ = ['BaseSegmentor', "EncoderDecoder"]
5 |
--------------------------------------------------------------------------------
/mctrans/models/segmentors/base.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import warnings
3 | from abc import ABCMeta, abstractmethod
4 | from collections import OrderedDict
5 |
6 | import mmcv
7 | import numpy as np
8 | import torch
9 | import torch.distributed as dist
10 | import torch.nn as nn
11 | from mmcv.runner import auto_fp16
12 |
13 |
14 | class BaseSegmentor(nn.Module):
15 | """Base class for segmentors."""
16 |
17 | __metaclass__ = ABCMeta
18 |
19 | def __init__(self):
20 | super(BaseSegmentor, self).__init__()
21 | self.fp16_enabled = False
22 |
23 | @property
24 | def with_center(self):
25 | """bool: whether the segmentor has neck"""
26 | return hasattr(self, 'center') and self.center is not None
27 |
28 | @property
29 | def with_auxiliary_head(self):
30 | """bool: whether the segmentor has auxiliary head"""
31 | return hasattr(self,
32 | 'aux_head') and self.aux_head is not None
33 |
34 | @property
35 | def with_decode_head(self):
36 | """bool: whether the segmentor has decode head"""
37 | return hasattr(self, 'decode_head') and self.decode_head is not None
38 |
39 | @abstractmethod
40 | def extract_feat(self, imgs):
41 | """Placeholder for extract features from images."""
42 | pass
43 |
44 | @abstractmethod
45 | def encode_decode(self, img, img_metas):
46 | """Placeholder for encode images with backbone and decode into a
47 | semantic segmentation map of the same size as input."""
48 | pass
49 |
50 | @abstractmethod
51 | def forward_train(self, imgs, img_metas, **kwargs):
52 | """Placeholder for Forward function for training."""
53 | pass
54 |
55 | @abstractmethod
56 | def simple_test(self, img, img_meta, **kwargs):
57 | """Placeholder for single image test."""
58 | pass
59 |
60 | @abstractmethod
61 | def aug_test(self, imgs, img_metas, **kwargs):
62 | """Placeholder for augmentation test."""
63 | pass
64 |
65 | def init_weights(self, pretrained=None):
66 | """Initialize the weights in segmentor.
67 |
68 | Args:
69 | pretrained (str, optional): Path to pre-trained weights.
70 | Defaults to None.
71 | """
72 | if pretrained is not None:
73 | logger = logging.getLogger()
74 | logger.info(f'load model from: {pretrained}')
75 |
76 | def forward_test(self, imgs, img_metas, **kwargs):
77 | """
78 | Args:
79 | imgs (List[Tensor]): the outer list indicates test-time
80 | augmentations and inner Tensor should have a shape NxCxHxW,
81 | which contains all images in the batch.
82 | img_metas (List[List[dict]]): the outer list indicates test-time
83 | augs (multiscale, flip, etc.) and the inner list indicates
84 | images in a batch.
85 | """
86 | for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
87 | if not isinstance(var, list):
88 | raise TypeError(f'{name} must be a list, but got '
89 | f'{type(var)}')
90 |
91 | num_augs = len(imgs)
92 | if num_augs != len(img_metas):
93 | raise ValueError(f'num of augmentations ({len(imgs)}) != '
94 | f'num of image meta ({len(img_metas)})')
95 | # all images in the same aug batch all of the same ori_shape and pad
96 | # shape
97 | for img_meta in img_metas:
98 | ori_shapes = [_['ori_shape'] for _ in img_meta]
99 | assert all(shape == ori_shapes[0] for shape in ori_shapes)
100 | img_shapes = [_['img_shape'] for _ in img_meta]
101 | assert all(shape == img_shapes[0] for shape in img_shapes)
102 | pad_shapes = [_['pad_shape'] for _ in img_meta]
103 | assert all(shape == pad_shapes[0] for shape in pad_shapes)
104 |
105 | if num_augs == 1:
106 | return self.simple_test(imgs[0], img_metas[0], **kwargs)
107 | else:
108 | return self.aug_test(imgs, img_metas, **kwargs)
109 |
110 | @auto_fp16(apply_to=('img', ))
111 | def forward(self, img, img_metas, return_loss=True, **kwargs):
112 | """Calls either :func:`forward_train` or :func:`forward_test` depending
113 | on whether ``return_loss`` is ``True``.
114 |
115 | Note this setting will change the expected inputs. When
116 | ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
117 | and List[dict]), and when ``resturn_loss=False``, img and img_meta
118 | should be double nested (i.e. List[Tensor], List[List[dict]]), with
119 | the outer list indicating test time augmentations.
120 | """
121 | if return_loss:
122 | return self.forward_train(img, img_metas, **kwargs)
123 | else:
124 | return self.forward_test(img, img_metas, **kwargs)
125 |
126 | def train_step(self, data_batch, optimizer, **kwargs):
127 | """The iteration step during training.
128 |
129 | This method defines an iteration step during training, except for the
130 | back propagation and optimizer updating, which are done in an optimizer
131 | hook. Note that in some complicated cases or models, the whole process
132 | including back propagation and optimizer updating is also defined in
133 | this method, such as GAN.
134 |
135 | Args:
136 | data (dict): The output of dataloader.
137 | optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
138 | runner is passed to ``train_step()``. This argument is unused
139 | and reserved.
140 |
141 | Returns:
142 | dict: It should contain at least 3 keys: ``loss``, ``log_vars``,
143 | ``num_samples``.
144 | ``loss`` is a tensor for back propagation, which can be a
145 | weighted sum of multiple losses.
146 | ``log_vars`` contains all the variables to be sent to the
147 | logger.
148 | ``num_samples`` indicates the batch size (when the model is
149 | DDP, it means the batch size on each GPU), which is used for
150 | averaging the logs.
151 | """
152 | losses = self(**data_batch)
153 | loss, log_vars = self._parse_losses(losses)
154 |
155 | outputs = dict(
156 | loss=loss,
157 | log_vars=log_vars,
158 | num_samples=len(data_batch['img'].data))
159 |
160 | return outputs
161 |
162 | def val_step(self, data_batch, **kwargs):
163 | """The iteration step during validation.
164 |
165 | This method shares the same signature as :func:`train_step`, but used
166 | during val epochs. Note that the evaluation after training epochs is
167 | not implemented with this method, but an evaluation hook.
168 | """
169 | output = self(**data_batch, **kwargs)
170 | return output
171 |
172 | @staticmethod
173 | def _parse_losses(losses):
174 | """Parse the raw outputs (losses) of the network.
175 |
176 | Args:
177 | losses (dict): Raw output of the network, which usually contain
178 | losses and other necessary information.
179 |
180 | Returns:
181 | tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor
182 | which may be a weighted sum of all losses, log_vars contains
183 | all the variables to be sent to the logger.
184 | """
185 |
186 | log_vars = OrderedDict()
187 | for loss_name, loss_value in losses.items():
188 | if isinstance(loss_value, torch.Tensor):
189 | log_vars[loss_name] = loss_value.mean()
190 | elif isinstance(loss_value, list):
191 | log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
192 | else:
193 | raise TypeError(
194 | f'{loss_name} is not a tensor or list of tensors')
195 |
196 | loss = sum(_value for _key, _value in log_vars.items() if 'Loss' in _key)
197 |
198 | log_vars['Loss'] = loss
199 | for loss_name, loss_value in log_vars.items():
200 | # reduce loss when distributed training
201 | if dist.is_available() and dist.is_initialized():
202 | loss_value = loss_value.data.clone()
203 | dist.all_reduce(loss_value.div_(dist.get_world_size()))
204 | log_vars[loss_name] = loss_value.item()
205 |
206 | return loss, log_vars
207 |
208 | def show_result(self,
209 | img,
210 | result,
211 | palette=None,
212 | win_name='',
213 | show=False,
214 | wait_time=0,
215 | out_file=None):
216 | """Draw `result` over `img`.
217 |
218 | Args:
219 | img (str or Tensor): The image to be displayed.
220 | result (Tensor): The semantic segmentation results to draw over
221 | `img`.
222 | palette (list[list[int]]] | np.ndarray | None): The palette of
223 | segmentation map. If None is given, random palette will be
224 | generated. Default: None
225 | win_name (str): The window name.
226 | wait_time (int): Value of waitKey param.
227 | Default: 0.
228 | show (bool): Whether to show the image.
229 | Default: False.
230 | out_file (str or None): The filename to write the image.
231 | Default: None.
232 |
233 | Returns:
234 | img (Tensor): Only if not `show` or `out_file`
235 | """
236 | img = mmcv.imread(img)
237 | img = img.copy()
238 | seg = result[0]
239 | if palette is None:
240 | if self.PALETTE is None:
241 | palette = np.random.randint(
242 | 0, 255, size=(len(self.CLASSES), 3))
243 | else:
244 | palette = self.PALETTE
245 | palette = np.array(palette)
246 | assert palette.shape[0] == len(self.CLASSES)
247 | assert palette.shape[1] == 3
248 | assert len(palette.shape) == 2
249 | color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
250 | for label, color in enumerate(palette):
251 | color_seg[seg == label, :] = color
252 | # convert to BGR
253 | color_seg = color_seg[..., ::-1]
254 |
255 | img = img * 0.5 + color_seg * 0.5
256 | img = img.astype(np.uint8)
257 | # if out_file specified, do not show image in window
258 | if out_file is not None:
259 | show = False
260 |
261 | if show:
262 | mmcv.imshow(img, win_name, wait_time)
263 | if out_file is not None:
264 | mmcv.imwrite(img, out_file)
265 |
266 | if not (show or out_file):
267 | warnings.warn('show==False and out_file is not specified, only '
268 | 'result image will be returned')
269 | return img
270 |
--------------------------------------------------------------------------------
/mctrans/models/segmentors/encoder_decoder.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 | from .base import BaseSegmentor
4 | from ..builder import build_network, build_losses, MODEL, build_encoder, build_decoder, build_head, build_center
5 | from ...data.transforms.utils import resize
6 | from ...metrics import build_metrics
7 |
8 |
9 | @MODEL.register_module()
10 | class EncoderDecoder(BaseSegmentor):
11 | def __init__(self,
12 | encoder,
13 | decoder,
14 | seg_head,
15 | center=None,
16 | aux_head=None,
17 | pretrained=None):
18 | super(EncoderDecoder, self).__init__()
19 | self.encoder = build_encoder(encoder)
20 | self.decoder = build_decoder(decoder)
21 | self.seg_head = build_head(seg_head)
22 |
23 | if center is not None:
24 | self.center = build_center(center)
25 | if aux_head is not None:
26 | self.aux_head = build_head(aux_head)
27 |
28 | self.init_weights(pretrained=pretrained)
29 |
30 | def init_weights(self, pretrained=None):
31 | """Initialize the weights in backbone and heads.
32 | Args:
33 | pretrained (str, optional): Path to pre-trained weights.
34 | Defaults to None.
35 | """
36 | pass
37 | super(EncoderDecoder, self).init_weights(pretrained)
38 | self.encoder.init_weights(pretrained=pretrained)
39 | self.decoder.init_weights()
40 | self.seg_head.init_weights()
41 | if self.with_auxiliary_head:
42 | if isinstance(self.aux_head, nn.ModuleList):
43 | for aux_head in self.aux_head:
44 | aux_head.init_weights()
45 | else:
46 | self.aux_head.init_weights()
47 |
48 | def extract_feat(self, img):
49 | x = self.encoder(img)
50 | if self.with_center:
51 | x = self.center(x)
52 | return x
53 |
54 | def encode_decode(self, img, img_metas, rescale=True):
55 | """Encode images with backbone and decode into a semantic segmentation
56 | map of the same size as input."""
57 | x = self.extract_feat(img)
58 | pred = self._decode_head_forward(x, img_metas, return_loss=False)
59 | #TODO shold evaluate on more dataset
60 | if rescale:
61 | if not hasattr(img_metas[0], "height"):
62 | re_size = img_metas[0]["spatial_shape"][:2]
63 | else:
64 | re_size = (img_metas[0]['height'], img_metas[0]['width'])
65 | pred = resize(
66 | pred,
67 | size=re_size,
68 | mode='nearest',
69 | align_corners=None,
70 | warning=False)
71 | return pred
72 |
73 | def _decode_head_forward(self, x, seg_label=None, return_loss=False):
74 | x = self.decoder(x)
75 | if return_loss:
76 | return self.seg_head.forward_train(x, seg_label)
77 | else:
78 | return self.seg_head.forward_test(x)
79 |
80 | def _auxiliary_head_forward(self, x, seg_label, return_loss=True):
81 | return self.aux_head.forward_train(x, seg_label)
82 |
83 | def forward_train(self, img, img_metas, seg_label, **kwargs):
84 | # the img_metas may useful in other framework
85 | x = self.extract_feat(img)
86 |
87 | losses = dict()
88 | loss_decode = self._decode_head_forward(x, seg_label, return_loss=True)
89 | losses.update(loss_decode)
90 |
91 | if self.with_auxiliary_head:
92 | loss_aux = self._auxiliary_head_forward(x, seg_label, return_loss=True)
93 | losses.update(loss_aux)
94 |
95 | return losses
96 |
97 | def forward_test(self, img, img_metas, rescale=True, **kwargs):
98 | # TODO support sliding window evaluator
99 | pred = self.encode_decode(img, img_metas, rescale)
100 | pred = pred.cpu().numpy()
101 | # unravel batch dim
102 | pred = list(pred)
103 | return pred
104 |
--------------------------------------------------------------------------------
/mctrans/models/trans/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JiYuanFeng/MCTrans/9b8b5677eef584b423d5e1630680a4b667cbe823/mctrans/models/trans/__init__.py
--------------------------------------------------------------------------------
/mctrans/models/trans/transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from ..ops.modules import MSDeformAttn
4 | from .utils import _get_activation_fn, _get_clones
5 |
6 |
7 | class DSALayer(nn.Module):
8 | def __init__(self,
9 | d_model=256, d_ffn=1024,
10 | dropout=0.1, activation="relu",
11 | n_levels=4, n_heads=8, n_points=4):
12 | super().__init__()
13 |
14 | # self attention
15 | self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points)
16 | self.dropout1 = nn.Dropout(dropout)
17 | self.norm1 = nn.LayerNorm(d_model)
18 |
19 | # ffn
20 | self.linear1 = nn.Linear(d_model, d_ffn)
21 | self.activation = _get_activation_fn(activation)
22 | self.dropout2 = nn.Dropout(dropout)
23 | self.linear2 = nn.Linear(d_ffn, d_model)
24 | self.dropout3 = nn.Dropout(dropout)
25 | self.norm2 = nn.LayerNorm(d_model)
26 |
27 | @staticmethod
28 | def with_pos_embed(tensor, pos):
29 | return tensor if pos is None else tensor + pos
30 |
31 | def forward_ffn(self, src):
32 | src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
33 | src = src + self.dropout3(src2)
34 | src = self.norm2(src)
35 | return src
36 |
37 | def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None):
38 | # self attention
39 | src2 = self.self_attn(self.with_pos_embed(src, pos),
40 | reference_points,
41 | src,
42 | spatial_shapes,
43 | level_start_index,
44 | padding_mask)
45 | src = src + self.dropout1(src2)
46 | src = self.norm1(src)
47 | # ffn
48 | src = self.forward_ffn(src)
49 |
50 | return src
51 |
52 |
53 | class DSA(nn.Module):
54 | def __init__(self, att_layer, n_layers):
55 | super().__init__()
56 | self.layers = _get_clones(att_layer, n_layers)
57 |
58 | @staticmethod
59 | def get_reference_points(spatial_shapes, valid_ratios, device):
60 | reference_points_list = []
61 | for lvl, (H_, W_) in enumerate(spatial_shapes):
62 | ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
63 | torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
64 | ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
65 | ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
66 | ref = torch.stack((ref_x, ref_y), -1)
67 | reference_points_list.append(ref)
68 | reference_points = torch.cat(reference_points_list, 1)
69 | reference_points = reference_points[:, :, None] * valid_ratios[:, None]
70 | return reference_points
71 |
72 | def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None):
73 | output = src
74 | reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)
75 | for _, layer in enumerate(self.layers):
76 | output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask)
77 | return output
78 |
79 |
80 | class CALayer(nn.Module):
81 | def __init__(self, d_model=256, d_ffn=1024,
82 | dropout=0.1, activation="relu", n_heads=8):
83 | super().__init__()
84 | self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
85 | self.dropout1 = nn.Dropout(dropout)
86 | self.norm1 = nn.LayerNorm(d_model)
87 |
88 | # self attention
89 | self.cross_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
90 | self.dropout2 = nn.Dropout(dropout)
91 | self.norm2 = nn.LayerNorm(d_model)
92 |
93 | # ffn
94 | self.linear1 = nn.Linear(d_model, d_ffn)
95 | self.activation = _get_activation_fn(activation)
96 | self.dropout3 = nn.Dropout(dropout)
97 | self.linear2 = nn.Linear(d_ffn, d_model)
98 | self.dropout4 = nn.Dropout(dropout)
99 | self.norm3 = nn.LayerNorm(d_model)
100 |
101 | @staticmethod
102 | def with_pos_embed(tensor, pos):
103 | return tensor if pos is None else tensor + pos
104 |
105 | def forward_ffn(self, tgt):
106 | tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt))))
107 | tgt = tgt + self.dropout4(tgt2)
108 | tgt = self.norm3(tgt)
109 | return tgt
110 |
111 | def forward(self, tgt, src):
112 | # self attention
113 | q = k = tgt
114 | tgt2 = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1))[0].transpose(0, 1)
115 | tgt = tgt + self.dropout2(tgt2)
116 | tgt = self.norm2(tgt)
117 | # cross attention
118 | tgt2 = self.cross_attn(tgt, src, src)[0]
119 | tgt = tgt + self.dropout1(tgt2)
120 | tgt = self.norm1(tgt)
121 | # ffn
122 | tgt = self.forward_ffn(tgt)
123 |
124 | return tgt
125 |
126 |
127 | class CA(nn.Module):
128 | def __init__(self, att_layer, n_layers, n_category=2, d_model=256):
129 | super().__init__()
130 | self.layers = _get_clones(att_layer, n_layers)
131 | self.proxy_embed = nn.Parameter(torch.zeros(1, n_category, d_model))
132 |
133 | def forward(self, src):
134 | query = None
135 | B = src.shape[0]
136 | for idx, layer in enumerate(self.layers):
137 | if idx == 0:
138 | query = self.proxy_embed.expand(B, -1, -1)
139 | else:
140 | query += self.proxy_embed.expand(B, -1, -1)
141 | query = layer(query, src)
142 | return query
--------------------------------------------------------------------------------
/mctrans/models/trans/utils.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------
6 | # Modified from DETR (https://github.com/facebookresearch/detr)
7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
8 | # ------------------------------------------------------------------------
9 |
10 | """
11 | Various positional encodings for the transformer.
12 | """
13 | import copy
14 | import math
15 | import torch
16 | from typing import Optional
17 |
18 | from torch import nn, Tensor
19 | import torch.nn.functional as F
20 |
21 |
22 | def _get_clones(module, N):
23 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
24 |
25 |
26 | def _get_activation_fn(activation):
27 | """Return an activation function given a string"""
28 | if activation == "relu":
29 | return F.relu
30 | if activation == "gelu":
31 | return F.gelu
32 | if activation == "glu":
33 | return F.glu
34 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
35 |
36 |
37 | def build_position_encoding(position_embedding='sine', hidden_dim=240):
38 | N_steps = hidden_dim // 2
39 | if position_embedding in ('v2', 'sine'):
40 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
41 | elif position_embedding in ('v3', 'learned'):
42 | position_embedding = PositionEmbeddingLearned(N_steps)
43 | else:
44 | raise ValueError(f"not supported {position_embedding}")
45 |
46 | return position_embedding
47 |
48 |
49 | class NestedTensor(object):
50 | def __init__(self, tensors, mask: Optional[Tensor]):
51 | self.tensors = tensors
52 | self.mask = mask
53 |
54 | def to(self, device, non_blocking=False):
55 | # type: (Device) -> NestedTensor # noqa
56 | cast_tensor = self.tensors.to(device, non_blocking=non_blocking)
57 | mask = self.mask
58 | if mask is not None:
59 | assert mask is not None
60 | cast_mask = mask.to(device, non_blocking=non_blocking)
61 | else:
62 | cast_mask = None
63 | return NestedTensor(cast_tensor, cast_mask)
64 |
65 | def record_stream(self, *args, **kwargs):
66 | self.tensors.record_stream(*args, **kwargs)
67 | if self.mask is not None:
68 | self.mask.record_stream(*args, **kwargs)
69 |
70 | def decompose(self):
71 | return self.tensors, self.mask
72 |
73 | def __repr__(self):
74 | return str(self.tensors)
75 |
76 |
77 | class DevPositionEmbeddingSine(nn.Module):
78 | """
79 | This is a more standard version of the position embedding, very similar to the one
80 | used by the Attention is all you need paper, generalized to work on images.
81 | """
82 |
83 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
84 | super().__init__()
85 | self.num_pos_feats = num_pos_feats
86 | self.temperature = temperature
87 | self.normalize = normalize
88 | if scale is not None and normalize is False:
89 | raise ValueError("normalize should be True if scale is passed")
90 | if scale is None:
91 | scale = 2 * math.pi
92 | self.scale = scale
93 |
94 | def forward(self, tensor_list: NestedTensor):
95 | x = tensor_list.tensors
96 | mask = tensor_list.mask
97 | assert mask is not None
98 | # (b, D, H, W)
99 | not_mask = ~mask
100 | z_embed = not_mask.cumsum(1, dtype=torch.float32)
101 | y_embed = not_mask.cumsum(2, dtype=torch.float32)
102 | x_embed = not_mask.cumsum(3, dtype=torch.float32)
103 | if self.normalize:
104 | eps = 1e-6
105 | z_embed = (z_embed - 0.5) / (y_embed[:, -1:, :, :] + eps) * self.scale
106 | y_embed = (y_embed - 0.5) / (y_embed[:, :, -1:, :] + eps) * self.scale
107 | x_embed = (x_embed - 0.5) / (x_embed[:, :, :, -1:] + eps) * self.scale
108 |
109 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
110 | dim_t = self.temperature ** (2 * (dim_t // 3) / self.num_pos_feats)
111 |
112 | pos_x = x_embed[:, :, :, :, None] / dim_t
113 | pos_y = y_embed[:, :, :, :, None] / dim_t
114 | pos_z = z_embed[:, :, :, :, None] / dim_t
115 |
116 | pos_x = torch.stack((pos_x[:, :, :, :, 0::2].sin(), pos_x[:, :, :, :, 1::2].cos()), dim=5).flatten(4)
117 | pos_y = torch.stack((pos_y[:, :, :, :, 0::2].sin(), pos_y[:, :, :, :, 1::2].cos()), dim=5).flatten(4)
118 | pos_z = torch.stack((pos_z[:, :, :, :, 0::2].sin(), pos_z[:, :, :, :, 1::2].cos()), dim=5).flatten(4)
119 | pos = torch.cat((pos_z, pos_y, pos_x), dim=4).permute(0, 4, 1, 2, 3)
120 | return pos
121 |
122 |
123 | class PositionEmbeddingSine(nn.Module):
124 | """
125 | This is a more standard version of the position embedding, very similar to the one
126 | used by the Attention is all you need paper, generalized to work on images.
127 | """
128 |
129 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
130 | super().__init__()
131 | self.num_pos_feats = num_pos_feats
132 | self.temperature = temperature
133 | self.normalize = normalize
134 | if scale is not None and normalize is False:
135 | raise ValueError("normalize should be True if scale is passed")
136 | if scale is None:
137 | scale = 2 * math.pi
138 | self.scale = scale
139 |
140 | def forward(self, tensor_list: NestedTensor):
141 | x = tensor_list.tensors
142 | mask = tensor_list.mask
143 | assert mask is not None
144 | not_mask = ~mask
145 | y_embed = not_mask.cumsum(1, dtype=torch.float32)
146 | x_embed = not_mask.cumsum(2, dtype=torch.float32)
147 | if self.normalize:
148 | eps = 1e-6
149 | y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
150 | x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
151 |
152 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
153 | dim_t = self.temperature ** (2 * (dim_t // 3) / self.num_pos_feats)
154 |
155 | pos_x = x_embed[:, :, :, None] / dim_t
156 | pos_y = y_embed[:, :, :, None] / dim_t
157 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
158 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
159 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
160 | return pos
161 |
162 |
163 | class PositionEmbeddingLearned(nn.Module):
164 | """
165 | Absolute pos embedding, learned.
166 | """
167 |
168 | def __init__(self, num_pos_feats=256):
169 | super().__init__()
170 | self.row_embed = nn.Embedding(50, num_pos_feats)
171 | self.col_embed = nn.Embedding(50, num_pos_feats)
172 | self.reset_parameters()
173 |
174 | def reset_parameters(self):
175 | nn.init.uniform_(self.row_embed.weight)
176 | nn.init.uniform_(self.col_embed.weight)
177 |
178 | def forward(self, tensor_list: NestedTensor):
179 | x = tensor_list.tensors
180 | h, w = x.shape[-2:]
181 | i = torch.arange(w, device=x.device)
182 | j = torch.arange(h, device=x.device)
183 | x_emb = self.col_embed(i)
184 | y_emb = self.row_embed(j)
185 | pos = torch.cat([
186 | x_emb.unsqueeze(0).repeat(h, 1, 1),
187 | y_emb.unsqueeze(1).repeat(1, w, 1),
188 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
189 | return pos
190 |
191 |
--------------------------------------------------------------------------------
/mctrans/models/utils.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | def conv3x3(in_planes, out_planes, dilation=1):
5 | return nn.Conv2d(
6 | in_planes,
7 | out_planes,
8 | kernel_size=3,
9 | padding=dilation,
10 | dilation=dilation)
11 |
12 |
13 | def conv_bn_relu(in_channels, out_channels, kernel_size=3, padding=1, stride=1):
14 | return nn.Sequential(nn.Conv2d(in_channels,
15 | out_channels,
16 | kernel_size=kernel_size,
17 | padding=padding,
18 | stride=stride),
19 | nn.BatchNorm2d(out_channels),
20 | nn.ReLU(inplace=True))
21 |
22 |
23 | def deconv_bn_relu(in_channels, out_channels, kernel_size=3, padding=1, stride=2):
24 | return nn.Sequential(nn.ConvTranspose2d(in_channels,
25 | out_channels,
26 | kernel_size=kernel_size,
27 | padding=padding,
28 | stride=stride,
29 | output_padding=1),
30 | nn.BatchNorm2d(out_channels),
31 | nn.ReLU(inplace=True))
32 |
33 |
34 | def make_vgg_layer(inplanes,
35 | planes,
36 | num_blocks,
37 | dilation=1,
38 | with_bn=False,
39 | down_sample=False,
40 | ceil_mode=False):
41 | layers = []
42 | if down_sample:
43 | layers.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=ceil_mode))
44 | for _ in range(num_blocks):
45 | layers.append(conv3x3(inplanes, planes, dilation))
46 | if with_bn:
47 | layers.append(nn.BatchNorm2d(planes))
48 | layers.append(nn.ReLU(inplace=True))
49 | inplanes = planes
50 | return nn.Sequential(*layers)
51 |
--------------------------------------------------------------------------------
/mctrans/pipline/__init__.py:
--------------------------------------------------------------------------------
1 | from .segpipline import SegPipline
2 |
--------------------------------------------------------------------------------
/mctrans/pipline/segpipline.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
3 | from mmcv.runner import build_runner, build_optimizer
4 | from mmcv.runner import DistSamplerSeedHook, DistEvalHook, EvalHook
5 |
6 | from mctrans.data import build_dataset, build_dataloader
7 | from mctrans.models import build_model
8 | from mctrans.utils import get_root_logger
9 |
10 |
11 | class SegPipline(object):
12 | def __init__(self, cfg, distributed=False, validate=False):
13 | self.cfg = cfg
14 | self.distributed = distributed
15 | self.logger = get_root_logger(cfg.log_level)
16 | self.data_loaders = self._build_data_loader()
17 | self.model = self._build_model()
18 | # need to update
19 | self.optimizer = build_optimizer(self.model, cfg.optimizer)
20 | self.runner = self._build_runner()
21 | if validate:
22 | self._build_eval_hook()
23 | if cfg.resume_from:
24 | self.runner.resume(cfg.resume_from)
25 | elif cfg.load_from:
26 | self.runner.load_checkpoint(cfg.load_from)
27 | else:
28 | pass
29 |
30 | def _build_data_loader(self):
31 | dataset = [build_dataset(self.cfg.data.train)]
32 | if len(self.cfg.workflow) == 2:
33 | dataset.append(build_dataset(self.cfg.data.val))
34 | data_loaders = [
35 | build_dataloader(
36 | ds,
37 | self.cfg.data.samples_per_gpu,
38 | self.cfg.data.workers_per_gpu,
39 | len(self.cfg.gpu_ids),
40 | dist=self.distributed,
41 | seed=self.cfg.seed,
42 | drop_last=True) for ds in dataset]
43 | return data_loaders
44 |
45 | def _build_model(self):
46 | model = build_model(self.cfg.model)
47 | if self.distributed:
48 | find_unused_parameters = self.cfg.get('find_unused_parameters', True)
49 | # Sets the `find_unused_parameters` parameter in
50 | # torch.nn.parallel.DistributedDataParallel
51 | model = MMDistributedDataParallel(
52 | model.cuda(),
53 | device_ids=[torch.cuda.current_device()],
54 | broadcast_buffers=False,
55 | find_unused_parameters=find_unused_parameters)
56 | else:
57 | model = MMDataParallel(
58 | model.cuda(self.cfg.gpu_ids[0]), device_ids=self.cfg.gpu_ids)
59 | return model
60 |
61 | def _build_eval_hook(self):
62 | val_dataset = build_dataset(self.cfg.data.val)
63 | val_dataloader = build_dataloader(
64 | val_dataset,
65 | samples_per_gpu=1,
66 | workers_per_gpu=self.cfg.data.workers_per_gpu,
67 | dist=self.distributed,
68 | shuffle=False)
69 | eval_cfg = self.cfg.get('evaluation', {})
70 | eval_cfg['by_epoch'] = self.cfg.runner['type'] != 'IterBasedRunner'
71 | eval_hook = DistEvalHook if self.distributed else EvalHook
72 | self.runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
73 |
74 | def _build_runner(self):
75 | runner = build_runner(
76 | cfg=self.cfg.runner,
77 | default_args=dict(
78 | model=self.model,
79 | batch_processor=None,
80 | optimizer=self.optimizer,
81 | work_dir=self.cfg.work_dir,
82 | logger=self.logger,
83 | meta=None))
84 |
85 | # register hooks
86 | runner.register_training_hooks(self.cfg.lr_config, self.cfg.optimizer_config,
87 | self.cfg.checkpoint_config, self.cfg.log_config,
88 | self.cfg.get('momentum_config', None))
89 |
90 | if self.distributed:
91 | runner.register_hook(DistSamplerSeedHook())
92 |
93 | return runner
94 |
95 | def _report_details(self):
96 | self.logger.info(self.model)
97 |
98 | def run(self):
99 | self.runner.run(self.data_loaders, self.cfg.workflow, self.cfg.max_epochs)
100 |
--------------------------------------------------------------------------------
/mctrans/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .logger import get_root_logger
2 |
--------------------------------------------------------------------------------
/mctrans/utils/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from mmcv.utils import get_logger
4 |
5 |
6 | def get_root_logger(log_file=None, log_level=logging.INFO):
7 | """Get the root logger.
8 |
9 | The logger will be initialized if it has not been initialized. By default a
10 | StreamHandler will be added. If `log_file` is specified, a FileHandler will
11 | also be added. The name of the root logger is the top-level package name,
12 | e.g., "mmseg".
13 |
14 | Args:
15 | log_file (str | None): The log filename. If specified, a FileHandler
16 | will be added to the root logger.
17 | log_level (int): The root logger level. Note that only the process of
18 | rank 0 is affected, while other processes will set the level to
19 | "Error" and be silent most of the time.
20 |
21 | Returns:
22 | logging.Logger: The root logger.
23 | """
24 |
25 | logger = get_logger(name='MCTrans', log_file=log_file, log_level=log_level)
26 |
27 | return logger
28 |
--------------------------------------------------------------------------------
/mctrans/utils/misc.py:
--------------------------------------------------------------------------------
1 | import random
2 | from typing import Optional
3 |
4 | import numpy as np
5 | import torch
6 | import torch.distributed as dist
7 | from torch import Tensor
8 |
9 | def sync_param(input, reduction='mean'):
10 | if isinstance(input, np.ndarray):
11 | sync_input = torch.from_numpy(input).cuda()
12 | elif isinstance(input, torch.Tensor):
13 | sync_input = input.clone()
14 | else:
15 | raise ValueError('input should be torch tensor or ndarray')
16 | dist.all_reduce(sync_input)
17 | if reduction == 'mean':
18 | sync_input.div_(dist.get_world_size())
19 | return sync_input
20 |
21 | def is_distributed():
22 | if dist.is_available() and dist.is_initialized():
23 | return True
24 | else:
25 | return False
26 |
27 |
28 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==0.13.0
2 | addict==2.4.0
3 | cachetools==4.2.2
4 | certifi==2021.5.30
5 | chardet==4.0.0
6 | future==0.18.2
7 | google-auth==1.32.1
8 | google-auth-oauthlib==0.4.4
9 | grpcio==1.38.1
10 | idna==2.10
11 | importlib-metadata==4.6.1
12 | joblib==1.0.1
13 | Markdown==3.3.4
14 | mmcv==1.3.9
15 | monai==0.5.3
16 | numpy==1.21.0
17 | oauthlib==3.1.1
18 | opencv-python==4.5.3.56
19 | Pillow==8.3.1
20 | prettytable==2.1.0
21 | protobuf==3.17.3
22 | pyasn1==0.4.8
23 | pyasn1-modules==0.2.8
24 | PyYAML==5.4.1
25 | requests==2.25.1
26 | requests-oauthlib==1.3.0
27 | rsa==4.7.2
28 | scikit-learn==0.24.2
29 | scipy==1.7.0
30 | six==1.16.0
31 | sklearn==0.0
32 | tensorboard==2.5.0
33 | tensorboard-data-server==0.6.1
34 | tensorboard-plugin-wit==1.8.0
35 | threadpoolctl==2.2.0
36 | torch==1.7.1+cu110
37 | torchaudio==0.7.2
38 | torchvision==0.8.2+cu110
39 | typing-extensions==3.10.0.0
40 | urllib3==1.26.6
41 | wcwidth==0.2.5
42 | Werkzeug==2.0.1
43 | yapf==0.31.0
44 | zipp==3.5.0
45 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 |
4 | setuptools.setup(
5 | name='MCTrans',
6 | version='0.0.1',
7 | author='Yuanfeng Ji, Jiamin Ren',
8 | author_email='u3008013@connect.hku.hk',
9 | description='Simple Framework for Medcial Image Analysis',
10 | url='https://github.com/JiYuanFeng/MCTrans',
11 | packages=setuptools.find_packages(),
12 | license='Internal',
13 | include_package_data=True,
14 | zip_safe=False,
15 | classifiers=[
16 | 'License :: OSI Approved :: Internal',
17 | 'Programming Language :: Python',
18 | 'Intended Audience :: Developers',
19 | 'Operating System :: OS Independent',
20 | ]
21 | )
22 |
--------------------------------------------------------------------------------
/tools/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import copy
3 | import os
4 | import os.path as osp
5 | import time
6 | import warnings
7 | warnings.filterwarnings("ignore")
8 |
9 | import mmcv
10 | import torch
11 | from mmcv.runner import init_dist, set_random_seed
12 | from mmcv.utils import Config, DictAction
13 |
14 | from mctrans.pipline import SegPipline
15 | from mctrans.utils import get_root_logger
16 |
17 |
18 | def parse_args():
19 | parser = argparse.ArgumentParser(description='Train a segmentation')
20 | parser.add_argument('config', help='train config file path')
21 | parser.add_argument('--work-dir', help='the dir to save logs and models')
22 | parser.add_argument(
23 | '--load-from', help='the checkpoint file to load weights from')
24 | parser.add_argument(
25 | '--resume-from', help='the checkpoint file to resume from')
26 | parser.add_argument(
27 | '--no-validate',
28 | action='store_true',
29 | help='whether not to evaluate the checkpoint during training')
30 | group_gpus = parser.add_mutually_exclusive_group()
31 | group_gpus.add_argument(
32 | '--gpus',
33 | type=int,
34 | help='number of gpus to use '
35 | '(only applicable to non-distributed training)')
36 | group_gpus.add_argument(
37 | '--gpu-ids',
38 | type=int,
39 | nargs='+',
40 | help='ids of gpus to use '
41 | '(only applicable to non-distributed training)')
42 | parser.add_argument('--seed', type=int, default=0, help='random seed')
43 | parser.add_argument(
44 | '--deterministic',
45 | action='store_true',
46 | help='whether to set deterministic options for CUDNN backend.')
47 | parser.add_argument(
48 | '--options', nargs='+', action=DictAction, help='custom options')
49 | parser.add_argument(
50 | '--launcher',
51 | choices=['none', 'pytorch', 'slurm', 'mpi'],
52 | default='none',
53 | help='job launcher')
54 | parser.add_argument('--local_rank', type=int, default=0)
55 | args = parser.parse_args()
56 | if 'LOCAL_RANK' not in os.environ:
57 | os.environ['LOCAL_RANK'] = str(args.local_rank)
58 |
59 | parser.add_argument('--port', type=int, default=29500,
60 | help='port only works when launcher=="slurm"')
61 |
62 | return args
63 |
64 |
65 | def main():
66 | args = parse_args()
67 |
68 | cfg = Config.fromfile(args.config)
69 | if args.options is not None:
70 | cfg.merge_from_dict(args.options)
71 | # set cudnn_benchmark
72 | if cfg.get('cudnn_benchmark', False):
73 | torch.backends.cudnn.benchmark = True
74 |
75 | # work_dir is determined in this priority: CLI > segment in file > filename
76 | if args.work_dir is not None:
77 | # update configs according to CLI args if args.work_dir is not None
78 | cfg.work_dir = args.work_dir
79 | elif cfg.get('work_dir', None) is None:
80 | # use config filename as default work_dir if cfg.work_dir is None
81 | cfg.work_dir = osp.join('./work_dirs', osp.splitext(osp.basename(args.config))[0])
82 |
83 | if args.load_from is not None:
84 | cfg.load_from = args.load_from
85 | if args.resume_from is not None:
86 | cfg.resume_from = args.resume_from
87 | if args.gpu_ids is not None:
88 | cfg.gpu_ids = args.gpu_ids
89 | else:
90 | cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
91 |
92 | # init distributed env first, since logger depends on the dist info.
93 | if args.launcher == 'none':
94 | distributed = False
95 | else:
96 | distributed = True
97 | init_dist(args.launcher, **cfg.dist_params)
98 |
99 | # create work_dir
100 | mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
101 | # dump config
102 | cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
103 | # init the logger before other steps
104 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
105 | log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
106 | logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
107 | # init the meta dict to record some important information such as
108 | # environment info and seed, which will be logged
109 | meta = dict()
110 |
111 | # log some basic info
112 | logger.info(f'Distributed training: {distributed}')
113 | logger.info(f'Config:\n{cfg.pretty_text}')
114 |
115 | # set random seeds
116 | if args.seed is not None:
117 | logger.info(f'Set random seed to {args.seed}, deterministic: '
118 | f'{args.deterministic}')
119 | set_random_seed(args.seed, deterministic=args.deterministic)
120 | cfg.seed = args.seed
121 | meta['seed'] = args.seed
122 | meta['exp_name'] = osp.basename(args.config)
123 |
124 | SegPipline(cfg, distributed, not args.no_validate).run()
125 |
126 | if __name__ == '__main__':
127 | main()
128 |
--------------------------------------------------------------------------------
/tools/train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | CFG="configs/pannuke-vgg32/mctrans_vgg32_d5_256x256_400ep_pannuke.py"
3 | WORK_DIR=$(echo ${CFG%.*} | sed -e "s/configs/work_dirs/g")/
4 | python3 tools/train.py ${CFG} --work-dir ${WORK_DIR} --gpus 1
5 |
6 |
--------------------------------------------------------------------------------